matrix.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. "gonum.org/v1/gonum/floats"
  10. "gonum.org/v1/gonum/lapack"
  11. "gonum.org/v1/gonum/lapack/lapack64"
  12. )
  13. // Matrix is the basic matrix interface type.
  14. type Matrix interface {
  15. // Dims returns the dimensions of a Matrix.
  16. Dims() (r, c int)
  17. // At returns the value of a matrix element at row i, column j.
  18. // It will panic if i or j are out of bounds for the matrix.
  19. At(i, j int) float64
  20. // T returns the transpose of the Matrix. Whether T returns a copy of the
  21. // underlying data is implementation dependent.
  22. // This method may be implemented using the Transpose type, which
  23. // provides an implicit matrix transpose.
  24. T() Matrix
  25. }
  26. var (
  27. _ Matrix = Transpose{}
  28. _ Untransposer = Transpose{}
  29. )
  30. // Transpose is a type for performing an implicit matrix transpose. It implements
  31. // the Matrix interface, returning values from the transpose of the matrix within.
  32. type Transpose struct {
  33. Matrix Matrix
  34. }
  35. // At returns the value of the element at row i and column j of the transposed
  36. // matrix, that is, row j and column i of the Matrix field.
  37. func (t Transpose) At(i, j int) float64 {
  38. return t.Matrix.At(j, i)
  39. }
  40. // Dims returns the dimensions of the transposed matrix. The number of rows returned
  41. // is the number of columns in the Matrix field, and the number of columns is
  42. // the number of rows in the Matrix field.
  43. func (t Transpose) Dims() (r, c int) {
  44. c, r = t.Matrix.Dims()
  45. return r, c
  46. }
  47. // T performs an implicit transpose by returning the Matrix field.
  48. func (t Transpose) T() Matrix {
  49. return t.Matrix
  50. }
  51. // Untranspose returns the Matrix field.
  52. func (t Transpose) Untranspose() Matrix {
  53. return t.Matrix
  54. }
  55. // Untransposer is a type that can undo an implicit transpose.
  56. type Untransposer interface {
  57. // Note: This interface is needed to unify all of the Transpose types. In
  58. // the mat methods, we need to test if the Matrix has been implicitly
  59. // transposed. If this is checked by testing for the specific Transpose type
  60. // then the behavior will be different if the user uses T() or TTri() for a
  61. // triangular matrix.
  62. // Untranspose returns the underlying Matrix stored for the implicit transpose.
  63. Untranspose() Matrix
  64. }
  65. // UntransposeBander is a type that can undo an implicit band transpose.
  66. type UntransposeBander interface {
  67. // Untranspose returns the underlying Banded stored for the implicit transpose.
  68. UntransposeBand() Banded
  69. }
  70. // UntransposeTrier is a type that can undo an implicit triangular transpose.
  71. type UntransposeTrier interface {
  72. // Untranspose returns the underlying Triangular stored for the implicit transpose.
  73. UntransposeTri() Triangular
  74. }
  75. // UntransposeTriBander is a type that can undo an implicit triangular banded
  76. // transpose.
  77. type UntransposeTriBander interface {
  78. // Untranspose returns the underlying Triangular stored for the implicit transpose.
  79. UntransposeTriBand() TriBanded
  80. }
  81. // Mutable is a matrix interface type that allows elements to be altered.
  82. type Mutable interface {
  83. // Set alters the matrix element at row i, column j to v.
  84. // It will panic if i or j are out of bounds for the matrix.
  85. Set(i, j int, v float64)
  86. Matrix
  87. }
  88. // A RowViewer can return a Vector reflecting a row that is backed by the matrix
  89. // data. The Vector returned will have length equal to the number of columns.
  90. type RowViewer interface {
  91. RowView(i int) Vector
  92. }
  93. // A RawRowViewer can return a slice of float64 reflecting a row that is backed by the matrix
  94. // data.
  95. type RawRowViewer interface {
  96. RawRowView(i int) []float64
  97. }
  98. // A ColViewer can return a Vector reflecting a column that is backed by the matrix
  99. // data. The Vector returned will have length equal to the number of rows.
  100. type ColViewer interface {
  101. ColView(j int) Vector
  102. }
  103. // A RawColViewer can return a slice of float64 reflecting a column that is backed by the matrix
  104. // data.
  105. type RawColViewer interface {
  106. RawColView(j int) []float64
  107. }
  108. // A Cloner can make a copy of a into the receiver, overwriting the previous value of the
  109. // receiver. The clone operation does not make any restriction on shape and will not cause
  110. // shadowing.
  111. type Cloner interface {
  112. Clone(a Matrix)
  113. }
  114. // A Reseter can reset the matrix so that it can be reused as the receiver of a dimensionally
  115. // restricted operation. This is commonly used when the matrix is being used as a workspace
  116. // or temporary matrix.
  117. //
  118. // If the matrix is a view, using the reset matrix may result in data corruption in elements
  119. // outside the view.
  120. type Reseter interface {
  121. Reset()
  122. }
  123. // A Copier can make a copy of elements of a into the receiver. The submatrix copied
  124. // starts at row and column 0 and has dimensions equal to the minimum dimensions of
  125. // the two matrices. The number of row and columns copied is returned.
  126. // Copy will copy from a source that aliases the receiver unless the source is transposed;
  127. // an aliasing transpose copy will panic with the exception for a special case when
  128. // the source data has a unitary increment or stride.
  129. type Copier interface {
  130. Copy(a Matrix) (r, c int)
  131. }
  132. // A Grower can grow the size of the represented matrix by the given number of rows and columns.
  133. // Growing beyond the size given by the Caps method will result in the allocation of a new
  134. // matrix and copying of the elements. If Grow is called with negative increments it will
  135. // panic with ErrIndexOutOfRange.
  136. type Grower interface {
  137. Caps() (r, c int)
  138. Grow(r, c int) Matrix
  139. }
  140. // A BandWidther represents a banded matrix and can return the left and right half-bandwidths, k1 and
  141. // k2.
  142. type BandWidther interface {
  143. BandWidth() (k1, k2 int)
  144. }
  145. // A RawMatrixSetter can set the underlying blas64.General used by the receiver. There is no restriction
  146. // on the shape of the receiver. Changes to the receiver's elements will be reflected in the blas64.General.Data.
  147. type RawMatrixSetter interface {
  148. SetRawMatrix(a blas64.General)
  149. }
  150. // A RawMatrixer can return a blas64.General representation of the receiver. Changes to the blas64.General.Data
  151. // slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
  152. type RawMatrixer interface {
  153. RawMatrix() blas64.General
  154. }
  155. // A RawVectorer can return a blas64.Vector representation of the receiver. Changes to the blas64.Vector.Data
  156. // slice will be reflected in the original matrix, changes to the Inc field will not.
  157. type RawVectorer interface {
  158. RawVector() blas64.Vector
  159. }
  160. // A NonZeroDoer can call a function for each non-zero element of the receiver.
  161. // The parameters of the function are the element indices and its value.
  162. type NonZeroDoer interface {
  163. DoNonZero(func(i, j int, v float64))
  164. }
  165. // A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver.
  166. // The parameters of the function are the element indices and its value.
  167. type RowNonZeroDoer interface {
  168. DoRowNonZero(i int, fn func(i, j int, v float64))
  169. }
  170. // A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver.
  171. // The parameters of the function are the element indices and its value.
  172. type ColNonZeroDoer interface {
  173. DoColNonZero(j int, fn func(i, j int, v float64))
  174. }
  175. // untranspose untransposes a matrix if applicable. If a is an Untransposer, then
  176. // untranspose returns the underlying matrix and true. If it is not, then it returns
  177. // the input matrix and false.
  178. func untranspose(a Matrix) (Matrix, bool) {
  179. if ut, ok := a.(Untransposer); ok {
  180. return ut.Untranspose(), true
  181. }
  182. return a, false
  183. }
  184. // untransposeExtract returns an untransposed matrix in a built-in matrix type.
  185. //
  186. // The untransposed matrix is returned unaltered if it is a built-in matrix type.
  187. // Otherwise, if it implements a Raw method, an appropriate built-in type value
  188. // is returned holding the raw matrix value of the input. If neither of these
  189. // is possible, the untransposed matrix is returned.
  190. func untransposeExtract(a Matrix) (Matrix, bool) {
  191. ut, trans := untranspose(a)
  192. switch m := ut.(type) {
  193. case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense:
  194. return m, trans
  195. // TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
  196. case RawSymBander:
  197. rsb := m.RawSymBand()
  198. if rsb.Uplo != blas.Upper {
  199. return ut, trans
  200. }
  201. var sb SymBandDense
  202. sb.SetRawSymBand(rsb)
  203. return &sb, trans
  204. case RawTriBander:
  205. rtb := m.RawTriBand()
  206. if rtb.Diag == blas.Unit {
  207. return ut, trans
  208. }
  209. var tb TriBandDense
  210. tb.SetRawTriBand(rtb)
  211. return &tb, trans
  212. case RawBander:
  213. var b BandDense
  214. b.SetRawBand(m.RawBand())
  215. return &b, trans
  216. case RawTriangular:
  217. rt := m.RawTriangular()
  218. if rt.Diag == blas.Unit {
  219. return ut, trans
  220. }
  221. var t TriDense
  222. t.SetRawTriangular(rt)
  223. return &t, trans
  224. case RawSymmetricer:
  225. rs := m.RawSymmetric()
  226. if rs.Uplo != blas.Upper {
  227. return ut, trans
  228. }
  229. var s SymDense
  230. s.SetRawSymmetric(rs)
  231. return &s, trans
  232. case RawMatrixer:
  233. var d Dense
  234. d.SetRawMatrix(m.RawMatrix())
  235. return &d, trans
  236. default:
  237. return ut, trans
  238. }
  239. }
  240. // TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
  241. // TODO(btracey): Add in fast paths to Row/Col for the other concrete types
  242. // (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
  243. // Col copies the elements in the jth column of the matrix into the slice dst.
  244. // The length of the provided slice must equal the number of rows, unless the
  245. // slice is nil in which case a new slice is first allocated.
  246. func Col(dst []float64, j int, a Matrix) []float64 {
  247. r, c := a.Dims()
  248. if j < 0 || j >= c {
  249. panic(ErrColAccess)
  250. }
  251. if dst == nil {
  252. dst = make([]float64, r)
  253. } else {
  254. if len(dst) != r {
  255. panic(ErrColLength)
  256. }
  257. }
  258. aU, aTrans := untranspose(a)
  259. if rm, ok := aU.(RawMatrixer); ok {
  260. m := rm.RawMatrix()
  261. if aTrans {
  262. copy(dst, m.Data[j*m.Stride:j*m.Stride+m.Cols])
  263. return dst
  264. }
  265. blas64.Copy(blas64.Vector{N: r, Inc: m.Stride, Data: m.Data[j:]},
  266. blas64.Vector{N: r, Inc: 1, Data: dst},
  267. )
  268. return dst
  269. }
  270. for i := 0; i < r; i++ {
  271. dst[i] = a.At(i, j)
  272. }
  273. return dst
  274. }
  275. // Row copies the elements in the ith row of the matrix into the slice dst.
  276. // The length of the provided slice must equal the number of columns, unless the
  277. // slice is nil in which case a new slice is first allocated.
  278. func Row(dst []float64, i int, a Matrix) []float64 {
  279. r, c := a.Dims()
  280. if i < 0 || i >= r {
  281. panic(ErrColAccess)
  282. }
  283. if dst == nil {
  284. dst = make([]float64, c)
  285. } else {
  286. if len(dst) != c {
  287. panic(ErrRowLength)
  288. }
  289. }
  290. aU, aTrans := untranspose(a)
  291. if rm, ok := aU.(RawMatrixer); ok {
  292. m := rm.RawMatrix()
  293. if aTrans {
  294. blas64.Copy(blas64.Vector{N: c, Inc: m.Stride, Data: m.Data[i:]},
  295. blas64.Vector{N: c, Inc: 1, Data: dst},
  296. )
  297. return dst
  298. }
  299. copy(dst, m.Data[i*m.Stride:i*m.Stride+m.Cols])
  300. return dst
  301. }
  302. for j := 0; j < c; j++ {
  303. dst[j] = a.At(i, j)
  304. }
  305. return dst
  306. }
  307. // Cond returns the condition number of the given matrix under the given norm.
  308. // The condition number must be based on the 1-norm, 2-norm or ∞-norm.
  309. // Cond will panic with matrix.ErrShape if the matrix has zero size.
  310. //
  311. // BUG(btracey): The computation of the 1-norm and ∞-norm for non-square matrices
  312. // is innacurate, although is typically the right order of magnitude. See
  313. // https://github.com/xianyi/OpenBLAS/issues/636. While the value returned will
  314. // change with the resolution of this bug, the result from Cond will match the
  315. // condition number used internally.
  316. func Cond(a Matrix, norm float64) float64 {
  317. m, n := a.Dims()
  318. if m == 0 || n == 0 {
  319. panic(ErrShape)
  320. }
  321. var lnorm lapack.MatrixNorm
  322. switch norm {
  323. default:
  324. panic("mat: bad norm value")
  325. case 1:
  326. lnorm = lapack.MaxColumnSum
  327. case 2:
  328. var svd SVD
  329. ok := svd.Factorize(a, SVDNone)
  330. if !ok {
  331. return math.Inf(1)
  332. }
  333. return svd.Cond()
  334. case math.Inf(1):
  335. lnorm = lapack.MaxRowSum
  336. }
  337. if m == n {
  338. // Use the LU decomposition to compute the condition number.
  339. var lu LU
  340. lu.factorize(a, lnorm)
  341. return lu.Cond()
  342. }
  343. if m > n {
  344. // Use the QR factorization to compute the condition number.
  345. var qr QR
  346. qr.factorize(a, lnorm)
  347. return qr.Cond()
  348. }
  349. // Use the LQ factorization to compute the condition number.
  350. var lq LQ
  351. lq.factorize(a, lnorm)
  352. return lq.Cond()
  353. }
  354. // Det returns the determinant of the matrix a. In many expressions using LogDet
  355. // will be more numerically stable.
  356. func Det(a Matrix) float64 {
  357. det, sign := LogDet(a)
  358. return math.Exp(det) * sign
  359. }
  360. // Dot returns the sum of the element-wise product of a and b.
  361. // Dot panics if the matrix sizes are unequal.
  362. func Dot(a, b Vector) float64 {
  363. la := a.Len()
  364. lb := b.Len()
  365. if la != lb {
  366. panic(ErrShape)
  367. }
  368. if arv, ok := a.(RawVectorer); ok {
  369. if brv, ok := b.(RawVectorer); ok {
  370. return blas64.Dot(arv.RawVector(), brv.RawVector())
  371. }
  372. }
  373. var sum float64
  374. for i := 0; i < la; i++ {
  375. sum += a.At(i, 0) * b.At(i, 0)
  376. }
  377. return sum
  378. }
  379. // Equal returns whether the matrices a and b have the same size
  380. // and are element-wise equal.
  381. func Equal(a, b Matrix) bool {
  382. ar, ac := a.Dims()
  383. br, bc := b.Dims()
  384. if ar != br || ac != bc {
  385. return false
  386. }
  387. aU, aTrans := untranspose(a)
  388. bU, bTrans := untranspose(b)
  389. if rma, ok := aU.(RawMatrixer); ok {
  390. if rmb, ok := bU.(RawMatrixer); ok {
  391. ra := rma.RawMatrix()
  392. rb := rmb.RawMatrix()
  393. if aTrans == bTrans {
  394. for i := 0; i < ra.Rows; i++ {
  395. for j := 0; j < ra.Cols; j++ {
  396. if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
  397. return false
  398. }
  399. }
  400. }
  401. return true
  402. }
  403. for i := 0; i < ra.Rows; i++ {
  404. for j := 0; j < ra.Cols; j++ {
  405. if ra.Data[i*ra.Stride+j] != rb.Data[j*rb.Stride+i] {
  406. return false
  407. }
  408. }
  409. }
  410. return true
  411. }
  412. }
  413. if rma, ok := aU.(RawSymmetricer); ok {
  414. if rmb, ok := bU.(RawSymmetricer); ok {
  415. ra := rma.RawSymmetric()
  416. rb := rmb.RawSymmetric()
  417. // Symmetric matrices are always upper and equal to their transpose.
  418. for i := 0; i < ra.N; i++ {
  419. for j := i; j < ra.N; j++ {
  420. if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
  421. return false
  422. }
  423. }
  424. }
  425. return true
  426. }
  427. }
  428. if ra, ok := aU.(*VecDense); ok {
  429. if rb, ok := bU.(*VecDense); ok {
  430. // If the raw vectors are the same length they must either both be
  431. // transposed or both not transposed (or have length 1).
  432. for i := 0; i < ra.mat.N; i++ {
  433. if ra.mat.Data[i*ra.mat.Inc] != rb.mat.Data[i*rb.mat.Inc] {
  434. return false
  435. }
  436. }
  437. return true
  438. }
  439. }
  440. for i := 0; i < ar; i++ {
  441. for j := 0; j < ac; j++ {
  442. if a.At(i, j) != b.At(i, j) {
  443. return false
  444. }
  445. }
  446. }
  447. return true
  448. }
  449. // EqualApprox returns whether the matrices a and b have the same size and contain all equal
  450. // elements with tolerance for element-wise equality specified by epsilon. Matrices
  451. // with non-equal shapes are not equal.
  452. func EqualApprox(a, b Matrix, epsilon float64) bool {
  453. ar, ac := a.Dims()
  454. br, bc := b.Dims()
  455. if ar != br || ac != bc {
  456. return false
  457. }
  458. aU, aTrans := untranspose(a)
  459. bU, bTrans := untranspose(b)
  460. if rma, ok := aU.(RawMatrixer); ok {
  461. if rmb, ok := bU.(RawMatrixer); ok {
  462. ra := rma.RawMatrix()
  463. rb := rmb.RawMatrix()
  464. if aTrans == bTrans {
  465. for i := 0; i < ra.Rows; i++ {
  466. for j := 0; j < ra.Cols; j++ {
  467. if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
  468. return false
  469. }
  470. }
  471. }
  472. return true
  473. }
  474. for i := 0; i < ra.Rows; i++ {
  475. for j := 0; j < ra.Cols; j++ {
  476. if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[j*rb.Stride+i], epsilon, epsilon) {
  477. return false
  478. }
  479. }
  480. }
  481. return true
  482. }
  483. }
  484. if rma, ok := aU.(RawSymmetricer); ok {
  485. if rmb, ok := bU.(RawSymmetricer); ok {
  486. ra := rma.RawSymmetric()
  487. rb := rmb.RawSymmetric()
  488. // Symmetric matrices are always upper and equal to their transpose.
  489. for i := 0; i < ra.N; i++ {
  490. for j := i; j < ra.N; j++ {
  491. if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
  492. return false
  493. }
  494. }
  495. }
  496. return true
  497. }
  498. }
  499. if ra, ok := aU.(*VecDense); ok {
  500. if rb, ok := bU.(*VecDense); ok {
  501. // If the raw vectors are the same length they must either both be
  502. // transposed or both not transposed (or have length 1).
  503. for i := 0; i < ra.mat.N; i++ {
  504. if !floats.EqualWithinAbsOrRel(ra.mat.Data[i*ra.mat.Inc], rb.mat.Data[i*rb.mat.Inc], epsilon, epsilon) {
  505. return false
  506. }
  507. }
  508. return true
  509. }
  510. }
  511. for i := 0; i < ar; i++ {
  512. for j := 0; j < ac; j++ {
  513. if !floats.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) {
  514. return false
  515. }
  516. }
  517. }
  518. return true
  519. }
  520. // LogDet returns the log of the determinant and the sign of the determinant
  521. // for the matrix that has been factorized. Numerical stability in product and
  522. // division expressions is generally improved by working in log space.
  523. func LogDet(a Matrix) (det float64, sign float64) {
  524. // TODO(btracey): Add specialized routines for TriDense, etc.
  525. var lu LU
  526. lu.Factorize(a)
  527. return lu.LogDet()
  528. }
  529. // Max returns the largest element value of the matrix A.
  530. // Max will panic with matrix.ErrShape if the matrix has zero size.
  531. func Max(a Matrix) float64 {
  532. r, c := a.Dims()
  533. if r == 0 || c == 0 {
  534. panic(ErrShape)
  535. }
  536. // Max(A) = Max(A^T)
  537. aU, _ := untranspose(a)
  538. switch m := aU.(type) {
  539. case RawMatrixer:
  540. rm := m.RawMatrix()
  541. max := math.Inf(-1)
  542. for i := 0; i < rm.Rows; i++ {
  543. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
  544. if v > max {
  545. max = v
  546. }
  547. }
  548. }
  549. return max
  550. case RawTriangular:
  551. rm := m.RawTriangular()
  552. // The max of a triangular is at least 0 unless the size is 1.
  553. if rm.N == 1 {
  554. return rm.Data[0]
  555. }
  556. max := 0.0
  557. if rm.Uplo == blas.Upper {
  558. for i := 0; i < rm.N; i++ {
  559. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  560. if v > max {
  561. max = v
  562. }
  563. }
  564. }
  565. return max
  566. }
  567. for i := 0; i < rm.N; i++ {
  568. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
  569. if v > max {
  570. max = v
  571. }
  572. }
  573. }
  574. return max
  575. case RawSymmetricer:
  576. rm := m.RawSymmetric()
  577. if rm.Uplo != blas.Upper {
  578. panic(badSymTriangle)
  579. }
  580. max := math.Inf(-1)
  581. for i := 0; i < rm.N; i++ {
  582. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  583. if v > max {
  584. max = v
  585. }
  586. }
  587. }
  588. return max
  589. default:
  590. r, c := aU.Dims()
  591. max := math.Inf(-1)
  592. for i := 0; i < r; i++ {
  593. for j := 0; j < c; j++ {
  594. v := aU.At(i, j)
  595. if v > max {
  596. max = v
  597. }
  598. }
  599. }
  600. return max
  601. }
  602. }
  603. // Min returns the smallest element value of the matrix A.
  604. // Min will panic with matrix.ErrShape if the matrix has zero size.
  605. func Min(a Matrix) float64 {
  606. r, c := a.Dims()
  607. if r == 0 || c == 0 {
  608. panic(ErrShape)
  609. }
  610. // Min(A) = Min(A^T)
  611. aU, _ := untranspose(a)
  612. switch m := aU.(type) {
  613. case RawMatrixer:
  614. rm := m.RawMatrix()
  615. min := math.Inf(1)
  616. for i := 0; i < rm.Rows; i++ {
  617. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
  618. if v < min {
  619. min = v
  620. }
  621. }
  622. }
  623. return min
  624. case RawTriangular:
  625. rm := m.RawTriangular()
  626. // The min of a triangular is at most 0 unless the size is 1.
  627. if rm.N == 1 {
  628. return rm.Data[0]
  629. }
  630. min := 0.0
  631. if rm.Uplo == blas.Upper {
  632. for i := 0; i < rm.N; i++ {
  633. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  634. if v < min {
  635. min = v
  636. }
  637. }
  638. }
  639. return min
  640. }
  641. for i := 0; i < rm.N; i++ {
  642. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
  643. if v < min {
  644. min = v
  645. }
  646. }
  647. }
  648. return min
  649. case RawSymmetricer:
  650. rm := m.RawSymmetric()
  651. if rm.Uplo != blas.Upper {
  652. panic(badSymTriangle)
  653. }
  654. min := math.Inf(1)
  655. for i := 0; i < rm.N; i++ {
  656. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  657. if v < min {
  658. min = v
  659. }
  660. }
  661. }
  662. return min
  663. default:
  664. r, c := aU.Dims()
  665. min := math.Inf(1)
  666. for i := 0; i < r; i++ {
  667. for j := 0; j < c; j++ {
  668. v := aU.At(i, j)
  669. if v < min {
  670. min = v
  671. }
  672. }
  673. }
  674. return min
  675. }
  676. }
  677. // Norm returns the specified (induced) norm of the matrix a. See
  678. // https://en.wikipedia.org/wiki/Matrix_norm for the definition of an induced norm.
  679. //
  680. // Valid norms are:
  681. // 1 - The maximum absolute column sum
  682. // 2 - Frobenius norm, the square root of the sum of the squares of the elements.
  683. // Inf - The maximum absolute row sum.
  684. // Norm will panic with ErrNormOrder if an illegal norm order is specified and
  685. // with matrix.ErrShape if the matrix has zero size.
  686. func Norm(a Matrix, norm float64) float64 {
  687. r, c := a.Dims()
  688. if r == 0 || c == 0 {
  689. panic(ErrShape)
  690. }
  691. aU, aTrans := untranspose(a)
  692. var work []float64
  693. switch rma := aU.(type) {
  694. case RawMatrixer:
  695. rm := rma.RawMatrix()
  696. n := normLapack(norm, aTrans)
  697. if n == lapack.MaxColumnSum {
  698. work = getFloats(rm.Cols, false)
  699. defer putFloats(work)
  700. }
  701. return lapack64.Lange(n, rm, work)
  702. case RawTriangular:
  703. rm := rma.RawTriangular()
  704. n := normLapack(norm, aTrans)
  705. if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
  706. work = getFloats(rm.N, false)
  707. defer putFloats(work)
  708. }
  709. return lapack64.Lantr(n, rm, work)
  710. case RawSymmetricer:
  711. rm := rma.RawSymmetric()
  712. n := normLapack(norm, aTrans)
  713. if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
  714. work = getFloats(rm.N, false)
  715. defer putFloats(work)
  716. }
  717. return lapack64.Lansy(n, rm, work)
  718. case *VecDense:
  719. rv := rma.RawVector()
  720. switch norm {
  721. default:
  722. panic("unreachable")
  723. case 1:
  724. if aTrans {
  725. imax := blas64.Iamax(rv)
  726. return math.Abs(rma.At(imax, 0))
  727. }
  728. return blas64.Asum(rv)
  729. case 2:
  730. return blas64.Nrm2(rv)
  731. case math.Inf(1):
  732. if aTrans {
  733. return blas64.Asum(rv)
  734. }
  735. imax := blas64.Iamax(rv)
  736. return math.Abs(rma.At(imax, 0))
  737. }
  738. }
  739. switch norm {
  740. default:
  741. panic("unreachable")
  742. case 1:
  743. var max float64
  744. for j := 0; j < c; j++ {
  745. var sum float64
  746. for i := 0; i < r; i++ {
  747. sum += math.Abs(a.At(i, j))
  748. }
  749. if sum > max {
  750. max = sum
  751. }
  752. }
  753. return max
  754. case 2:
  755. var sum float64
  756. for i := 0; i < r; i++ {
  757. for j := 0; j < c; j++ {
  758. v := a.At(i, j)
  759. sum += v * v
  760. }
  761. }
  762. return math.Sqrt(sum)
  763. case math.Inf(1):
  764. var max float64
  765. for i := 0; i < r; i++ {
  766. var sum float64
  767. for j := 0; j < c; j++ {
  768. sum += math.Abs(a.At(i, j))
  769. }
  770. if sum > max {
  771. max = sum
  772. }
  773. }
  774. return max
  775. }
  776. }
  777. // normLapack converts the float64 norm input in Norm to a lapack.MatrixNorm.
  778. func normLapack(norm float64, aTrans bool) lapack.MatrixNorm {
  779. switch norm {
  780. case 1:
  781. n := lapack.MaxColumnSum
  782. if aTrans {
  783. n = lapack.MaxRowSum
  784. }
  785. return n
  786. case 2:
  787. return lapack.Frobenius
  788. case math.Inf(1):
  789. n := lapack.MaxRowSum
  790. if aTrans {
  791. n = lapack.MaxColumnSum
  792. }
  793. return n
  794. default:
  795. panic(ErrNormOrder)
  796. }
  797. }
  798. // Sum returns the sum of the elements of the matrix.
  799. func Sum(a Matrix) float64 {
  800. // TODO(btracey): Add a fast path for the other supported matrix types.
  801. r, c := a.Dims()
  802. var sum float64
  803. aU, _ := untranspose(a)
  804. if rma, ok := aU.(RawMatrixer); ok {
  805. rm := rma.RawMatrix()
  806. for i := 0; i < rm.Rows; i++ {
  807. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
  808. sum += v
  809. }
  810. }
  811. return sum
  812. }
  813. for i := 0; i < r; i++ {
  814. for j := 0; j < c; j++ {
  815. sum += a.At(i, j)
  816. }
  817. }
  818. return sum
  819. }
  820. // A Tracer can compute the trace of the matrix. Trace must panic if the
  821. // matrix is not square.
  822. type Tracer interface {
  823. Trace() float64
  824. }
  825. // Trace returns the trace of the matrix. Trace will panic if the
  826. // matrix is not square.
  827. func Trace(a Matrix) float64 {
  828. m, _ := untransposeExtract(a)
  829. if t, ok := m.(Tracer); ok {
  830. return t.Trace()
  831. }
  832. r, c := a.Dims()
  833. if r != c {
  834. panic(ErrSquare)
  835. }
  836. var v float64
  837. for i := 0; i < r; i++ {
  838. v += a.At(i, i)
  839. }
  840. return v
  841. }
  842. func min(a, b int) int {
  843. if a < b {
  844. return a
  845. }
  846. return b
  847. }
  848. func max(a, b int) int {
  849. if a > b {
  850. return a
  851. }
  852. return b
  853. }
  854. // use returns a float64 slice with l elements, using f if it
  855. // has the necessary capacity, otherwise creating a new slice.
  856. func use(f []float64, l int) []float64 {
  857. if l <= cap(f) {
  858. return f[:l]
  859. }
  860. return make([]float64, l)
  861. }
  862. // useZeroed returns a float64 slice with l elements, using f if it
  863. // has the necessary capacity, otherwise creating a new slice. The
  864. // elements of the returned slice are guaranteed to be zero.
  865. func useZeroed(f []float64, l int) []float64 {
  866. if l <= cap(f) {
  867. f = f[:l]
  868. zero(f)
  869. return f
  870. }
  871. return make([]float64, l)
  872. }
  873. // zero zeros the given slice's elements.
  874. func zero(f []float64) {
  875. for i := range f {
  876. f[i] = 0
  877. }
  878. }
  879. // useInt returns an int slice with l elements, using i if it
  880. // has the necessary capacity, otherwise creating a new slice.
  881. func useInt(i []int, l int) []int {
  882. if l <= cap(i) {
  883. return i[:l]
  884. }
  885. return make([]int, l)
  886. }