Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright 2011-2013 Kwant authors.
2#
3# This file is part of Kwant. It is subject to the license terms in the file
4# LICENSE.rst found in the top-level directory of this distribution and at
5# https://kwant-project.org/license. A list of Kwant authors can be found in
6# the file AUTHORS.rst at the top-level directory of this distribution and at
7# https://kwant-project.org/authors.
9__all__ = ['smatrix', 'greens_function', 'ldos', 'wave_function', 'Solver']
11import numpy as np
12import scipy.sparse as sp
13from . import common
15# Note: previous code would have failed if UMFPACK was provided by scikit
16import scipy.sparse.linalg.dsolve.linsolve as linsolve
18uses_umfpack = linsolve.useUmfpack
20if uses_umfpack: 20 ↛ 21line 20 didn't jump to line 21, because the condition on line 20 was never true
21 def factorized(A, piv_tol=1.0, sym_piv_tol=1.0):
22 """
23 Return a fuction for solving a sparse linear system, with A
24 pre-factorized.
26 Parameters
27 ----------
28 A : csc_matrix
29 matrix to be factorized
30 piv_tol : float, 0 <= piv_tol <= 1.0
31 sym_piv_tol : float, 0 <= piv_tol <= 1.0
32 thresholds used by UMFPACK for pivoting. 0 means no pivoting, 1.0
33 means full pivoting as in dense matrices (guaranteeing stability,
34 but reducing possibly sparsity). Defaults of UMFPACK are 0.1 and
35 0.001 respectively. Whether piv_tol or sym_piv_tol are used is
36 decided internally by UMFPACK, depending on whether the matrix is
37 "symmetric" enough.
39 Examples
40 --------
41 solve = factorized(A) # Makes LU decomposition.
42 x1 = solve(rhs1) # Uses the LU factors.
43 x2 = solve(rhs2) # Uses again the LU factors.
44 """
45 umfpack = linsolve.umfpack
47 if not sp.isspmatrix_csc(A):
48 A = sp.csc_matrix(A)
50 A.sort_indices()
51 A = A.asfptype() # upcast to a floating point format
53 if A.dtype.char not in 'dD':
54 raise ValueError("convert matrix data to double, please, using"
55 " .astype()")
57 family = {'d': 'di', 'D': 'zi'}
58 umf = umfpack.UmfpackContext(family[A.dtype.char])
60 # adjust pivot thresholds
61 umf.control[umfpack.UMFPACK_PIVOT_TOLERANCE] = piv_tol
62 umf.control[umfpack.UMFPACK_SYM_PIVOT_TOLERANCE] = sym_piv_tol
64 # Make LU decomposition.
65 umf.numeric(A)
67 def solve(b):
68 return umf.solve(umfpack.UMFPACK_A, A, b, autoTranspose=True)
70 return solve
71else:
72 # no UMFPACK found. SuperLU is being used, but usually abysmally slow
73 # (SuperLu is not bad per se, somehow the SciPy version isn't good).
74 # Since scipy doesn't include UMFPACK anymore due to software rot,
75 # there is no warning here.
76 factorized = linsolve.factorized
79class Solver(common.SparseSolver):
80 "Sparse Solver class based on the sparse direct solvers provided by SciPy."
81 lhsformat = 'csc'
82 rhsformat = 'csc'
83 nrhs = 1
85 def _factorized(self, a):
86 a = sp.csc_matrix(a)
87 return factorized(a)
89 def _solve_linear_sys(self, factorized_a, b, kept_vars):
90 if b.shape[1] == 0:
91 return b[kept_vars]
93 sols = []
94 vec = np.empty(b.shape[0], complex)
95 for j in range(b.shape[1]):
96 vec[:] = b[:, j].toarray().flatten()
97 sols.append(factorized_a(vec)[kept_vars])
99 return np.asarray(sols).transpose()
102default_solver = Solver()
104smatrix = default_solver.smatrix
105greens_function = default_solver.greens_function
106ldos = default_solver.ldos
107wave_function = default_solver.wave_function