svd.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "gonum.org/v1/gonum/blas/blas64"
  7. "gonum.org/v1/gonum/lapack"
  8. "gonum.org/v1/gonum/lapack/lapack64"
  9. )
  10. // SVD is a type for creating and using the Singular Value Decomposition (SVD)
  11. // of a matrix.
  12. type SVD struct {
  13. kind SVDKind
  14. s []float64
  15. u blas64.General
  16. vt blas64.General
  17. }
  18. // SVDKind specifies the treatment of singular vectors during an SVD
  19. // factorization.
  20. type SVDKind int
  21. const (
  22. // SVDNone specifies that no singular vectors should be computed during
  23. // the decomposition.
  24. SVDNone SVDKind = 0
  25. // SVDThinU specifies the thin decomposition for U should be computed.
  26. SVDThinU SVDKind = 1 << (iota - 1)
  27. // SVDFullU specifies the full decomposition for U should be computed.
  28. SVDFullU
  29. // SVDThinV specifies the thin decomposition for V should be computed.
  30. SVDThinV
  31. // SVDFullV specifies the full decomposition for V should be computed.
  32. SVDFullV
  33. // SVDThin is a convenience value for computing both thin vectors.
  34. SVDThin SVDKind = SVDThinU | SVDThinV
  35. // SVDThin is a convenience value for computing both full vectors.
  36. SVDFull SVDKind = SVDFullU | SVDFullV
  37. )
  38. // succFact returns whether the receiver contains a successful factorization.
  39. func (svd *SVD) succFact() bool {
  40. return len(svd.s) != 0
  41. }
  42. // Factorize computes the singular value decomposition (SVD) of the input matrix A.
  43. // The singular values of A are computed in all cases, while the singular
  44. // vectors are optionally computed depending on the input kind.
  45. //
  46. // The full singular value decomposition (kind == SVDFull) is a factorization
  47. // of an m×n matrix A of the form
  48. // A = U * Σ * V^T
  49. // where Σ is an m×n diagonal matrix, U is an m×m orthogonal matrix, and V is an
  50. // n×n orthogonal matrix. The diagonal elements of Σ are the singular values of A.
  51. // The first min(m,n) columns of U and V are, respectively, the left and right
  52. // singular vectors of A.
  53. //
  54. // Significant storage space can be saved by using the thin representation of
  55. // the SVD (kind == SVDThin) instead of the full SVD, especially if
  56. // m >> n or m << n. The thin SVD finds
  57. // A = U~ * Σ * V~^T
  58. // where U~ is of size m×min(m,n), Σ is a diagonal matrix of size min(m,n)×min(m,n)
  59. // and V~ is of size n×min(m,n).
  60. //
  61. // Factorize returns whether the decomposition succeeded. If the decomposition
  62. // failed, routines that require a successful factorization will panic.
  63. func (svd *SVD) Factorize(a Matrix, kind SVDKind) (ok bool) {
  64. // kill previous factorization
  65. svd.s = svd.s[:0]
  66. svd.kind = kind
  67. m, n := a.Dims()
  68. var jobU, jobVT lapack.SVDJob
  69. // TODO(btracey): This code should be modified to have the smaller
  70. // matrix written in-place into aCopy when the lapack/native/dgesvd
  71. // implementation is complete.
  72. switch {
  73. case kind&SVDFullU != 0:
  74. jobU = lapack.SVDAll
  75. svd.u = blas64.General{
  76. Rows: m,
  77. Cols: m,
  78. Stride: m,
  79. Data: use(svd.u.Data, m*m),
  80. }
  81. case kind&SVDThinU != 0:
  82. jobU = lapack.SVDStore
  83. svd.u = blas64.General{
  84. Rows: m,
  85. Cols: min(m, n),
  86. Stride: min(m, n),
  87. Data: use(svd.u.Data, m*min(m, n)),
  88. }
  89. default:
  90. jobU = lapack.SVDNone
  91. }
  92. switch {
  93. case kind&SVDFullV != 0:
  94. svd.vt = blas64.General{
  95. Rows: n,
  96. Cols: n,
  97. Stride: n,
  98. Data: use(svd.vt.Data, n*n),
  99. }
  100. jobVT = lapack.SVDAll
  101. case kind&SVDThinV != 0:
  102. svd.vt = blas64.General{
  103. Rows: min(m, n),
  104. Cols: n,
  105. Stride: n,
  106. Data: use(svd.vt.Data, min(m, n)*n),
  107. }
  108. jobVT = lapack.SVDStore
  109. default:
  110. jobVT = lapack.SVDNone
  111. }
  112. // A is destroyed on call, so copy the matrix.
  113. aCopy := DenseCopyOf(a)
  114. svd.kind = kind
  115. svd.s = use(svd.s, min(m, n))
  116. work := []float64{0}
  117. lapack64.Gesvd(jobU, jobVT, aCopy.mat, svd.u, svd.vt, svd.s, work, -1)
  118. work = getFloats(int(work[0]), false)
  119. ok = lapack64.Gesvd(jobU, jobVT, aCopy.mat, svd.u, svd.vt, svd.s, work, len(work))
  120. putFloats(work)
  121. if !ok {
  122. svd.kind = 0
  123. }
  124. return ok
  125. }
  126. // Kind returns the SVDKind of the decomposition. If no decomposition has been
  127. // computed, Kind returns -1.
  128. func (svd *SVD) Kind() SVDKind {
  129. if !svd.succFact() {
  130. return -1
  131. }
  132. return svd.kind
  133. }
  134. // Cond returns the 2-norm condition number for the factorized matrix. Cond will
  135. // panic if the receiver does not contain a successful factorization.
  136. func (svd *SVD) Cond() float64 {
  137. if !svd.succFact() {
  138. panic(badFact)
  139. }
  140. return svd.s[0] / svd.s[len(svd.s)-1]
  141. }
  142. // Values returns the singular values of the factorized matrix in descending order.
  143. //
  144. // If the input slice is non-nil, the values will be stored in-place into
  145. // the slice. In this case, the slice must have length min(m,n), and Values will
  146. // panic with ErrSliceLengthMismatch otherwise. If the input slice is nil, a new
  147. // slice of the appropriate length will be allocated and returned.
  148. //
  149. // Values will panic if the receiver does not contain a successful factorization.
  150. func (svd *SVD) Values(s []float64) []float64 {
  151. if !svd.succFact() {
  152. panic(badFact)
  153. }
  154. if s == nil {
  155. s = make([]float64, len(svd.s))
  156. }
  157. if len(s) != len(svd.s) {
  158. panic(ErrSliceLengthMismatch)
  159. }
  160. copy(s, svd.s)
  161. return s
  162. }
  163. // UTo extracts the matrix U from the singular value decomposition. The first
  164. // min(m,n) columns are the left singular vectors and correspond to the singular
  165. // values as returned from SVD.Values.
  166. //
  167. // If dst is not nil, U is stored in-place into dst, and dst must have size
  168. // m×m if the full U was computed, size m×min(m,n) if the thin U was computed,
  169. // and UTo panics otherwise. If dst is nil, a new matrix of the appropriate size
  170. // is allocated and returned.
  171. func (svd *SVD) UTo(dst *Dense) *Dense {
  172. if !svd.succFact() {
  173. panic(badFact)
  174. }
  175. kind := svd.kind
  176. if kind&SVDThinU == 0 && kind&SVDFullU == 0 {
  177. panic("svd: u not computed during factorization")
  178. }
  179. r := svd.u.Rows
  180. c := svd.u.Cols
  181. if dst == nil {
  182. dst = NewDense(r, c, nil)
  183. } else {
  184. dst.reuseAs(r, c)
  185. }
  186. tmp := &Dense{
  187. mat: svd.u,
  188. capRows: r,
  189. capCols: c,
  190. }
  191. dst.Copy(tmp)
  192. return dst
  193. }
  194. // VTo extracts the matrix V from the singular value decomposition. The first
  195. // min(m,n) columns are the right singular vectors and correspond to the singular
  196. // values as returned from SVD.Values.
  197. //
  198. // If dst is not nil, V is stored in-place into dst, and dst must have size
  199. // n×n if the full V was computed, size n×min(m,n) if the thin V was computed,
  200. // and VTo panics otherwise. If dst is nil, a new matrix of the appropriate size
  201. // is allocated and returned.
  202. func (svd *SVD) VTo(dst *Dense) *Dense {
  203. if !svd.succFact() {
  204. panic(badFact)
  205. }
  206. kind := svd.kind
  207. if kind&SVDThinU == 0 && kind&SVDFullV == 0 {
  208. panic("svd: v not computed during factorization")
  209. }
  210. r := svd.vt.Rows
  211. c := svd.vt.Cols
  212. if dst == nil {
  213. dst = NewDense(c, r, nil)
  214. } else {
  215. dst.reuseAs(c, r)
  216. }
  217. tmp := &Dense{
  218. mat: svd.vt,
  219. capRows: r,
  220. capCols: c,
  221. }
  222. dst.Copy(tmp.T())
  223. return dst
  224. }