Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# Copyright 2011-2013 Kwant authors. 

2# 

3# This file is part of Kwant. It is subject to the license terms in the file 

4# LICENSE.rst found in the top-level directory of this distribution and at 

5# https://kwant-project.org/license. A list of Kwant authors can be found in 

6# the file AUTHORS.rst at the top-level directory of this distribution and at 

7# https://kwant-project.org/authors. 

8 

9__all__ = ['smatrix', 'greens_function', 'ldos', 'wave_function', 'Solver'] 

10 

11import numpy as np 

12import scipy.sparse as sp 

13from . import common 

14 

15# Note: previous code would have failed if UMFPACK was provided by scikit 

16import scipy.sparse.linalg.dsolve.linsolve as linsolve 

17 

18uses_umfpack = linsolve.useUmfpack 

19 

20if uses_umfpack: 20 ↛ 21line 20 didn't jump to line 21, because the condition on line 20 was never true

21 def factorized(A, piv_tol=1.0, sym_piv_tol=1.0): 

22 """ 

23 Return a fuction for solving a sparse linear system, with A 

24 pre-factorized. 

25 

26 Parameters 

27 ---------- 

28 A : csc_matrix 

29 matrix to be factorized 

30 piv_tol : float, 0 <= piv_tol <= 1.0 

31 sym_piv_tol : float, 0 <= piv_tol <= 1.0 

32 thresholds used by UMFPACK for pivoting. 0 means no pivoting, 1.0 

33 means full pivoting as in dense matrices (guaranteeing stability, 

34 but reducing possibly sparsity). Defaults of UMFPACK are 0.1 and 

35 0.001 respectively. Whether piv_tol or sym_piv_tol are used is 

36 decided internally by UMFPACK, depending on whether the matrix is 

37 "symmetric" enough. 

38 

39 Examples 

40 -------- 

41 solve = factorized(A) # Makes LU decomposition. 

42 x1 = solve(rhs1) # Uses the LU factors. 

43 x2 = solve(rhs2) # Uses again the LU factors. 

44 """ 

45 umfpack = linsolve.umfpack 

46 

47 if not sp.isspmatrix_csc(A): 

48 A = sp.csc_matrix(A) 

49 

50 A.sort_indices() 

51 A = A.asfptype() # upcast to a floating point format 

52 

53 if A.dtype.char not in 'dD': 

54 raise ValueError("convert matrix data to double, please, using" 

55 " .astype()") 

56 

57 family = {'d': 'di', 'D': 'zi'} 

58 umf = umfpack.UmfpackContext(family[A.dtype.char]) 

59 

60 # adjust pivot thresholds 

61 umf.control[umfpack.UMFPACK_PIVOT_TOLERANCE] = piv_tol 

62 umf.control[umfpack.UMFPACK_SYM_PIVOT_TOLERANCE] = sym_piv_tol 

63 

64 # Make LU decomposition. 

65 umf.numeric(A) 

66 

67 def solve(b): 

68 return umf.solve(umfpack.UMFPACK_A, A, b, autoTranspose=True) 

69 

70 return solve 

71else: 

72 # no UMFPACK found. SuperLU is being used, but usually abysmally slow 

73 # (SuperLu is not bad per se, somehow the SciPy version isn't good). 

74 # Since scipy doesn't include UMFPACK anymore due to software rot, 

75 # there is no warning here. 

76 factorized = linsolve.factorized 

77 

78 

79class Solver(common.SparseSolver): 

80 "Sparse Solver class based on the sparse direct solvers provided by SciPy." 

81 lhsformat = 'csc' 

82 rhsformat = 'csc' 

83 nrhs = 1 

84 

85 def _factorized(self, a): 

86 a = sp.csc_matrix(a) 

87 return factorized(a) 

88 

89 def _solve_linear_sys(self, factorized_a, b, kept_vars): 

90 if b.shape[1] == 0: 

91 return b[kept_vars] 

92 

93 sols = [] 

94 vec = np.empty(b.shape[0], complex) 

95 for j in range(b.shape[1]): 

96 vec[:] = b[:, j].toarray().flatten() 

97 sols.append(factorized_a(vec)[kept_vars]) 

98 

99 return np.asarray(sols).transpose() 

100 

101 

102default_solver = Solver() 

103 

104smatrix = default_solver.smatrix 

105greens_function = default_solver.greens_function 

106ldos = default_solver.ldos 

107wave_function = default_solver.wave_function