Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2Matrix square root for general matrices and for upper triangular matrices. 

3 

4This module exists to avoid cyclic imports. 

5 

6""" 

7__all__ = ['sqrtm'] 

8 

9import numpy as np 

10 

11from scipy._lib._util import _asarray_validated 

12 

13 

14# Local imports 

15from .misc import norm 

16from .lapack import ztrsyl, dtrsyl 

17from .decomp_schur import schur, rsf2csf 

18 

19 

20class SqrtmError(np.linalg.LinAlgError): 

21 pass 

22 

23 

24def _sqrtm_triu(T, blocksize=64): 

25 """ 

26 Matrix square root of an upper triangular matrix. 

27 

28 This is a helper function for `sqrtm` and `logm`. 

29 

30 Parameters 

31 ---------- 

32 T : (N, N) array_like upper triangular 

33 Matrix whose square root to evaluate 

34 blocksize : int, optional 

35 If the blocksize is not degenerate with respect to the 

36 size of the input array, then use a blocked algorithm. (Default: 64) 

37 

38 Returns 

39 ------- 

40 sqrtm : (N, N) ndarray 

41 Value of the sqrt function at `T` 

42 

43 References 

44 ---------- 

45 .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013) 

46 "Blocked Schur Algorithms for Computing the Matrix Square Root, 

47 Lecture Notes in Computer Science, 7782. pp. 171-182. 

48 

49 """ 

50 T_diag = np.diag(T) 

51 keep_it_real = np.isrealobj(T) and np.min(T_diag) >= 0 

52 if not keep_it_real: 

53 T_diag = T_diag.astype(complex) 

54 R = np.diag(np.sqrt(T_diag)) 

55 

56 # Compute the number of blocks to use; use at least one block. 

57 n, n = T.shape 

58 nblocks = max(n // blocksize, 1) 

59 

60 # Compute the smaller of the two sizes of blocks that 

61 # we will actually use, and compute the number of large blocks. 

62 bsmall, nlarge = divmod(n, nblocks) 

63 blarge = bsmall + 1 

64 nsmall = nblocks - nlarge 

65 if nsmall * bsmall + nlarge * blarge != n: 

66 raise Exception('internal inconsistency') 

67 

68 # Define the index range covered by each block. 

69 start_stop_pairs = [] 

70 start = 0 

71 for count, size in ((nsmall, bsmall), (nlarge, blarge)): 

72 for i in range(count): 

73 start_stop_pairs.append((start, start + size)) 

74 start += size 

75 

76 # Within-block interactions. 

77 for start, stop in start_stop_pairs: 

78 for j in range(start, stop): 

79 for i in range(j-1, start-1, -1): 

80 s = 0 

81 if j - i > 1: 

82 s = R[i, i+1:j].dot(R[i+1:j, j]) 

83 denom = R[i, i] + R[j, j] 

84 num = T[i, j] - s 

85 if denom != 0: 

86 R[i, j] = (T[i, j] - s) / denom 

87 elif denom == 0 and num == 0: 

88 R[i, j] = 0 

89 else: 

90 raise SqrtmError('failed to find the matrix square root') 

91 

92 # Between-block interactions. 

93 for j in range(nblocks): 

94 jstart, jstop = start_stop_pairs[j] 

95 for i in range(j-1, -1, -1): 

96 istart, istop = start_stop_pairs[i] 

97 S = T[istart:istop, jstart:jstop] 

98 if j - i > 1: 

99 S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart, 

100 jstart:jstop]) 

101 

102 # Invoke LAPACK. 

103 # For more details, see the solve_sylvester implemention 

104 # and the fortran dtrsyl and ztrsyl docs. 

105 Rii = R[istart:istop, istart:istop] 

106 Rjj = R[jstart:jstop, jstart:jstop] 

107 if keep_it_real: 

108 x, scale, info = dtrsyl(Rii, Rjj, S) 

109 else: 

110 x, scale, info = ztrsyl(Rii, Rjj, S) 

111 R[istart:istop, jstart:jstop] = x * scale 

112 

113 # Return the matrix square root. 

114 return R 

115 

116 

117def sqrtm(A, disp=True, blocksize=64): 

118 """ 

119 Matrix square root. 

120 

121 Parameters 

122 ---------- 

123 A : (N, N) array_like 

124 Matrix whose square root to evaluate 

125 disp : bool, optional 

126 Print warning if error in the result is estimated large 

127 instead of returning estimated error. (Default: True) 

128 blocksize : integer, optional 

129 If the blocksize is not degenerate with respect to the 

130 size of the input array, then use a blocked algorithm. (Default: 64) 

131 

132 Returns 

133 ------- 

134 sqrtm : (N, N) ndarray 

135 Value of the sqrt function at `A` 

136 

137 errest : float 

138 (if disp == False) 

139 

140 Frobenius norm of the estimated error, ||err||_F / ||A||_F 

141 

142 References 

143 ---------- 

144 .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013) 

145 "Blocked Schur Algorithms for Computing the Matrix Square Root, 

146 Lecture Notes in Computer Science, 7782. pp. 171-182. 

147 

148 Examples 

149 -------- 

150 >>> from scipy.linalg import sqrtm 

151 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]]) 

152 >>> r = sqrtm(a) 

153 >>> r 

154 array([[ 0.75592895, 1.13389342], 

155 [ 0.37796447, 1.88982237]]) 

156 >>> r.dot(r) 

157 array([[ 1., 3.], 

158 [ 1., 4.]]) 

159 

160 """ 

161 A = _asarray_validated(A, check_finite=True, as_inexact=True) 

162 if len(A.shape) != 2: 

163 raise ValueError("Non-matrix input to matrix function.") 

164 if blocksize < 1: 

165 raise ValueError("The blocksize should be at least 1.") 

166 keep_it_real = np.isrealobj(A) 

167 if keep_it_real: 

168 T, Z = schur(A) 

169 if not np.array_equal(T, np.triu(T)): 

170 T, Z = rsf2csf(T, Z) 

171 else: 

172 T, Z = schur(A, output='complex') 

173 failflag = False 

174 try: 

175 R = _sqrtm_triu(T, blocksize=blocksize) 

176 ZH = np.conjugate(Z).T 

177 X = Z.dot(R).dot(ZH) 

178 except SqrtmError: 

179 failflag = True 

180 X = np.empty_like(A) 

181 X.fill(np.nan) 

182 

183 if disp: 

184 if failflag: 

185 print("Failed to find a square root.") 

186 return X 

187 else: 

188 try: 

189 arg2 = norm(X.dot(X) - A, 'fro')**2 / norm(A, 'fro') 

190 except ValueError: 

191 # NaNs in matrix 

192 arg2 = np.inf 

193 

194 return X, arg2