lu.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. "gonum.org/v1/gonum/floats"
  10. "gonum.org/v1/gonum/lapack"
  11. "gonum.org/v1/gonum/lapack/lapack64"
  12. )
  13. const (
  14. badSliceLength = "mat: improper slice length"
  15. badLU = "mat: invalid LU factorization"
  16. )
  17. // LU is a type for creating and using the LU factorization of a matrix.
  18. type LU struct {
  19. lu *Dense
  20. pivot []int
  21. cond float64
  22. }
  23. // updateCond updates the stored condition number of the matrix. anorm is the
  24. // norm of the original matrix. If anorm is negative it will be estimated.
  25. func (lu *LU) updateCond(anorm float64, norm lapack.MatrixNorm) {
  26. n := lu.lu.mat.Cols
  27. work := getFloats(4*n, false)
  28. defer putFloats(work)
  29. iwork := getInts(n, false)
  30. defer putInts(iwork)
  31. if anorm < 0 {
  32. // This is an approximation. By the definition of a norm,
  33. // |AB| <= |A| |B|.
  34. // Since A = L*U, we get for the condition number κ that
  35. // κ(A) := |A| |A^-1| = |L*U| |A^-1| <= |L| |U| |A^-1|,
  36. // so this will overestimate the condition number somewhat.
  37. // The norm of the original factorized matrix cannot be stored
  38. // because of update possibilities.
  39. u := lu.lu.asTriDense(n, blas.NonUnit, blas.Upper)
  40. l := lu.lu.asTriDense(n, blas.Unit, blas.Lower)
  41. unorm := lapack64.Lantr(norm, u.mat, work)
  42. lnorm := lapack64.Lantr(norm, l.mat, work)
  43. anorm = unorm * lnorm
  44. }
  45. v := lapack64.Gecon(norm, lu.lu.mat, anorm, work, iwork)
  46. lu.cond = 1 / v
  47. }
  48. // Factorize computes the LU factorization of the square matrix a and stores the
  49. // result. The LU decomposition will complete regardless of the singularity of a.
  50. //
  51. // The LU factorization is computed with pivoting, and so really the decomposition
  52. // is a PLU decomposition where P is a permutation matrix. The individual matrix
  53. // factors can be extracted from the factorization using the Permutation method
  54. // on Dense, and the LU LTo and UTo methods.
  55. func (lu *LU) Factorize(a Matrix) {
  56. lu.factorize(a, CondNorm)
  57. }
  58. func (lu *LU) factorize(a Matrix, norm lapack.MatrixNorm) {
  59. r, c := a.Dims()
  60. if r != c {
  61. panic(ErrSquare)
  62. }
  63. if lu.lu == nil {
  64. lu.lu = NewDense(r, r, nil)
  65. } else {
  66. lu.lu.Reset()
  67. lu.lu.reuseAs(r, r)
  68. }
  69. lu.lu.Copy(a)
  70. if cap(lu.pivot) < r {
  71. lu.pivot = make([]int, r)
  72. }
  73. lu.pivot = lu.pivot[:r]
  74. work := getFloats(r, false)
  75. anorm := lapack64.Lange(norm, lu.lu.mat, work)
  76. putFloats(work)
  77. lapack64.Getrf(lu.lu.mat, lu.pivot)
  78. lu.updateCond(anorm, norm)
  79. }
  80. // isValid returns whether the receiver contains a factorization.
  81. func (lu *LU) isValid() bool {
  82. return lu.lu != nil && !lu.lu.IsZero()
  83. }
  84. // Cond returns the condition number for the factorized matrix.
  85. // Cond will panic if the receiver does not contain a factorization.
  86. func (lu *LU) Cond() float64 {
  87. if !lu.isValid() {
  88. panic(badLU)
  89. }
  90. return lu.cond
  91. }
  92. // Reset resets the factorization so that it can be reused as the receiver of a
  93. // dimensionally restricted operation.
  94. func (lu *LU) Reset() {
  95. if lu.lu != nil {
  96. lu.lu.Reset()
  97. }
  98. lu.pivot = lu.pivot[:0]
  99. }
  100. func (lu *LU) isZero() bool {
  101. return len(lu.pivot) == 0
  102. }
  103. // Det returns the determinant of the matrix that has been factorized. In many
  104. // expressions, using LogDet will be more numerically stable.
  105. // Det will panic if the receiver does not contain a factorization.
  106. func (lu *LU) Det() float64 {
  107. det, sign := lu.LogDet()
  108. return math.Exp(det) * sign
  109. }
  110. // LogDet returns the log of the determinant and the sign of the determinant
  111. // for the matrix that has been factorized. Numerical stability in product and
  112. // division expressions is generally improved by working in log space.
  113. // LogDet will panic if the receiver does not contain a factorization.
  114. func (lu *LU) LogDet() (det float64, sign float64) {
  115. if !lu.isValid() {
  116. panic(badLU)
  117. }
  118. _, n := lu.lu.Dims()
  119. logDiag := getFloats(n, false)
  120. defer putFloats(logDiag)
  121. sign = 1.0
  122. for i := 0; i < n; i++ {
  123. v := lu.lu.at(i, i)
  124. if v < 0 {
  125. sign *= -1
  126. }
  127. if lu.pivot[i] != i {
  128. sign *= -1
  129. }
  130. logDiag[i] = math.Log(math.Abs(v))
  131. }
  132. return floats.Sum(logDiag), sign
  133. }
  134. // Pivot returns pivot indices that enable the construction of the permutation
  135. // matrix P (see Dense.Permutation). If swaps == nil, then new memory will be
  136. // allocated, otherwise the length of the input must be equal to the size of the
  137. // factorized matrix.
  138. // Pivot will panic if the receiver does not contain a factorization.
  139. func (lu *LU) Pivot(swaps []int) []int {
  140. if !lu.isValid() {
  141. panic(badLU)
  142. }
  143. _, n := lu.lu.Dims()
  144. if swaps == nil {
  145. swaps = make([]int, n)
  146. }
  147. if len(swaps) != n {
  148. panic(badSliceLength)
  149. }
  150. // Perform the inverse of the row swaps in order to find the final
  151. // row swap position.
  152. for i := range swaps {
  153. swaps[i] = i
  154. }
  155. for i := n - 1; i >= 0; i-- {
  156. v := lu.pivot[i]
  157. swaps[i], swaps[v] = swaps[v], swaps[i]
  158. }
  159. return swaps
  160. }
  161. // RankOne updates an LU factorization as if a rank-one update had been applied to
  162. // the original matrix A, storing the result into the receiver. That is, if in
  163. // the original LU decomposition P * L * U = A, in the updated decomposition
  164. // P * L * U = A + alpha * x * y^T.
  165. // RankOne will panic if orig does not contain a factorization.
  166. func (lu *LU) RankOne(orig *LU, alpha float64, x, y Vector) {
  167. if !orig.isValid() {
  168. panic(badLU)
  169. }
  170. // RankOne uses algorithm a1 on page 28 of "Multiple-Rank Updates to Matrix
  171. // Factorizations for Nonlinear Analysis and Circuit Design" by Linzhong Deng.
  172. // http://web.stanford.edu/group/SOL/dissertations/Linzhong-Deng-thesis.pdf
  173. _, n := orig.lu.Dims()
  174. if r, c := x.Dims(); r != n || c != 1 {
  175. panic(ErrShape)
  176. }
  177. if r, c := y.Dims(); r != n || c != 1 {
  178. panic(ErrShape)
  179. }
  180. if orig != lu {
  181. if lu.isZero() {
  182. if cap(lu.pivot) < n {
  183. lu.pivot = make([]int, n)
  184. }
  185. lu.pivot = lu.pivot[:n]
  186. if lu.lu == nil {
  187. lu.lu = NewDense(n, n, nil)
  188. } else {
  189. lu.lu.reuseAs(n, n)
  190. }
  191. } else if len(lu.pivot) != n {
  192. panic(ErrShape)
  193. }
  194. copy(lu.pivot, orig.pivot)
  195. lu.lu.Copy(orig.lu)
  196. }
  197. xs := getFloats(n, false)
  198. defer putFloats(xs)
  199. ys := getFloats(n, false)
  200. defer putFloats(ys)
  201. for i := 0; i < n; i++ {
  202. xs[i] = x.AtVec(i)
  203. ys[i] = y.AtVec(i)
  204. }
  205. // Adjust for the pivoting in the LU factorization
  206. for i, v := range lu.pivot {
  207. xs[i], xs[v] = xs[v], xs[i]
  208. }
  209. lum := lu.lu.mat
  210. omega := alpha
  211. for j := 0; j < n; j++ {
  212. ujj := lum.Data[j*lum.Stride+j]
  213. ys[j] /= ujj
  214. theta := 1 + xs[j]*ys[j]*omega
  215. beta := omega * ys[j] / theta
  216. gamma := omega * xs[j]
  217. omega -= beta * gamma
  218. lum.Data[j*lum.Stride+j] *= theta
  219. for i := j + 1; i < n; i++ {
  220. xs[i] -= lum.Data[i*lum.Stride+j] * xs[j]
  221. tmp := ys[i]
  222. ys[i] -= lum.Data[j*lum.Stride+i] * ys[j]
  223. lum.Data[i*lum.Stride+j] += beta * xs[i]
  224. lum.Data[j*lum.Stride+i] += gamma * tmp
  225. }
  226. }
  227. lu.updateCond(-1, CondNorm)
  228. }
  229. // LTo extracts the lower triangular matrix from an LU factorization.
  230. // If dst is nil, a new matrix is allocated. The resulting L matrix is returned.
  231. // LTo will panic if the receiver does not contain a factorization.
  232. func (lu *LU) LTo(dst *TriDense) *TriDense {
  233. if !lu.isValid() {
  234. panic(badLU)
  235. }
  236. _, n := lu.lu.Dims()
  237. if dst == nil {
  238. dst = NewTriDense(n, Lower, nil)
  239. } else {
  240. dst.reuseAs(n, Lower)
  241. }
  242. // Extract the lower triangular elements.
  243. for i := 0; i < n; i++ {
  244. for j := 0; j < i; j++ {
  245. dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j]
  246. }
  247. }
  248. // Set ones on the diagonal.
  249. for i := 0; i < n; i++ {
  250. dst.mat.Data[i*dst.mat.Stride+i] = 1
  251. }
  252. return dst
  253. }
  254. // UTo extracts the upper triangular matrix from an LU factorization.
  255. // If dst is nil, a new matrix is allocated. The resulting U matrix is returned.
  256. // UTo will panic if the receiver does not contain a factorization.
  257. func (lu *LU) UTo(dst *TriDense) *TriDense {
  258. if !lu.isValid() {
  259. panic(badLU)
  260. }
  261. _, n := lu.lu.Dims()
  262. if dst == nil {
  263. dst = NewTriDense(n, Upper, nil)
  264. } else {
  265. dst.reuseAs(n, Upper)
  266. }
  267. // Extract the upper triangular elements.
  268. for i := 0; i < n; i++ {
  269. for j := i; j < n; j++ {
  270. dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j]
  271. }
  272. }
  273. return dst
  274. }
  275. // Permutation constructs an r×r permutation matrix with the given row swaps.
  276. // A permutation matrix has exactly one element equal to one in each row and column
  277. // and all other elements equal to zero. swaps[i] specifies the row with which
  278. // i will be swapped, which is equivalent to the non-zero column of row i.
  279. func (m *Dense) Permutation(r int, swaps []int) {
  280. m.reuseAs(r, r)
  281. for i := 0; i < r; i++ {
  282. zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+r])
  283. v := swaps[i]
  284. if v < 0 || v >= r {
  285. panic(ErrRowAccess)
  286. }
  287. m.mat.Data[i*m.mat.Stride+v] = 1
  288. }
  289. }
  290. // SolveTo solves a system of linear equations using the LU decomposition of a matrix.
  291. // It computes
  292. // A * X = B if trans == false
  293. // A^T * X = B if trans == true
  294. // In both cases, A is represented in LU factorized form, and the matrix X is
  295. // stored into dst.
  296. //
  297. // If A is singular or near-singular a Condition error is returned. See
  298. // the documentation for Condition for more information.
  299. // SolveTo will panic if the receiver does not contain a factorization.
  300. func (lu *LU) SolveTo(dst *Dense, trans bool, b Matrix) error {
  301. if !lu.isValid() {
  302. panic(badLU)
  303. }
  304. _, n := lu.lu.Dims()
  305. br, bc := b.Dims()
  306. if br != n {
  307. panic(ErrShape)
  308. }
  309. // TODO(btracey): Should test the condition number instead of testing that
  310. // the determinant is exactly zero.
  311. if lu.Det() == 0 {
  312. return Condition(math.Inf(1))
  313. }
  314. dst.reuseAs(n, bc)
  315. bU, _ := untranspose(b)
  316. var restore func()
  317. if dst == bU {
  318. dst, restore = dst.isolatedWorkspace(bU)
  319. defer restore()
  320. } else if rm, ok := bU.(RawMatrixer); ok {
  321. dst.checkOverlap(rm.RawMatrix())
  322. }
  323. dst.Copy(b)
  324. t := blas.NoTrans
  325. if trans {
  326. t = blas.Trans
  327. }
  328. lapack64.Getrs(t, lu.lu.mat, dst.mat, lu.pivot)
  329. if lu.cond > ConditionTolerance {
  330. return Condition(lu.cond)
  331. }
  332. return nil
  333. }
  334. // SolveVecTo solves a system of linear equations using the LU decomposition of a matrix.
  335. // It computes
  336. // A * x = b if trans == false
  337. // A^T * x = b if trans == true
  338. // In both cases, A is represented in LU factorized form, and the vector x is
  339. // stored into dst.
  340. //
  341. // If A is singular or near-singular a Condition error is returned. See
  342. // the documentation for Condition for more information.
  343. // SolveVecTo will panic if the receiver does not contain a factorization.
  344. func (lu *LU) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
  345. if !lu.isValid() {
  346. panic(badLU)
  347. }
  348. _, n := lu.lu.Dims()
  349. if br, bc := b.Dims(); br != n || bc != 1 {
  350. panic(ErrShape)
  351. }
  352. switch rv := b.(type) {
  353. default:
  354. dst.reuseAs(n)
  355. return lu.SolveTo(dst.asDense(), trans, b)
  356. case RawVectorer:
  357. if dst != b {
  358. dst.checkOverlap(rv.RawVector())
  359. }
  360. // TODO(btracey): Should test the condition number instead of testing that
  361. // the determinant is exactly zero.
  362. if lu.Det() == 0 {
  363. return Condition(math.Inf(1))
  364. }
  365. dst.reuseAs(n)
  366. var restore func()
  367. if dst == b {
  368. dst, restore = dst.isolatedWorkspace(b)
  369. defer restore()
  370. }
  371. dst.CopyVec(b)
  372. vMat := blas64.General{
  373. Rows: n,
  374. Cols: 1,
  375. Stride: dst.mat.Inc,
  376. Data: dst.mat.Data,
  377. }
  378. t := blas.NoTrans
  379. if trans {
  380. t = blas.Trans
  381. }
  382. lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot)
  383. if lu.cond > ConditionTolerance {
  384. return Condition(lu.cond)
  385. }
  386. return nil
  387. }
  388. }