dggsvd3.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. // Copyright ©2017 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gonum
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas/blas64"
  8. "gonum.org/v1/gonum/lapack"
  9. )
  10. // Dggsvd3 computes the generalized singular value decomposition (GSVD)
  11. // of an m×n matrix A and p×n matrix B:
  12. // Uᵀ*A*Q = D1*[ 0 R ]
  13. //
  14. // Vᵀ*B*Q = D2*[ 0 R ]
  15. // where U, V and Q are orthogonal matrices.
  16. //
  17. // Dggsvd3 returns k and l, the dimensions of the sub-blocks. k+l
  18. // is the effective numerical rank of the (m+p)×n matrix [ Aᵀ Bᵀ ]ᵀ.
  19. // R is a (k+l)×(k+l) nonsingular upper triangular matrix, D1 and
  20. // D2 are m×(k+l) and p×(k+l) diagonal matrices and of the following
  21. // structures, respectively:
  22. //
  23. // If m-k-l >= 0,
  24. //
  25. // k l
  26. // D1 = k [ I 0 ]
  27. // l [ 0 C ]
  28. // m-k-l [ 0 0 ]
  29. //
  30. // k l
  31. // D2 = l [ 0 S ]
  32. // p-l [ 0 0 ]
  33. //
  34. // n-k-l k l
  35. // [ 0 R ] = k [ 0 R11 R12 ] k
  36. // l [ 0 0 R22 ] l
  37. //
  38. // where
  39. //
  40. // C = diag( alpha_k, ... , alpha_{k+l} ),
  41. // S = diag( beta_k, ... , beta_{k+l} ),
  42. // C^2 + S^2 = I.
  43. //
  44. // R is stored in
  45. // A[0:k+l, n-k-l:n]
  46. // on exit.
  47. //
  48. // If m-k-l < 0,
  49. //
  50. // k m-k k+l-m
  51. // D1 = k [ I 0 0 ]
  52. // m-k [ 0 C 0 ]
  53. //
  54. // k m-k k+l-m
  55. // D2 = m-k [ 0 S 0 ]
  56. // k+l-m [ 0 0 I ]
  57. // p-l [ 0 0 0 ]
  58. //
  59. // n-k-l k m-k k+l-m
  60. // [ 0 R ] = k [ 0 R11 R12 R13 ]
  61. // m-k [ 0 0 R22 R23 ]
  62. // k+l-m [ 0 0 0 R33 ]
  63. //
  64. // where
  65. // C = diag( alpha_k, ... , alpha_m ),
  66. // S = diag( beta_k, ... , beta_m ),
  67. // C^2 + S^2 = I.
  68. //
  69. // R = [ R11 R12 R13 ] is stored in A[1:m, n-k-l+1:n]
  70. // [ 0 R22 R23 ]
  71. // and R33 is stored in
  72. // B[m-k:l, n+m-k-l:n] on exit.
  73. //
  74. // Dggsvd3 computes C, S, R, and optionally the orthogonal transformation
  75. // matrices U, V and Q.
  76. //
  77. // jobU, jobV and jobQ are options for computing the orthogonal matrices. The behavior
  78. // is as follows
  79. // jobU == lapack.GSVDU Compute orthogonal matrix U
  80. // jobU == lapack.GSVDNone Do not compute orthogonal matrix.
  81. // The behavior is the same for jobV and jobQ with the exception that instead of
  82. // lapack.GSVDU these accept lapack.GSVDV and lapack.GSVDQ respectively.
  83. // The matrices U, V and Q must be m×m, p×p and n×n respectively unless the
  84. // relevant job parameter is lapack.GSVDNone.
  85. //
  86. // alpha and beta must have length n or Dggsvd3 will panic. On exit, alpha and
  87. // beta contain the generalized singular value pairs of A and B
  88. // alpha[0:k] = 1,
  89. // beta[0:k] = 0,
  90. // if m-k-l >= 0,
  91. // alpha[k:k+l] = diag(C),
  92. // beta[k:k+l] = diag(S),
  93. // if m-k-l < 0,
  94. // alpha[k:m]= C, alpha[m:k+l]= 0
  95. // beta[k:m] = S, beta[m:k+l] = 1.
  96. // if k+l < n,
  97. // alpha[k+l:n] = 0 and
  98. // beta[k+l:n] = 0.
  99. //
  100. // On exit, iwork contains the permutation required to sort alpha descending.
  101. //
  102. // iwork must have length n, work must have length at least max(1, lwork), and
  103. // lwork must be -1 or greater than n, otherwise Dggsvd3 will panic. If
  104. // lwork is -1, work[0] holds the optimal lwork on return, but Dggsvd3 does
  105. // not perform the GSVD.
  106. func (impl Implementation) Dggsvd3(jobU, jobV, jobQ lapack.GSVDJob, m, n, p int, a []float64, lda int, b []float64, ldb int, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64, lwork int, iwork []int) (k, l int, ok bool) {
  107. wantu := jobU == lapack.GSVDU
  108. wantv := jobV == lapack.GSVDV
  109. wantq := jobQ == lapack.GSVDQ
  110. switch {
  111. case !wantu && jobU != lapack.GSVDNone:
  112. panic(badGSVDJob + "U")
  113. case !wantv && jobV != lapack.GSVDNone:
  114. panic(badGSVDJob + "V")
  115. case !wantq && jobQ != lapack.GSVDNone:
  116. panic(badGSVDJob + "Q")
  117. case m < 0:
  118. panic(mLT0)
  119. case n < 0:
  120. panic(nLT0)
  121. case p < 0:
  122. panic(pLT0)
  123. case lda < max(1, n):
  124. panic(badLdA)
  125. case ldb < max(1, n):
  126. panic(badLdB)
  127. case ldu < 1, wantu && ldu < m:
  128. panic(badLdU)
  129. case ldv < 1, wantv && ldv < p:
  130. panic(badLdV)
  131. case ldq < 1, wantq && ldq < n:
  132. panic(badLdQ)
  133. case len(iwork) < n:
  134. panic(shortWork)
  135. case lwork < 1 && lwork != -1:
  136. panic(badLWork)
  137. case len(work) < max(1, lwork):
  138. panic(shortWork)
  139. }
  140. // Determine optimal work length.
  141. impl.Dggsvp3(jobU, jobV, jobQ,
  142. m, p, n,
  143. a, lda,
  144. b, ldb,
  145. 0, 0,
  146. u, ldu,
  147. v, ldv,
  148. q, ldq,
  149. iwork,
  150. work, work, -1)
  151. lwkopt := n + int(work[0])
  152. lwkopt = max(lwkopt, 2*n)
  153. lwkopt = max(lwkopt, 1)
  154. work[0] = float64(lwkopt)
  155. if lwork == -1 {
  156. return 0, 0, true
  157. }
  158. switch {
  159. case len(a) < (m-1)*lda+n:
  160. panic(shortA)
  161. case len(b) < (p-1)*ldb+n:
  162. panic(shortB)
  163. case wantu && len(u) < (m-1)*ldu+m:
  164. panic(shortU)
  165. case wantv && len(v) < (p-1)*ldv+p:
  166. panic(shortV)
  167. case wantq && len(q) < (n-1)*ldq+n:
  168. panic(shortQ)
  169. case len(alpha) != n:
  170. panic(badLenAlpha)
  171. case len(beta) != n:
  172. panic(badLenBeta)
  173. }
  174. // Compute the Frobenius norm of matrices A and B.
  175. anorm := impl.Dlange(lapack.Frobenius, m, n, a, lda, nil)
  176. bnorm := impl.Dlange(lapack.Frobenius, p, n, b, ldb, nil)
  177. // Get machine precision and set up threshold for determining
  178. // the effective numerical rank of the matrices A and B.
  179. tola := float64(max(m, n)) * math.Max(anorm, dlamchS) * dlamchP
  180. tolb := float64(max(p, n)) * math.Max(bnorm, dlamchS) * dlamchP
  181. // Preprocessing.
  182. k, l = impl.Dggsvp3(jobU, jobV, jobQ,
  183. m, p, n,
  184. a, lda,
  185. b, ldb,
  186. tola, tolb,
  187. u, ldu,
  188. v, ldv,
  189. q, ldq,
  190. iwork,
  191. work[:n], work[n:], lwork-n)
  192. // Compute the GSVD of two upper "triangular" matrices.
  193. _, ok = impl.Dtgsja(jobU, jobV, jobQ,
  194. m, p, n,
  195. k, l,
  196. a, lda,
  197. b, ldb,
  198. tola, tolb,
  199. alpha, beta,
  200. u, ldu,
  201. v, ldv,
  202. q, ldq,
  203. work)
  204. // Sort the singular values and store the pivot indices in iwork
  205. // Copy alpha to work, then sort alpha in work.
  206. bi := blas64.Implementation()
  207. bi.Dcopy(n, alpha, 1, work[:n], 1)
  208. ibnd := min(l, m-k)
  209. for i := 0; i < ibnd; i++ {
  210. // Scan for largest alpha_{k+i}.
  211. isub := i
  212. smax := work[k+i]
  213. for j := i + 1; j < ibnd; j++ {
  214. if v := work[k+j]; v > smax {
  215. isub = j
  216. smax = v
  217. }
  218. }
  219. if isub != i {
  220. work[k+isub] = work[k+i]
  221. work[k+i] = smax
  222. iwork[k+i] = k + isub
  223. } else {
  224. iwork[k+i] = k + i
  225. }
  226. }
  227. work[0] = float64(lwkopt)
  228. return k, l, ok
  229. }