Coverage for kwant/linalg/lapack.pyx : 88%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright 2011-2017 Kwant authors.
2#
3# This file is part of Kwant. It is subject to the license terms in the file
4# LICENSE.rst found in the top-level directory of this distribution and at
5# https://kwant-project.org/license. A list of Kwant authors can be found in
6# the file AUTHORS.rst at the top-level directory of this distribution and at
7# https://kwant-project.org/authors.
9"""Low-level access to LAPACK functions. """
11__all__ = ['getrf',
12 'getrs',
13 'gecon',
14 'ggev',
15 'gees',
16 'trsen',
17 'trevc',
18 'gges',
19 'tgsen',
20 'tgevc',
21 'prepare_for_lapack']
23import numpy as np
24cimport numpy as np
26cimport scipy.linalg.cython_lapack as lapack
28ctypedef int l_int
29ctypedef bint l_logical
31int_dtype = np.int32
32logical_dtype = np.int32
34ctypedef float complex float_complex
35ctypedef double complex double_complex
37ctypedef fused scalar:
38 float
39 double
40 float complex
41 double complex
43ctypedef fused single_precision:
44 float
45 float complex
47ctypedef fused double_precision:
48 double
49 double complex
51ctypedef fused cmplx:
52 float complex
53 double complex
55ctypedef fused floating:
56 float
57 double
59# exceptions
61class LinAlgError(RuntimeError):
62 pass
65# some helper functions
66def filter_args(select, args):
67 return tuple([arg for sel, arg in zip(select, args) if sel])
69def assert_fortran_mat(*mats):
70 # This is a workaround for a bug in NumPy version < 2.0,
71 # where 1x1 matrices do not have the F_Contiguous flag set correctly.
72 for mat in mats:
73 if (mat is not None and (mat.shape[0] > 1 or mat.shape[1] > 1) and
74 not mat.flags["F_CONTIGUOUS"]):
75 raise ValueError("Input matrix must be Fortran contiguous")
78cdef np.ndarray maybe_complex(scalar selector,
79 np.ndarray real, np.ndarray imag):
80 cdef np.ndarray r
81 r = real
82 if scalar in floating:
83 if imag.nonzero()[0].size:
84 r = real + 1j * imag
85 return r
88cdef l_int lwork_from_qwork(scalar qwork):
89 if scalar in floating:
90 return <l_int>qwork
91 else:
92 return <l_int>qwork.real
95def getrf(np.ndarray[scalar, ndim=2] A):
96 cdef l_int M, N, info
97 cdef np.ndarray[l_int] ipiv
99 assert_fortran_mat(A)
101 M = A.shape[0]
102 N = A.shape[1]
103 ipiv = np.empty(min(M,N), dtype = int_dtype)
105 if scalar is float:
106 lapack.sgetrf(&M, &N, <float *>A.data, &M,
107 <l_int *>ipiv.data, &info)
108 elif scalar is double:
109 lapack.dgetrf(&M, &N, <double *>A.data, &M,
110 <l_int *>ipiv.data, &info)
111 elif scalar is float_complex:
112 lapack.cgetrf(&M, &N, <float complex *>A.data, &M,
113 <l_int *>ipiv.data, &info)
114 elif scalar is double_complex:
115 lapack.zgetrf(&M, &N, <double complex *>A.data, &M,
116 <l_int *>ipiv.data, &info)
118 assert info >= 0, "Argument error in getrf"
120 return (A, ipiv, info > 0 or M != N)
123def getrs(np.ndarray[scalar, ndim=2] LU, np.ndarray[l_int] IPIV,
124 np.ndarray B):
125 cdef l_int N, NRHS, info
127 assert_fortran_mat(LU)
129 # Consistency checks for LU and B
131 if B.descr.type_num != LU.descr.type_num:
132 raise TypeError('B must have same dtype as LU')
134 # Workaround for 1x1-Fortran bug in NumPy < v2.0
135 if ((B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and
136 not B.flags["F_CONTIGUOUS"])):
137 raise ValueError("B must be Fortran ordered")
139 if B.ndim > 2:
140 raise ValueError("B must be a vector or matrix")
142 if LU.shape[0] != B.shape[0]:
143 raise ValueError('LU and B have incompatible shapes')
145 N = LU.shape[0]
147 if B.ndim == 1:
148 NRHS = 1
149 elif B.ndim == 2:
150 NRHS = B.shape[1]
152 if scalar is float:
153 lapack.sgetrs("N", &N, &NRHS, <float *>LU.data, &N,
154 <l_int *>IPIV.data, <float *>B.data, &N,
155 &info)
156 elif scalar is double:
157 lapack.dgetrs("N", &N, &NRHS, <double *>LU.data, &N,
158 <l_int *>IPIV.data, <double *>B.data, &N,
159 &info)
160 elif scalar is float_complex:
161 lapack.cgetrs("N", &N, &NRHS, <float complex *>LU.data, &N,
162 <l_int *>IPIV.data, <float complex *>B.data, &N,
163 &info)
164 elif scalar is double_complex:
165 lapack.zgetrs("N", &N, &NRHS, <double complex *>LU.data, &N,
166 <l_int *>IPIV.data, <double complex *>B.data, &N,
167 &info)
169 assert info == 0, "Argument error in getrs"
171 return B
174def gecon(np.ndarray[scalar, ndim=2] LU, double normA, char *norm = b"1"):
175 cdef l_int N, info
176 cdef float srcond, snormA
177 cdef double drcond
179 # Parameter checks
181 assert_fortran_mat(LU)
182 if norm[0] != b"1" and norm[0] != b"I":
183 raise ValueError("'norm' must be either '1' or 'I'")
184 if scalar in single_precision:
185 snormA = normA
187 # Allocate workspaces
189 N = LU.shape[0]
191 cdef np.ndarray[l_int] iwork
192 if scalar in floating:
193 iwork = np.empty(N, dtype=int_dtype)
195 cdef np.ndarray[scalar] work
196 if scalar in floating:
197 work = np.empty(4 * N, dtype=LU.dtype)
198 else:
199 work = np.empty(2 * N, dtype=LU.dtype)
201 cdef np.ndarray rwork
202 if scalar is float_complex:
203 rwork = np.empty(2 * N, dtype=np.float32)
204 elif scalar is double_complex:
205 rwork = np.empty(2 * N, dtype=np.float64)
207 # The actual calculation
209 if scalar is float:
210 lapack.sgecon(norm, &N, <float *>LU.data, &N, &snormA,
211 &srcond, <float *>work.data,
212 <l_int *>iwork.data, &info)
213 elif scalar is double:
214 lapack.dgecon(norm, &N, <double *>LU.data, &N, &normA,
215 &drcond, <double *>work.data,
216 <l_int *>iwork.data, &info)
217 elif scalar is float_complex:
218 lapack.cgecon(norm, &N, <float complex *>LU.data, &N, &snormA,
219 &srcond, <float complex *>work.data,
220 <float *>rwork.data, &info)
221 elif scalar is double_complex:
222 lapack.zgecon(norm, &N, <double complex *>LU.data, &N, &normA,
223 &drcond, <double complex *>work.data,
224 <double *>rwork.data, &info)
226 assert info == 0, "Argument error in gecon"
228 if scalar in single_precision:
229 return srcond
230 else:
231 return drcond
234# Helper function for xGGEV
235def ggev_postprocess(dtype, alphar, alphai, vl_r=None, vr_r=None):
236 # depending on whether the eigenvalues are purely real or complex,
237 # some post-processing of the eigenvalues and -vectors is necessary
239 indx = (alphai > 0.0).nonzero()[0]
241 if indx.size:
242 alpha = alphar + 1j * alphai
244 if vl_r is not None:
245 vl = np.array(vl_r, dtype = dtype)
246 for i in indx:
247 vl.imag[:, i] = vl_r[:,i+1]
248 vl[:, i+1] = np.conj(vl[:, i])
249 else:
250 vl = None
252 if vr_r is not None:
253 vr = np.array(vr_r, dtype = dtype)
254 for i in indx:
255 vr.imag[:, i] = vr_r[:,i+1]
256 vr[:, i+1] = np.conj(vr[:, i])
257 else:
258 vr = None
259 else:
260 alpha = alphar
261 vl = vl_r
262 vr = vr_r
264 return (alpha, vl, vr)
267def ggev(np.ndarray[scalar, ndim=2] A, np.ndarray[scalar, ndim=2] B,
268 left=False, right=True):
269 cdef l_int N, info, lwork
271 # Parameter checks
273 assert_fortran_mat(A, B)
275 if A.ndim != 2 or A.ndim != 2:
276 raise ValueError("gen_eig requires both a and be to be matrices")
278 if A.shape[0] != A.shape[1]:
279 raise ValueError("gen_eig requires square matrix input")
281 if A.shape[0] != B.shape[0] or A.shape[1] != B.shape[1]:
282 raise ValueError("A and B do not have the same shape")
284 # Allocate workspaces
286 N = A.shape[0]
288 cdef np.ndarray[scalar] alphar, alphai
289 if scalar in cmplx:
290 alphar = np.empty(N, dtype=A.dtype)
291 alphai = None
292 else:
293 alphar = np.empty(N, dtype=A.dtype)
294 alphai = np.empty(N, dtype=A.dtype)
296 cdef np.ndarray[scalar] beta = np.empty(N, dtype=A.dtype)
298 cdef np.ndarray rwork = None
299 if scalar is float_complex:
300 rwork = np.empty(8 * N, dtype=np.float32)
301 elif scalar is double_complex:
302 rwork = np.empty(8 * N, dtype=np.float64)
304 cdef np.ndarray vl
305 cdef scalar *vl_ptr
306 cdef char *jobvl
307 if left:
308 vl = np.empty((N,N), dtype=A.dtype, order='F')
309 vl_ptr = <scalar *>vl.data
310 jobvl = "V"
311 else:
312 vl = None
313 vl_ptr = NULL
314 jobvl = "N"
316 cdef np.ndarray vr
317 cdef scalar *vr_ptr
318 cdef char *jobvr
319 if right:
320 vr = np.empty((N,N), dtype=A.dtype, order='F')
321 vr_ptr = <scalar *>vr.data
322 jobvr = "V"
323 else:
324 vr = None
325 vr_ptr = NULL
326 jobvr = "N"
328 # Workspace query
329 # Xggev expects &qwork as a <scalar *> (even though it's an integer)
330 lwork = -1
331 cdef scalar qwork
333 if scalar is float:
334 lapack.sggev(jobvl, jobvr, &N, <float *>A.data, &N,
335 <float *>B.data, &N,
336 <float *>alphar.data, <float *> alphai.data,
337 <float *>beta.data,
338 vl_ptr, &N, vr_ptr, &N,
339 &qwork, &lwork, &info)
340 elif scalar is double:
341 lapack.dggev(jobvl, jobvr, &N, <double *>A.data, &N,
342 <double *>B.data, &N,
343 <double *>alphar.data, <double *> alphai.data,
344 <double *>beta.data,
345 vl_ptr, &N, vr_ptr, &N,
346 &qwork, &lwork, &info)
347 elif scalar is float_complex:
348 lapack.cggev(jobvl, jobvr, &N, <float complex *>A.data, &N,
349 <float complex *>B.data, &N,
350 <float complex *>alphar.data, <float complex *>beta.data,
351 vl_ptr, &N, vr_ptr, &N,
352 &qwork, &lwork,
353 <float *>rwork.data, &info)
354 elif scalar is double_complex:
355 lapack.zggev(jobvl, jobvr, &N, <double complex *>A.data, &N,
356 <double complex *>B.data, &N,
357 <double complex *>alphar.data, <double complex *>beta.data,
358 vl_ptr, &N, vr_ptr, &N,
359 &qwork, &lwork,
360 <double *>rwork.data, &info)
362 assert info == 0, "Argument error in ggev"
364 lwork = lwork_from_qwork(qwork)
365 cdef np.ndarray[scalar] work = np.empty(lwork, dtype=A.dtype)
367 # The actual calculation
369 if scalar is float:
370 lapack.sggev(jobvl, jobvr, &N, <float *>A.data, &N,
371 <float *>B.data, &N,
372 <float *>alphar.data, <float *> alphai.data,
373 <float *>beta.data,
374 vl_ptr, &N, vr_ptr, &N,
375 <float *>work.data, &lwork, &info)
376 elif scalar is double:
377 lapack.dggev(jobvl, jobvr, &N, <double *>A.data, &N,
378 <double *>B.data, &N,
379 <double *>alphar.data, <double *> alphai.data,
380 <double *>beta.data,
381 vl_ptr, &N, vr_ptr, &N,
382 <double *>work.data, &lwork, &info)
383 elif scalar is float_complex:
384 lapack.cggev(jobvl, jobvr, &N, <float complex *>A.data, &N,
385 <float complex *>B.data, &N,
386 <float complex *>alphar.data, <float complex *>beta.data,
387 vl_ptr, &N, vr_ptr, &N,
388 <float complex *>work.data, &lwork,
389 <float *>rwork.data, &info)
390 elif scalar is double_complex:
391 lapack.zggev(jobvl, jobvr, &N, <double complex *>A.data, &N,
392 <double complex *>B.data, &N,
393 <double complex *>alphar.data, <double complex *>beta.data,
394 vl_ptr, &N, vr_ptr, &N,
395 <double complex *>work.data, &lwork,
396 <double *>rwork.data, &info)
398 if info > 0:
399 raise LinAlgError("QZ iteration failed to converge in sggev")
401 assert info == 0, "Argument error in ggev"
403 if scalar is float:
404 post_dtype = np.complex64
405 elif scalar is double:
406 post_dtype = np.complex128
408 cdef np.ndarray alpha
409 alpha = alphar
410 if scalar in floating:
411 alpha, vl, vr = ggev_postprocess(post_dtype, alphar, alphai, vl, vr)
413 return filter_args((True, True, left, right), (alpha, beta, vl, vr))
416def gees(np.ndarray[scalar, ndim=2] A, calc_q=True, calc_ev=True):
417 cdef l_int N, lwork, sdim, info
419 assert_fortran_mat(A)
421 if A.ndim != 2:
422 raise ValueError("Expect matrix as input")
424 if A.shape[0] != A.shape[1]:
425 raise ValueError("Expect square matrix")
427 # Allocate workspaces
429 N = A.shape[0]
431 cdef np.ndarray[scalar] wr, wi
432 if scalar in cmplx:
433 wr = np.empty(N, dtype=A.dtype)
434 wi = None
435 else:
436 wr = np.empty(N, dtype=A.dtype)
437 wi = np.empty(N, dtype=A.dtype)
439 cdef np.ndarray rwork
440 if scalar is float_complex:
441 rwork = np.empty(N, dtype=np.float32)
442 elif scalar is double_complex:
443 rwork = np.empty(N, dtype=np.float64)
445 cdef char *jobvs
446 cdef scalar *vs_ptr
447 cdef np.ndarray[scalar, ndim=2] vs
448 if calc_q:
449 vs = np.empty((N,N), dtype=A.dtype, order='F')
450 vs_ptr = <scalar *>vs.data
451 jobvs = "V"
452 else:
453 vs = None
454 vs_ptr = NULL
455 jobvs = "N"
457 # Workspace query
458 # Xgees expects &qwork as a <scalar *> (even though it's an integer)
459 lwork = -1
460 cdef scalar qwork
462 if scalar is float:
463 lapack.sgees(jobvs, "N", NULL, &N, <float *>A.data, &N,
464 &sdim, <float *>wr.data, <float *>wi.data, vs_ptr, &N,
465 &qwork, &lwork, NULL, &info)
466 elif scalar is double:
467 lapack.dgees(jobvs, "N", NULL, &N, <double *>A.data, &N,
468 &sdim, <double *>wr.data, <double *>wi.data, vs_ptr, &N,
469 &qwork, &lwork, NULL, &info)
470 elif scalar is float_complex:
471 lapack.cgees(jobvs, "N", NULL, &N, <float complex *>A.data, &N,
472 &sdim, <float complex *>wr.data, vs_ptr, &N,
473 &qwork, &lwork, <float *>rwork.data, NULL, &info)
474 elif scalar is double_complex:
475 lapack.zgees(jobvs, "N", NULL, &N, <double complex *>A.data, &N,
476 &sdim, <double complex *>wr.data, vs_ptr, &N,
477 &qwork, &lwork, <double *>rwork.data, NULL, &info)
479 assert info == 0, "Argument error in sgees"
481 lwork = lwork_from_qwork(qwork)
482 cdef np.ndarray[scalar] work = np.empty(lwork, dtype=A.dtype)
484 # The actual calculation
486 if scalar is float:
487 lapack.sgees(jobvs, "N", NULL, &N, <float *>A.data, &N,
488 &sdim, <float *>wr.data, <float *>wi.data, vs_ptr, &N,
489 <float *>work.data, &lwork, NULL, &info)
490 elif scalar is double:
491 lapack.dgees(jobvs, "N", NULL, &N, <double *>A.data, &N,
492 &sdim, <double *>wr.data, <double *>wi.data, vs_ptr, &N,
493 <double *>work.data, &lwork, NULL, &info)
494 elif scalar is float_complex:
495 lapack.cgees(jobvs, "N", NULL, &N, <float complex *>A.data, &N,
496 &sdim, <float complex *>wr.data, vs_ptr, &N,
497 <float complex *>work.data, &lwork,
498 <float *>rwork.data, NULL, &info)
499 elif scalar is double_complex:
500 lapack.zgees(jobvs, "N", NULL, &N, <double complex *>A.data, &N,
501 &sdim, <double complex *>wr.data, vs_ptr, &N,
502 <double complex *>work.data, &lwork,
503 <double *>rwork.data, NULL, &info)
505 if info > 0:
506 raise LinAlgError("QR iteration failed to converge in gees")
508 assert info == 0, "Argument error in gees"
510 # Real inputs possibly produce complex output
511 cdef np.ndarray w = maybe_complex[scalar](0, wr, wi)
513 return filter_args((True, calc_q, calc_ev), (A, vs, w))
516def trsen(np.ndarray[l_logical] select,
517 np.ndarray[scalar, ndim=2] T,
518 np.ndarray[scalar, ndim=2] Q,
519 calc_ev=True):
520 cdef l_int N, M, lwork, liwork, qiwork, info
522 assert_fortran_mat(T, Q)
524 # Allocate workspaces
526 N = T.shape[0]
528 cdef np.ndarray[scalar] wr, wi
529 if scalar in cmplx:
530 wr = np.empty(N, dtype=T.dtype)
531 wi = None
532 else:
533 wr = np.empty(N, dtype=T.dtype)
534 wi = np.empty(N, dtype=T.dtype)
536 cdef char *compq
537 cdef scalar *q_ptr
538 if Q is not None:
539 compq = "V"
540 q_ptr = <scalar *>Q.data
541 else:
542 compq = "N"
543 q_ptr = NULL
545 # Workspace query
546 # Xtrsen expects &qwork as a <scalar *> (even though it's an integer)
547 cdef scalar qwork
548 lwork = liwork = -1
550 if scalar is float:
551 lapack.strsen("N", compq, <l_logical *>select.data,
552 &N, <float *>T.data, &N, q_ptr, &N,
553 <float *>wr.data, <float *>wi.data, &M, NULL, NULL,
554 &qwork, &lwork, &qiwork, &liwork, &info)
555 elif scalar is double:
556 lapack.dtrsen("N", compq, <l_logical *>select.data,
557 &N, <double *>T.data, &N, q_ptr, &N,
558 <double *>wr.data, <double *>wi.data, &M, NULL, NULL,
559 &qwork, &lwork, &qiwork, &liwork, &info)
560 elif scalar is float_complex:
561 lapack.ctrsen("N", compq, <l_logical *>select.data,
562 &N, <float complex *>T.data, &N, q_ptr, &N,
563 <float complex *>wr.data, &M, NULL, NULL,
564 &qwork, &lwork, &info)
565 elif scalar is double_complex:
566 lapack.ztrsen("N", compq, <l_logical *>select.data,
567 &N, <double complex *>T.data, &N, q_ptr, &N,
568 <double complex *>wr.data, &M, NULL, NULL,
569 &qwork, &lwork, &info)
571 assert info == 0, "Argument error in trsen"
573 lwork = lwork_from_qwork(qwork)
574 cdef np.ndarray[scalar] work = np.empty(lwork, dtype=T.dtype)
576 cdef np.ndarray[l_int] iwork = None
577 if scalar in floating:
578 liwork = qiwork
579 iwork = np.empty(liwork, dtype=int_dtype)
581 # Tha actual calculation
583 if scalar is float:
584 lapack.strsen("N", compq, <l_logical *>select.data,
585 &N, <float *>T.data, &N, q_ptr, &N,
586 <float *>wr.data, <float *>wi.data, &M, NULL, NULL,
587 <float *>work.data, &lwork,
588 <l_int *>iwork.data, &liwork, &info)
589 elif scalar is double:
590 lapack.dtrsen("N", compq, <l_logical *>select.data,
591 &N, <double *>T.data, &N, q_ptr, &N,
592 <double *>wr.data, <double *>wi.data, &M, NULL, NULL,
593 <double *>work.data, &lwork,
594 <l_int *>iwork.data, &liwork, &info)
595 elif scalar is float_complex:
596 lapack.ctrsen("N", compq, <l_logical *>select.data,
597 &N, <float complex *>T.data, &N, q_ptr, &N,
598 <float complex *>wr.data, &M, NULL, NULL,
599 <float complex *>work.data, &lwork, &info)
600 elif scalar is double_complex:
601 lapack.ztrsen("N", compq, <l_logical *>select.data,
602 &N, <double complex *>T.data, &N, q_ptr, &N,
603 <double complex *>wr.data, &M, NULL, NULL,
604 <double complex *>work.data, &lwork, &info)
606 if info > 0:
607 raise LinAlgError("Reordering failed; problem is very ill-conditioned")
609 assert info == 0, "Argument error in trsen"
611 # Real inputs possibly produce complex output
612 cdef np.ndarray w = maybe_complex[scalar](0, wr, wi)
614 return filter_args((True, Q is not None, calc_ev), (T, Q, w))
617# Helper function for xTREVC and xTGEVC
618def txevc_postprocess(dtype, T, vreal, np.ndarray[l_logical] select):
619 cdef int N, M, i, m, indx
621 N = T.shape[0]
622 if select is None:
623 select = np.ones(N, dtype = logical_dtype)
624 selindx = select.nonzero()[0]
625 M = selindx.size
627 v = np.empty((N, M), dtype = dtype, order='F')
629 indx = 0
630 for m in range(M):
631 k = selindx[m]
633 if k < N-1 and T[k+1,k]:
634 # we have the situation of a 2x2 block, and
635 # the eigenvalue with the positive imaginary part desired
636 v[:, m] = vreal[:, indx] + 1j * vreal[:, indx + 1]
638 # Check if the eigenvalue with negative real part is also
639 # selected, if it is, we need the same entries in vr
640 if not select[k+1]:
641 indx += 2
642 elif k > 0 and T[k,k-1]:
643 # we have the situation of a 2x2 block, and
644 # the eigenvalue with the negative imaginary part desired
645 v[:, m] = vreal[:, indx] - 1j * vreal[:, indx + 1]
647 indx += 2
648 else:
649 # real eigenvalue
650 v[:, m] = vreal[:, indx]
652 indx += 1
653 return v
656def trevc(np.ndarray[scalar, ndim=2] T,
657 np.ndarray[scalar, ndim=2] Q,
658 np.ndarray[l_logical] select,
659 left=False, right=True):
660 cdef l_int N, info, M, MM
661 cdef char *side
662 cdef char *howmny
664 # Parameter checks
666 if (T.shape[0] != T.shape[1] or Q.shape[0] != Q.shape[1]
667 or T.shape[0] != Q.shape[0]):
668 raise ValueError("Invalid Schur decomposition as input")
670 assert_fortran_mat(T, Q)
672 # Workspace allocation
674 N = T.shape[0]
676 cdef np.ndarray[scalar] work
677 if scalar in floating:
678 work = np.empty(4 * N, dtype=T.dtype)
679 else:
680 work = np.empty(2 * N, dtype=T.dtype)
682 cdef np.ndarray rwork = None
683 if scalar is float_complex:
684 rwork = np.empty(N, dtype=np.float32)
685 elif scalar is double_complex:
686 rwork = np.empty(N, dtype=np.float64)
688 if left and right:
689 side = "B"
690 elif left:
691 side = "L"
692 elif right:
693 side = "R"
694 else:
695 return
697 cdef np.ndarray[l_logical] select_cpy
698 cdef l_logical *select_ptr
699 if select is not None:
700 howmny = "S"
701 MM = select.nonzero()[0].size
702 # Correct for possible additional storage if a single complex
703 # eigenvalue is selected.
704 # For that: Figure out the positions of the 2x2 blocks.
705 cmplxindx = np.diagonal(T, -1).nonzero()[0]
706 for i in cmplxindx:
707 if bool(select[i]) != bool(select[i+1]):
708 MM += 1
710 # Select is overwritten in strevc.
711 select_cpy = np.array(select, dtype = logical_dtype,
712 order = 'F')
713 select_ptr = <l_logical *>select_cpy.data
714 else:
715 MM = N
716 select_ptr = NULL
717 if Q is not None:
718 howmny = "B"
719 else:
720 howmny = "A"
722 cdef np.ndarray[scalar, ndim=2] vl_r = None
723 cdef scalar *vl_r_ptr
724 if left:
725 if Q is not None and select is None:
726 vl_r = np.asfortranarray(Q.copy())
727 else:
728 vl_r = np.empty((N, MM), dtype=T.dtype, order='F')
729 vl_r_ptr = <scalar *>vl_r.data
730 else:
731 vl_r_ptr = NULL
733 cdef np.ndarray[scalar, ndim=2] vr_r = None
734 cdef scalar *vr_r_ptr
735 if right:
736 if Q is not None and select is None:
737 vr_r = np.asfortranarray(Q.copy())
738 else:
739 vr_r = np.empty((N, MM), dtype=T.dtype, order='F')
740 vr_r_ptr = <scalar *>vr_r.data
741 else:
742 vr_r_ptr = NULL
744 # The actual calculation
746 if scalar is float:
747 lapack.strevc(side, howmny, select_ptr,
748 &N, <float *>T.data, &N,
749 vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
750 <float *>work.data, &info)
751 elif scalar is double:
752 lapack.dtrevc(side, howmny, select_ptr,
753 &N, <double *>T.data, &N,
754 vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
755 <double *>work.data, &info)
756 elif scalar is float_complex:
757 lapack.ctrevc(side, howmny, select_ptr,
758 &N, <float complex *>T.data, &N,
759 vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
760 <float complex *>work.data, <float *>rwork.data, &info)
761 elif scalar is double_complex:
762 lapack.ztrevc(side, howmny, select_ptr,
763 &N, <double complex *>T.data, &N,
764 vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
765 <double complex *>work.data, <double *>rwork.data, &info)
767 assert info == 0, "Argument error in trevc"
768 assert MM == M, "Unexpected number of eigenvectors returned in strevc"
770 if select is not None and Q is not None:
771 if left:
772 vl_r = np.asfortranarray(np.dot(Q, vl_r))
773 if right:
774 vr_r = np.asfortranarray(np.dot(Q, vr_r))
776 cdef np.ndarray vl, vr
777 if left:
778 vl = vl_r
779 if right:
780 vr = vr_r
781 if scalar in floating:
782 # If there are complex eigenvalues, we need to postprocess the
783 # eigenvectors.
784 if scalar is float:
785 dtype = np.complex64
786 else:
787 dtype = np.complex128
788 if np.diagonal(T, -1).nonzero()[0].size:
789 if left:
790 vl = txevc_postprocess(dtype, T, vl_r, select)
791 if right:
792 vr = txevc_postprocess(dtype, T, vr_r, select)
794 if left and right:
795 return (vl, vr)
796 elif left:
797 return vl
798 else:
799 return vr
802def gges(np.ndarray[scalar, ndim=2] A,
803 np.ndarray[scalar, ndim=2] B,
804 calc_q=True, calc_z=True, calc_ev=True):
805 cdef l_int N, sdim, info
807 # Check parameters
809 assert_fortran_mat(A, B)
811 if A.shape[0] != B.shape[1]:
812 raise ValueError("Expect square matrix A")
814 if A.shape[0] != B.shape[0] or A.shape[0] != B.shape[1]:
815 raise ValueError("Shape of B is incompatible with matrix A")
817 # Allocate workspaces
819 N = A.shape[0]
821 cdef np.ndarray[scalar] alphar, alphai
822 if scalar in cmplx:
823 alphar = np.empty(N, dtype=A.dtype)
824 alphai = None
825 else:
826 alphar = np.empty(N, dtype=A.dtype)
827 alphai = np.empty(N, dtype=A.dtype)
829 cdef np.ndarray[scalar] beta = np.empty(N, dtype=A.dtype)
831 cdef np.ndarray rwork = None
832 if scalar is float_complex:
833 rwork = np.empty(8 * N, dtype=np.float32)
834 elif scalar is double_complex:
835 rwork = np.empty(8 * N, dtype=np.float64)
837 cdef char *jobvsl
838 cdef scalar *vsl_ptr
839 cdef np.ndarray[scalar, ndim=2] vsl
840 if calc_q:
841 vsl = np.empty((N,N), dtype=A.dtype, order='F')
842 vsl_ptr = <scalar *>vsl.data
843 jobvsl = "V"
844 else:
845 vsl = None
846 vsl_ptr = NULL
847 jobvsl = "N"
849 cdef char *jobvsr
850 cdef scalar *vsr_ptr
851 cdef np.ndarray[scalar, ndim=2] vsr
852 if calc_z:
853 vsr = np.empty((N,N), dtype=A.dtype, order='F')
854 vsr_ptr = <scalar *>vsr.data
855 jobvsr = "V"
856 else:
857 vsr = None
858 vsr_ptr = NULL
859 jobvsr = "N"
861 # Workspace query
862 # Xgges expects &qwork as a <scalar *> (even though it's an integer)
863 cdef l_int lwork = -1
864 cdef scalar qwork
866 if scalar is float:
867 lapack.sgges(jobvsl, jobvsr, "N", NULL,
868 &N, <float *>A.data, &N,
869 <float *>B.data, &N, &sdim,
870 <float *>alphar.data, <float *>alphai.data,
871 <float *>beta.data,
872 vsl_ptr, &N, vsr_ptr, &N,
873 &qwork, &lwork, NULL, &info)
874 elif scalar is double:
875 lapack.dgges(jobvsl, jobvsr, "N", NULL,
876 &N, <double *>A.data, &N,
877 <double *>B.data, &N, &sdim,
878 <double *>alphar.data, <double *>alphai.data,
879 <double *>beta.data,
880 vsl_ptr, &N, vsr_ptr, &N,
881 &qwork, &lwork, NULL, &info)
882 elif scalar is float_complex:
883 lapack.cgges(jobvsl, jobvsr, "N", NULL,
884 &N, <float complex *>A.data, &N,
885 <float complex *>B.data, &N, &sdim,
886 <float complex *>alphar.data, <float complex *>beta.data,
887 vsl_ptr, &N, vsr_ptr, &N,
888 &qwork, &lwork, <float *>rwork.data, NULL, &info)
889 elif scalar is double_complex:
890 lapack.zgges(jobvsl, jobvsr, "N", NULL,
891 &N, <double complex *>A.data, &N,
892 <double complex *>B.data, &N, &sdim,
893 <double complex *>alphar.data, <double complex *>beta.data,
894 vsl_ptr, &N, vsr_ptr, &N,
895 &qwork, &lwork, <double *>rwork.data, NULL, &info)
897 assert info == 0, "Argument error in gges"
899 lwork = lwork_from_qwork(qwork)
900 cdef np.ndarray[scalar] work = np.empty(lwork, dtype=A.dtype)
902 # The actual calculation
904 if scalar is float:
905 lapack.sgges(jobvsl, jobvsr, "N", NULL,
906 &N, <float *>A.data, &N,
907 <float *>B.data, &N, &sdim,
908 <float *>alphar.data, <float *>alphai.data,
909 <float *>beta.data,
910 vsl_ptr, &N, vsr_ptr, &N,
911 <float *>work.data, &lwork, NULL, &info)
912 elif scalar is double:
913 lapack.dgges(jobvsl, jobvsr, "N", NULL,
914 &N, <double *>A.data, &N,
915 <double *>B.data, &N, &sdim,
916 <double *>alphar.data, <double *>alphai.data,
917 <double *>beta.data,
918 vsl_ptr, &N, vsr_ptr, &N,
919 <double *>work.data, &lwork, NULL, &info)
920 elif scalar is float_complex:
921 lapack.cgges(jobvsl, jobvsr, "N", NULL,
922 &N, <float complex *>A.data, &N,
923 <float complex *>B.data, &N, &sdim,
924 <float complex *>alphar.data, <float complex *>beta.data,
925 vsl_ptr, &N, vsr_ptr, &N,
926 <float complex *>work.data, &lwork,
927 <float *>rwork.data, NULL, &info)
928 elif scalar is double_complex:
929 lapack.zgges(jobvsl, jobvsr, "N", NULL,
930 &N, <double complex *>A.data, &N,
931 <double complex *>B.data, &N, &sdim,
932 <double complex *>alphar.data, <double complex *>beta.data,
933 vsl_ptr, &N, vsr_ptr, &N,
934 <double complex *>work.data, &lwork,
935 <double *>rwork.data, NULL, &info)
937 if info > 0:
938 raise LinAlgError("QZ iteration failed to converge in gges")
940 assert info == 0, "Argument error in gges"
942 # Real inputs possibly produce complex output
943 cdef np.ndarray alpha = maybe_complex[scalar](0, alphar, alphai)
945 return filter_args((True, True, calc_q, calc_z, calc_ev, calc_ev),
946 (A, B, vsl, vsr, alpha, beta))
949def tgsen(np.ndarray[l_logical] select,
950 np.ndarray[scalar, ndim=2] S,
951 np.ndarray[scalar, ndim=2] T,
952 np.ndarray[scalar, ndim=2] Q,
953 np.ndarray[scalar, ndim=2] Z,
954 calc_ev=True):
955 cdef l_int ijob = 0
956 cdef l_int N, M, lwork, liwork, info
958 # Check parameters
960 if ((S.shape[0] != S.shape[1] or T.shape[0] != T.shape[1] or
961 S.shape[0] != T.shape[0]) or
962 (Q is not None and (Q.shape[0] != Q.shape[1] or
963 S.shape[0] != Q.shape[0])) or
964 (Z is not None and (Z.shape[0] != Z.shape[1] or
965 S.shape[0] != Z.shape[0]))):
966 raise ValueError("Invalid Schur decomposition as input")
968 assert_fortran_mat(S, T, Q, Z)
970 # Allocate workspaces
972 N = S.shape[0]
974 cdef np.ndarray[scalar] alphar, alphai
975 if scalar in cmplx:
976 alphar = np.empty(N, dtype=S.dtype)
977 alphai = None
978 else:
979 alphar = np.empty(N, dtype=S.dtype)
980 alphai = np.empty(N, dtype=S.dtype)
982 cdef np.ndarray[scalar] beta
983 beta = np.empty(N, dtype=S.dtype)
985 cdef l_logical wantq
986 cdef scalar *q_ptr
987 if Q is not None:
988 wantq = 1
989 q_ptr = <scalar *>Q.data
990 else:
991 wantq = 0
992 q_ptr = NULL
994 cdef l_logical wantz
995 cdef scalar *z_ptr
996 if Z is not None:
997 wantz = 1
998 z_ptr = <scalar *>Z.data
999 else:
1000 wantz = 0
1001 z_ptr = NULL
1003 # Workspace query
1004 # Xtgsen expects &qwork as a <scalar *> (even though it's an integer)
1005 lwork = -1
1006 liwork = -1
1007 cdef scalar qwork
1008 cdef l_int qiwork
1010 if scalar is float:
1011 lapack.stgsen(&ijob, &wantq, &wantz, <l_logical *>select.data,
1012 &N, <float *>S.data, &N,
1013 <float *>T.data, &N,
1014 <float *>alphar.data, <float *>alphai.data,
1015 <float *>beta.data,
1016 q_ptr, &N, z_ptr, &N, &M, NULL, NULL, NULL,
1017 &qwork, &lwork, &qiwork, &liwork, &info)
1018 elif scalar is double:
1019 lapack.dtgsen(&ijob, &wantq, &wantz, <l_logical *>select.data,
1020 &N, <double *>S.data, &N,
1021 <double *>T.data, &N,
1022 <double *>alphar.data, <double *>alphai.data,
1023 <double *>beta.data,
1024 q_ptr, &N, z_ptr, &N, &M, NULL, NULL, NULL,
1025 &qwork, &lwork, &qiwork, &liwork, &info)
1026 elif scalar is float_complex:
1027 lapack.ctgsen(&ijob, &wantq, &wantz, <l_logical *>select.data,
1028 &N, <float complex *>S.data, &N,
1029 <float complex *>T.data, &N,
1030 <float complex *>alphar.data, <float complex *>beta.data,
1031 q_ptr, &N, z_ptr, &N, &M, NULL, NULL, NULL,
1032 &qwork, &lwork, &qiwork, &liwork, &info)
1033 elif scalar is double_complex:
1034 lapack.ztgsen(&ijob, &wantq, &wantz, <l_logical *>select.data,
1035 &N, <double complex *>S.data, &N,
1036 <double complex *>T.data, &N,
1037 <double complex *>alphar.data, <double complex *>beta.data,
1038 q_ptr, &N, z_ptr, &N, &M, NULL, NULL, NULL,
1039 &qwork, &lwork, &qiwork, &liwork, &info)
1041 assert info == 0, "Argument error in tgsen"
1043 lwork = lwork_from_qwork(qwork)
1044 cdef np.ndarray[scalar] work = np.empty(lwork, dtype=S.dtype)
1046 liwork = qiwork
1047 cdef np.ndarray[l_int] iwork = np.empty(liwork, dtype=int_dtype)
1049 # The actual calculation
1051 if scalar is float:
1052 lapack.stgsen(&ijob, &wantq, &wantz, <l_logical *>select.data,
1053 &N, <float *>S.data, &N,
1054 <float *>T.data, &N,
1055 <float *>alphar.data, <float *>alphai.data,
1056 <float *>beta.data,
1057 q_ptr, &N, z_ptr, &N, &M, NULL, NULL, NULL,
1058 <float *>work.data, &lwork,
1059 <l_int *>iwork.data, &liwork, &info)
1060 elif scalar is double:
1061 lapack.dtgsen(&ijob, &wantq, &wantz, <l_logical *>select.data,
1062 &N, <double *>S.data, &N,
1063 <double *>T.data, &N,
1064 <double *>alphar.data, <double *>alphai.data,
1065 <double *>beta.data,
1066 q_ptr, &N, z_ptr, &N, &M, NULL, NULL, NULL,
1067 <double *>work.data, &lwork,
1068 <l_int *>iwork.data, &liwork, &info)
1069 elif scalar is float_complex:
1070 lapack.ctgsen(&ijob, &wantq, &wantz, <l_logical *>select.data,
1071 &N, <float complex *>S.data, &N,
1072 <float complex *>T.data, &N,
1073 <float complex *>alphar.data, <float complex *>beta.data,
1074 q_ptr, &N, z_ptr, &N, &M, NULL, NULL, NULL,
1075 <float complex *>work.data, &lwork,
1076 <l_int *>iwork.data, &liwork, &info)
1077 elif scalar is double_complex:
1078 lapack.ztgsen(&ijob, &wantq, &wantz, <l_logical *>select.data,
1079 &N, <double complex *>S.data, &N,
1080 <double complex *>T.data, &N,
1081 <double complex *>alphar.data, <double complex *>beta.data,
1082 q_ptr, &N, z_ptr, &N, &M, NULL, NULL, NULL,
1083 <double complex *>work.data, &lwork,
1084 <l_int *>iwork.data, &liwork, &info)
1086 if info > 0:
1087 raise LinAlgError("Reordering failed; problem is very ill-conditioned")
1089 assert info == 0, "Argument error in tgsen"
1091 # Real inputs possibly produce complex output
1092 cdef np.ndarray alpha = maybe_complex[scalar](0, alphar, alphai)
1094 return filter_args((True, True, Q is not None, Z is not None,
1095 calc_ev, calc_ev),
1096 (S, T, Q, Z, alpha, beta))
1099def tgevc(np.ndarray[scalar, ndim=2] S,
1100 np.ndarray[scalar, ndim=2] T,
1101 np.ndarray[scalar, ndim=2] Q,
1102 np.ndarray[scalar, ndim=2] Z,
1103 np.ndarray[l_logical] select,
1104 left=False, right=True):
1105 cdef l_int N, info, M, MM
1107 # Check parameters
1109 if ((S.shape[0] != S.shape[1] or T.shape[0] != T.shape[1] or
1110 S.shape[0] != T.shape[0]) or
1111 (Q is not None and (Q.shape[0] != Q.shape[1] or
1112 S.shape[0] != Q.shape[0])) or
1113 (Z is not None and (Z.shape[0] != Z.shape[1] or
1114 S.shape[0] != Z.shape[0]))):
1115 raise ValueError("Invalid Schur decomposition as input")
1117 assert_fortran_mat(S, T, Q, Z)
1119 # Allocate workspaces
1121 N = S.shape[0]
1123 cdef np.ndarray[scalar] work
1124 if scalar in floating:
1125 work = np.empty(6 * N, dtype=S.dtype)
1126 else:
1127 work = np.empty(2 * N, dtype=S.dtype)
1129 cdef np.ndarray rwork = None
1130 if scalar is float_complex:
1131 rwork = np.empty(2 * N, dtype=np.float32)
1132 elif scalar is double_complex:
1133 rwork = np.empty(2 * N, dtype=np.float64)
1135 cdef char *side
1136 if left and right:
1137 side = "B"
1138 elif left:
1139 side = "L"
1140 elif right:
1141 side = "R"
1142 else:
1143 return
1145 cdef l_logical backtr = False
1147 cdef char *howmny
1148 cdef np.ndarray[l_logical] select_cpy = None
1149 cdef l_logical *select_ptr
1150 if select is not None:
1151 howmny = "S"
1152 MM = select.nonzero()[0].size
1153 # Correct for possible additional storage if a single complex
1154 # eigenvalue is selected.
1155 # For that: Figure out the positions of the 2x2 blocks.
1156 cmplxindx = np.diagonal(S, -1).nonzero()[0]
1157 for i in cmplxindx:
1158 if bool(select[i]) != bool(select[i+1]):
1159 MM += 1
1161 # select is overwritten in tgevc
1162 select_cpy = np.array(select, dtype=logical_dtype,
1163 order = 'F')
1164 select_ptr = <l_logical *>select_cpy.data
1165 else:
1166 MM = N
1167 select_ptr = NULL
1168 if ((left and right and Q is not None and Z is not None) or
1169 (left and not right and Q is not None) or
1170 (right and not left and Z is not None)):
1171 howmny = "B"
1172 backtr = True
1173 else:
1174 howmny = "A"
1176 cdef np.ndarray[scalar, ndim=2] vl_r
1177 cdef scalar *vl_r_ptr
1178 if left:
1179 if backtr:
1180 vl_r = Q
1181 else:
1182 vl_r = np.empty((N, MM), dtype=S.dtype, order='F')
1183 vl_r_ptr = <scalar *>vl_r.data
1184 else:
1185 vl_r_ptr = NULL
1187 cdef np.ndarray[scalar, ndim=2] vr_r
1188 cdef scalar *vr_r_ptr
1189 if right:
1190 if backtr:
1191 vr_r = Z
1192 else:
1193 vr_r = np.empty((N, MM), dtype=S.dtype, order='F')
1194 vr_r_ptr = <scalar *>vr_r.data
1195 else:
1196 vr_r_ptr = NULL
1198 if scalar is float:
1199 lapack.stgevc(side, howmny, select_ptr,
1200 &N, <float *>S.data, &N,
1201 <float *>T.data, &N,
1202 vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
1203 <float *>work.data, &info)
1204 elif scalar is double:
1205 lapack.dtgevc(side, howmny, select_ptr,
1206 &N, <double *>S.data, &N,
1207 <double *>T.data, &N,
1208 vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
1209 <double *>work.data, &info)
1210 elif scalar is float_complex:
1211 lapack.ctgevc(side, howmny, select_ptr,
1212 &N, <float complex *>S.data, &N,
1213 <float complex *>T.data, &N,
1214 vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
1215 <float complex *>work.data, <float *>rwork.data, &info)
1216 elif scalar is double_complex:
1217 lapack.ztgevc(side, howmny, select_ptr,
1218 &N, <double complex *>S.data, &N,
1219 <double complex *>T.data, &N,
1220 vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
1221 <double complex *>work.data, <double *>rwork.data, &info)
1223 assert info == 0, "Argument error in tgevc"
1224 assert MM == M, "Unexpected number of eigenvectors returned in tgevc"
1226 if not backtr:
1227 if left:
1228 vl_r = np.asfortranarray(np.dot(Q, vl_r))
1229 if right:
1230 vr_r = np.asfortranarray(np.dot(Z, vr_r))
1232 # If there are complex eigenvalues, we need to postprocess the eigenvectors
1233 cdef np.ndarray vl, vr
1234 if left:
1235 vl = vl_r
1236 if right:
1237 vr = vr_r
1238 if scalar in floating:
1239 if scalar is float:
1240 dtype = np.complex64
1241 else:
1242 dtype = np.complex128
1243 if np.diagonal(S, -1).nonzero()[0].size:
1244 if left:
1245 vl = txevc_postprocess(dtype, S, vl_r, select)
1246 if right:
1247 vr = txevc_postprocess(dtype, S, vr_r, select)
1249 if left and right:
1250 return (vl, vr)
1251 elif left:
1252 return vl
1253 else:
1254 return vr
1257def prepare_for_lapack(overwrite, *args):
1258 """Convert arrays to Fortran format.
1260 This function takes a number of array objects in `args` and converts them
1261 to a format that can be directly passed to a Fortran function (Fortran
1262 contiguous NumPy array). If the arrays have different data type, they
1263 converted arrays are cast to a common compatible data type (one of NumPy's
1264 `float32`, `float64`, `complex64`, `complex128` data types).
1266 If `overwrite` is ``False``, an NumPy array that would already be in the
1267 correct format (Fortran contiguous, right data type) is neverthelessed
1268 copied. (Hence, overwrite = True does not imply that acting on the
1269 converted array in the return values will overwrite the original array in
1270 all cases -- it does only so if the original array was already in the
1271 correct format. The conversions require copying. In fact, that's the same
1272 behavior as in SciPy, it's just not explicitly stated there)
1274 If an argument is ``None``, it is just passed through and not used to
1275 determine the proper LAPACK type.
1277 Returns a list of properly converted arrays.
1278 """
1280 # Make sure we have NumPy arrays
1281 mats = [None]*len(args)
1282 for i in range(len(args)):
1283 if args[i] is not None:
1284 arr = np.asanyarray(args[i])
1285 if not np.issubdtype(arr.dtype, np.number):
1286 raise ValueError("Argument cannot be interpreted "
1287 "as a numeric array")
1289 mats[i] = (arr, arr is not args[i] or overwrite)
1290 else:
1291 mats[i] = (None, True)
1293 # First figure out common dtype
1294 # Note: The return type of common_type is guaranteed to be a floating point
1295 # kind.
1296 dtype = np.common_type(*[arr for arr, ovwrt in mats if arr is not None])
1298 if dtype not in (np.float32, np.float64, np.complex64, np.complex128):
1299 raise AssertionError("Unexpected data type from common_type")
1301 ret = []
1302 for npmat, ovwrt in mats:
1303 # Now make sure that the array is contiguous, and copy if necessary.
1304 if npmat is not None:
1305 if npmat.ndim == 2:
1306 if not npmat.flags["F_CONTIGUOUS"]:
1307 npmat = np.asfortranarray(npmat, dtype = dtype)
1308 elif npmat.dtype != dtype:
1309 npmat = npmat.astype(dtype)
1310 elif not ovwrt:
1311 # ugly here: copy makes always C-array, no way to tell it
1312 # to make a Fortran array.
1313 npmat = np.asfortranarray(npmat.copy())
1314 elif npmat.ndim == 1:
1315 if not npmat.flags["C_CONTIGUOUS"]:
1316 npmat = np.ascontiguousarray(npmat, dtype = dtype)
1317 elif npmat.dtype != dtype:
1318 npmat = npmat.astype(dtype)
1319 elif not ovwrt:
1320 npmat = np.asfortranarray(npmat.copy())
1321 else:
1322 raise ValueError("Dimensionality of array is not 1 or 2")
1324 ret.append(npmat)
1326 if len(ret) == 1:
1327 return ret[0]
1328 else:
1329 return tuple(ret)