solve.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "gonum.org/v1/gonum/blas"
  7. "gonum.org/v1/gonum/blas/blas64"
  8. "gonum.org/v1/gonum/lapack/lapack64"
  9. )
  10. // Solve finds a minimum-norm solution to a system of linear equations defined
  11. // by the matrices A and B. If A is singular or near-singular, a Condition error
  12. // is returned. See the documentation for Condition for more information.
  13. //
  14. // The minimization problem solved depends on the input parameters:
  15. // - if m >= n, find X such that ||A*X - B||_2 is minimized,
  16. // - if m < n, find the minimum norm solution of A * X = B.
  17. // The solution matrix, X, is stored in-place into the receiver.
  18. func (m *Dense) Solve(a, b Matrix) error {
  19. ar, ac := a.Dims()
  20. br, bc := b.Dims()
  21. if ar != br {
  22. panic(ErrShape)
  23. }
  24. m.reuseAs(ac, bc)
  25. // TODO(btracey): Add special cases for SymDense, etc.
  26. aU, aTrans := untranspose(a)
  27. bU, bTrans := untranspose(b)
  28. switch rma := aU.(type) {
  29. case RawTriangular:
  30. side := blas.Left
  31. tA := blas.NoTrans
  32. if aTrans {
  33. tA = blas.Trans
  34. }
  35. switch rm := bU.(type) {
  36. case RawMatrixer:
  37. if m != bU || bTrans {
  38. if m == bU || m.checkOverlap(rm.RawMatrix()) {
  39. tmp := getWorkspace(br, bc, false)
  40. tmp.Copy(b)
  41. m.Copy(tmp)
  42. putWorkspace(tmp)
  43. break
  44. }
  45. m.Copy(b)
  46. }
  47. default:
  48. if m != bU {
  49. m.Copy(b)
  50. } else if bTrans {
  51. // m and b share data so Copy cannot be used directly.
  52. tmp := getWorkspace(br, bc, false)
  53. tmp.Copy(b)
  54. m.Copy(tmp)
  55. putWorkspace(tmp)
  56. }
  57. }
  58. rm := rma.RawTriangular()
  59. blas64.Trsm(side, tA, 1, rm, m.mat)
  60. work := getFloats(3*rm.N, false)
  61. iwork := getInts(rm.N, false)
  62. cond := lapack64.Trcon(CondNorm, rm, work, iwork)
  63. putFloats(work)
  64. putInts(iwork)
  65. if cond > ConditionTolerance {
  66. return Condition(cond)
  67. }
  68. return nil
  69. }
  70. switch {
  71. case ar == ac:
  72. if a == b {
  73. // x = I.
  74. if ar == 1 {
  75. m.mat.Data[0] = 1
  76. return nil
  77. }
  78. for i := 0; i < ar; i++ {
  79. v := m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+ac]
  80. zero(v)
  81. v[i] = 1
  82. }
  83. return nil
  84. }
  85. var lu LU
  86. lu.Factorize(a)
  87. return lu.SolveTo(m, false, b)
  88. case ar > ac:
  89. var qr QR
  90. qr.Factorize(a)
  91. return qr.SolveTo(m, false, b)
  92. default:
  93. var lq LQ
  94. lq.Factorize(a)
  95. return lq.SolveTo(m, false, b)
  96. }
  97. }
  98. // SolveVec finds a minimum-norm solution to a system of linear equations defined
  99. // by the matrix a and the right-hand side column vector b. If A is singular or
  100. // near-singular, a Condition error is returned. See the documentation for
  101. // Dense.Solve for more information.
  102. func (v *VecDense) SolveVec(a Matrix, b Vector) error {
  103. if _, bc := b.Dims(); bc != 1 {
  104. panic(ErrShape)
  105. }
  106. _, c := a.Dims()
  107. // The Solve implementation is non-trivial, so rather than duplicate the code,
  108. // instead recast the VecDenses as Dense and call the matrix code.
  109. if rv, ok := b.(RawVectorer); ok {
  110. bmat := rv.RawVector()
  111. if v != b {
  112. v.checkOverlap(bmat)
  113. }
  114. v.reuseAs(c)
  115. m := v.asDense()
  116. // We conditionally create bm as m when b and v are identical
  117. // to prevent the overlap detection code from identifying m
  118. // and bm as overlapping but not identical.
  119. bm := m
  120. if v != b {
  121. b := VecDense{mat: bmat}
  122. bm = b.asDense()
  123. }
  124. return m.Solve(a, bm)
  125. }
  126. v.reuseAs(c)
  127. m := v.asDense()
  128. return m.Solve(a, b)
  129. }