qr.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. "gonum.org/v1/gonum/lapack"
  10. "gonum.org/v1/gonum/lapack/lapack64"
  11. )
  12. const badQR = "mat: invalid QR factorization"
  13. // QR is a type for creating and using the QR factorization of a matrix.
  14. type QR struct {
  15. qr *Dense
  16. tau []float64
  17. cond float64
  18. }
  19. func (qr *QR) updateCond(norm lapack.MatrixNorm) {
  20. // Since A = Q*R, and Q is orthogonal, we get for the condition number κ
  21. // κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Q^T|
  22. // = |R| |R^-1| = κ(R),
  23. // where we used that fact that Q^-1 = Q^T. However, this assumes that
  24. // the matrix norm is invariant under orthogonal transformations which
  25. // is not the case for CondNorm. Hopefully the error is negligible: κ
  26. // is only a qualitative measure anyway.
  27. n := qr.qr.mat.Cols
  28. work := getFloats(3*n, false)
  29. iwork := getInts(n, false)
  30. r := qr.qr.asTriDense(n, blas.NonUnit, blas.Upper)
  31. v := lapack64.Trcon(norm, r.mat, work, iwork)
  32. putFloats(work)
  33. putInts(iwork)
  34. qr.cond = 1 / v
  35. }
  36. // Factorize computes the QR factorization of an m×n matrix a where m >= n. The QR
  37. // factorization always exists even if A is singular.
  38. //
  39. // The QR decomposition is a factorization of the matrix A such that A = Q * R.
  40. // The matrix Q is an orthonormal m×m matrix, and R is an m×n upper triangular matrix.
  41. // Q and R can be extracted using the QTo and RTo methods.
  42. func (qr *QR) Factorize(a Matrix) {
  43. qr.factorize(a, CondNorm)
  44. }
  45. func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) {
  46. m, n := a.Dims()
  47. if m < n {
  48. panic(ErrShape)
  49. }
  50. k := min(m, n)
  51. if qr.qr == nil {
  52. qr.qr = &Dense{}
  53. }
  54. qr.qr.Clone(a)
  55. work := []float64{0}
  56. qr.tau = make([]float64, k)
  57. lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1)
  58. work = getFloats(int(work[0]), false)
  59. lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work))
  60. putFloats(work)
  61. qr.updateCond(norm)
  62. }
  63. // isValid returns whether the receiver contains a factorization.
  64. func (qr *QR) isValid() bool {
  65. return qr.qr != nil && !qr.qr.IsZero()
  66. }
  67. // Cond returns the condition number for the factorized matrix.
  68. // Cond will panic if the receiver does not contain a factorization.
  69. func (qr *QR) Cond() float64 {
  70. if !qr.isValid() {
  71. panic(badQR)
  72. }
  73. return qr.cond
  74. }
  75. // TODO(btracey): Add in the "Reduced" forms for extracting the n×n orthogonal
  76. // and upper triangular matrices.
  77. // RTo extracts the m×n upper trapezoidal matrix from a QR decomposition.
  78. // If dst is nil, a new matrix is allocated. The resulting dst matrix is returned.
  79. // RTo will panic if the receiver does not contain a factorization.
  80. func (qr *QR) RTo(dst *Dense) *Dense {
  81. if !qr.isValid() {
  82. panic(badQR)
  83. }
  84. r, c := qr.qr.Dims()
  85. if dst == nil {
  86. dst = NewDense(r, c, nil)
  87. } else {
  88. dst.reuseAs(r, c)
  89. }
  90. // Disguise the QR as an upper triangular
  91. t := &TriDense{
  92. mat: blas64.Triangular{
  93. N: c,
  94. Stride: qr.qr.mat.Stride,
  95. Data: qr.qr.mat.Data,
  96. Uplo: blas.Upper,
  97. Diag: blas.NonUnit,
  98. },
  99. cap: qr.qr.capCols,
  100. }
  101. dst.Copy(t)
  102. // Zero below the triangular.
  103. for i := r; i < c; i++ {
  104. zero(dst.mat.Data[i*dst.mat.Stride : i*dst.mat.Stride+c])
  105. }
  106. return dst
  107. }
  108. // QTo extracts the m×m orthonormal matrix Q from a QR decomposition.
  109. // If dst is nil, a new matrix is allocated. The resulting Q matrix is returned.
  110. // QTo will panic if the receiver does not contain a factorization.
  111. func (qr *QR) QTo(dst *Dense) *Dense {
  112. if !qr.isValid() {
  113. panic(badQR)
  114. }
  115. r, _ := qr.qr.Dims()
  116. if dst == nil {
  117. dst = NewDense(r, r, nil)
  118. } else {
  119. dst.reuseAsZeroed(r, r)
  120. }
  121. // Set Q = I.
  122. for i := 0; i < r*r; i += r + 1 {
  123. dst.mat.Data[i] = 1
  124. }
  125. // Construct Q from the elementary reflectors.
  126. work := []float64{0}
  127. lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, -1)
  128. work = getFloats(int(work[0]), false)
  129. lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, len(work))
  130. putFloats(work)
  131. return dst
  132. }
  133. // SolveTo finds a minimum-norm solution to a system of linear equations defined
  134. // by the matrices A and b, where A is an m×n matrix represented in its QR factorized
  135. // form. If A is singular or near-singular a Condition error is returned.
  136. // See the documentation for Condition for more information.
  137. //
  138. // The minimization problem solved depends on the input parameters.
  139. // If trans == false, find X such that ||A*X - B||_2 is minimized.
  140. // If trans == true, find the minimum norm solution of A^T * X = B.
  141. // The solution matrix, X, is stored in place into dst.
  142. // SolveTo will panic if the receiver does not contain a factorization.
  143. func (qr *QR) SolveTo(dst *Dense, trans bool, b Matrix) error {
  144. if !qr.isValid() {
  145. panic(badQR)
  146. }
  147. r, c := qr.qr.Dims()
  148. br, bc := b.Dims()
  149. // The QR solve algorithm stores the result in-place into the right hand side.
  150. // The storage for the answer must be large enough to hold both b and x.
  151. // However, this method's receiver must be the size of x. Copy b, and then
  152. // copy the result into m at the end.
  153. if trans {
  154. if c != br {
  155. panic(ErrShape)
  156. }
  157. dst.reuseAs(r, bc)
  158. } else {
  159. if r != br {
  160. panic(ErrShape)
  161. }
  162. dst.reuseAs(c, bc)
  163. }
  164. // Do not need to worry about overlap between m and b because x has its own
  165. // independent storage.
  166. w := getWorkspace(max(r, c), bc, false)
  167. w.Copy(b)
  168. t := qr.qr.asTriDense(qr.qr.mat.Cols, blas.NonUnit, blas.Upper).mat
  169. if trans {
  170. ok := lapack64.Trtrs(blas.Trans, t, w.mat)
  171. if !ok {
  172. return Condition(math.Inf(1))
  173. }
  174. for i := c; i < r; i++ {
  175. zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc])
  176. }
  177. work := []float64{0}
  178. lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, -1)
  179. work = getFloats(int(work[0]), false)
  180. lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, len(work))
  181. putFloats(work)
  182. } else {
  183. work := []float64{0}
  184. lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, -1)
  185. work = getFloats(int(work[0]), false)
  186. lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, len(work))
  187. putFloats(work)
  188. ok := lapack64.Trtrs(blas.NoTrans, t, w.mat)
  189. if !ok {
  190. return Condition(math.Inf(1))
  191. }
  192. }
  193. // X was set above to be the correct size for the result.
  194. dst.Copy(w)
  195. putWorkspace(w)
  196. if qr.cond > ConditionTolerance {
  197. return Condition(qr.cond)
  198. }
  199. return nil
  200. }
  201. // SolveVecTo finds a minimum-norm solution to a system of linear equations,
  202. // Ax = b.
  203. // See QR.SolveTo for the full documentation.
  204. // SolveVecTo will panic if the receiver does not contain a factorization.
  205. func (qr *QR) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
  206. if !qr.isValid() {
  207. panic(badQR)
  208. }
  209. r, c := qr.qr.Dims()
  210. if _, bc := b.Dims(); bc != 1 {
  211. panic(ErrShape)
  212. }
  213. // The Solve implementation is non-trivial, so rather than duplicate the code,
  214. // instead recast the VecDenses as Dense and call the matrix code.
  215. bm := Matrix(b)
  216. if rv, ok := b.(RawVectorer); ok {
  217. bmat := rv.RawVector()
  218. if dst != b {
  219. dst.checkOverlap(bmat)
  220. }
  221. b := VecDense{mat: bmat}
  222. bm = b.asDense()
  223. }
  224. if trans {
  225. dst.reuseAs(r)
  226. } else {
  227. dst.reuseAs(c)
  228. }
  229. return qr.SolveTo(dst.asDense(), trans, bm)
  230. }