matrix.go 26 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004
  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. "gonum.org/v1/gonum/floats"
  10. "gonum.org/v1/gonum/lapack"
  11. "gonum.org/v1/gonum/lapack/lapack64"
  12. )
  13. // Matrix is the basic matrix interface type.
  14. type Matrix interface {
  15. // Dims returns the dimensions of a Matrix.
  16. Dims() (r, c int)
  17. // At returns the value of a matrix element at row i, column j.
  18. // It will panic if i or j are out of bounds for the matrix.
  19. At(i, j int) float64
  20. // T returns the transpose of the Matrix. Whether T returns a copy of the
  21. // underlying data is implementation dependent.
  22. // This method may be implemented using the Transpose type, which
  23. // provides an implicit matrix transpose.
  24. T() Matrix
  25. }
  26. // allMatrix represents the extra set of methods that all mat Matrix types
  27. // should satisfy. This is used to enforce compile-time consistency between the
  28. // Dense types, especially helpful when adding new features.
  29. type allMatrix interface {
  30. Reseter
  31. IsEmpty() bool
  32. Zero()
  33. }
  34. // denseMatrix represents the extra set of methods that all Dense Matrix types
  35. // should satisfy. This is used to enforce compile-time consistency between the
  36. // Dense types, especially helpful when adding new features.
  37. type denseMatrix interface {
  38. DiagView() Diagonal
  39. Tracer
  40. }
  41. var (
  42. _ Matrix = Transpose{}
  43. _ Untransposer = Transpose{}
  44. )
  45. // Transpose is a type for performing an implicit matrix transpose. It implements
  46. // the Matrix interface, returning values from the transpose of the matrix within.
  47. type Transpose struct {
  48. Matrix Matrix
  49. }
  50. // At returns the value of the element at row i and column j of the transposed
  51. // matrix, that is, row j and column i of the Matrix field.
  52. func (t Transpose) At(i, j int) float64 {
  53. return t.Matrix.At(j, i)
  54. }
  55. // Dims returns the dimensions of the transposed matrix. The number of rows returned
  56. // is the number of columns in the Matrix field, and the number of columns is
  57. // the number of rows in the Matrix field.
  58. func (t Transpose) Dims() (r, c int) {
  59. c, r = t.Matrix.Dims()
  60. return r, c
  61. }
  62. // T performs an implicit transpose by returning the Matrix field.
  63. func (t Transpose) T() Matrix {
  64. return t.Matrix
  65. }
  66. // Untranspose returns the Matrix field.
  67. func (t Transpose) Untranspose() Matrix {
  68. return t.Matrix
  69. }
  70. // Untransposer is a type that can undo an implicit transpose.
  71. type Untransposer interface {
  72. // Note: This interface is needed to unify all of the Transpose types. In
  73. // the mat methods, we need to test if the Matrix has been implicitly
  74. // transposed. If this is checked by testing for the specific Transpose type
  75. // then the behavior will be different if the user uses T() or TTri() for a
  76. // triangular matrix.
  77. // Untranspose returns the underlying Matrix stored for the implicit transpose.
  78. Untranspose() Matrix
  79. }
  80. // UntransposeBander is a type that can undo an implicit band transpose.
  81. type UntransposeBander interface {
  82. // Untranspose returns the underlying Banded stored for the implicit transpose.
  83. UntransposeBand() Banded
  84. }
  85. // UntransposeTrier is a type that can undo an implicit triangular transpose.
  86. type UntransposeTrier interface {
  87. // Untranspose returns the underlying Triangular stored for the implicit transpose.
  88. UntransposeTri() Triangular
  89. }
  90. // UntransposeTriBander is a type that can undo an implicit triangular banded
  91. // transpose.
  92. type UntransposeTriBander interface {
  93. // Untranspose returns the underlying Triangular stored for the implicit transpose.
  94. UntransposeTriBand() TriBanded
  95. }
  96. // Mutable is a matrix interface type that allows elements to be altered.
  97. type Mutable interface {
  98. // Set alters the matrix element at row i, column j to v.
  99. // It will panic if i or j are out of bounds for the matrix.
  100. Set(i, j int, v float64)
  101. Matrix
  102. }
  103. // A RowViewer can return a Vector reflecting a row that is backed by the matrix
  104. // data. The Vector returned will have length equal to the number of columns.
  105. type RowViewer interface {
  106. RowView(i int) Vector
  107. }
  108. // A RawRowViewer can return a slice of float64 reflecting a row that is backed by the matrix
  109. // data.
  110. type RawRowViewer interface {
  111. RawRowView(i int) []float64
  112. }
  113. // A ColViewer can return a Vector reflecting a column that is backed by the matrix
  114. // data. The Vector returned will have length equal to the number of rows.
  115. type ColViewer interface {
  116. ColView(j int) Vector
  117. }
  118. // A RawColViewer can return a slice of float64 reflecting a column that is backed by the matrix
  119. // data.
  120. type RawColViewer interface {
  121. RawColView(j int) []float64
  122. }
  123. // A ClonerFrom can make a copy of a into the receiver, overwriting the previous value of the
  124. // receiver. The clone operation does not make any restriction on shape and will not cause
  125. // shadowing.
  126. type ClonerFrom interface {
  127. CloneFrom(a Matrix)
  128. }
  129. // A Reseter can reset the matrix so that it can be reused as the receiver of a dimensionally
  130. // restricted operation. This is commonly used when the matrix is being used as a workspace
  131. // or temporary matrix.
  132. //
  133. // If the matrix is a view, using Reset may result in data corruption in elements outside
  134. // the view. Similarly, if the matrix shares backing data with another variable, using
  135. // Reset may lead to unexpected changes in data values.
  136. type Reseter interface {
  137. Reset()
  138. }
  139. // A Copier can make a copy of elements of a into the receiver. The submatrix copied
  140. // starts at row and column 0 and has dimensions equal to the minimum dimensions of
  141. // the two matrices. The number of row and columns copied is returned.
  142. // Copy will copy from a source that aliases the receiver unless the source is transposed;
  143. // an aliasing transpose copy will panic with the exception for a special case when
  144. // the source data has a unitary increment or stride.
  145. type Copier interface {
  146. Copy(a Matrix) (r, c int)
  147. }
  148. // A Grower can grow the size of the represented matrix by the given number of rows and columns.
  149. // Growing beyond the size given by the Caps method will result in the allocation of a new
  150. // matrix and copying of the elements. If Grow is called with negative increments it will
  151. // panic with ErrIndexOutOfRange.
  152. type Grower interface {
  153. Caps() (r, c int)
  154. Grow(r, c int) Matrix
  155. }
  156. // A BandWidther represents a banded matrix and can return the left and right half-bandwidths, k1 and
  157. // k2.
  158. type BandWidther interface {
  159. BandWidth() (k1, k2 int)
  160. }
  161. // A RawMatrixSetter can set the underlying blas64.General used by the receiver. There is no restriction
  162. // on the shape of the receiver. Changes to the receiver's elements will be reflected in the blas64.General.Data.
  163. type RawMatrixSetter interface {
  164. SetRawMatrix(a blas64.General)
  165. }
  166. // A RawMatrixer can return a blas64.General representation of the receiver. Changes to the blas64.General.Data
  167. // slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
  168. type RawMatrixer interface {
  169. RawMatrix() blas64.General
  170. }
  171. // A RawVectorer can return a blas64.Vector representation of the receiver. Changes to the blas64.Vector.Data
  172. // slice will be reflected in the original matrix, changes to the Inc field will not.
  173. type RawVectorer interface {
  174. RawVector() blas64.Vector
  175. }
  176. // A NonZeroDoer can call a function for each non-zero element of the receiver.
  177. // The parameters of the function are the element indices and its value.
  178. type NonZeroDoer interface {
  179. DoNonZero(func(i, j int, v float64))
  180. }
  181. // A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver.
  182. // The parameters of the function are the element indices and its value.
  183. type RowNonZeroDoer interface {
  184. DoRowNonZero(i int, fn func(i, j int, v float64))
  185. }
  186. // A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver.
  187. // The parameters of the function are the element indices and its value.
  188. type ColNonZeroDoer interface {
  189. DoColNonZero(j int, fn func(i, j int, v float64))
  190. }
  191. // untranspose untransposes a matrix if applicable. If a is an Untransposer, then
  192. // untranspose returns the underlying matrix and true. If it is not, then it returns
  193. // the input matrix and false.
  194. func untranspose(a Matrix) (Matrix, bool) {
  195. if ut, ok := a.(Untransposer); ok {
  196. return ut.Untranspose(), true
  197. }
  198. return a, false
  199. }
  200. // untransposeExtract returns an untransposed matrix in a built-in matrix type.
  201. //
  202. // The untransposed matrix is returned unaltered if it is a built-in matrix type.
  203. // Otherwise, if it implements a Raw method, an appropriate built-in type value
  204. // is returned holding the raw matrix value of the input. If neither of these
  205. // is possible, the untransposed matrix is returned.
  206. func untransposeExtract(a Matrix) (Matrix, bool) {
  207. ut, trans := untranspose(a)
  208. switch m := ut.(type) {
  209. case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense:
  210. return m, trans
  211. // TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
  212. case RawSymBander:
  213. rsb := m.RawSymBand()
  214. if rsb.Uplo != blas.Upper {
  215. return ut, trans
  216. }
  217. var sb SymBandDense
  218. sb.SetRawSymBand(rsb)
  219. return &sb, trans
  220. case RawTriBander:
  221. rtb := m.RawTriBand()
  222. if rtb.Diag == blas.Unit {
  223. return ut, trans
  224. }
  225. var tb TriBandDense
  226. tb.SetRawTriBand(rtb)
  227. return &tb, trans
  228. case RawBander:
  229. var b BandDense
  230. b.SetRawBand(m.RawBand())
  231. return &b, trans
  232. case RawTriangular:
  233. rt := m.RawTriangular()
  234. if rt.Diag == blas.Unit {
  235. return ut, trans
  236. }
  237. var t TriDense
  238. t.SetRawTriangular(rt)
  239. return &t, trans
  240. case RawSymmetricer:
  241. rs := m.RawSymmetric()
  242. if rs.Uplo != blas.Upper {
  243. return ut, trans
  244. }
  245. var s SymDense
  246. s.SetRawSymmetric(rs)
  247. return &s, trans
  248. case RawMatrixer:
  249. var d Dense
  250. d.SetRawMatrix(m.RawMatrix())
  251. return &d, trans
  252. default:
  253. return ut, trans
  254. }
  255. }
  256. // TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
  257. // TODO(btracey): Add in fast paths to Row/Col for the other concrete types
  258. // (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
  259. // Col copies the elements in the jth column of the matrix into the slice dst.
  260. // The length of the provided slice must equal the number of rows, unless the
  261. // slice is nil in which case a new slice is first allocated.
  262. func Col(dst []float64, j int, a Matrix) []float64 {
  263. r, c := a.Dims()
  264. if j < 0 || j >= c {
  265. panic(ErrColAccess)
  266. }
  267. if dst == nil {
  268. dst = make([]float64, r)
  269. } else {
  270. if len(dst) != r {
  271. panic(ErrColLength)
  272. }
  273. }
  274. aU, aTrans := untranspose(a)
  275. if rm, ok := aU.(RawMatrixer); ok {
  276. m := rm.RawMatrix()
  277. if aTrans {
  278. copy(dst, m.Data[j*m.Stride:j*m.Stride+m.Cols])
  279. return dst
  280. }
  281. blas64.Copy(blas64.Vector{N: r, Inc: m.Stride, Data: m.Data[j:]},
  282. blas64.Vector{N: r, Inc: 1, Data: dst},
  283. )
  284. return dst
  285. }
  286. for i := 0; i < r; i++ {
  287. dst[i] = a.At(i, j)
  288. }
  289. return dst
  290. }
  291. // Row copies the elements in the ith row of the matrix into the slice dst.
  292. // The length of the provided slice must equal the number of columns, unless the
  293. // slice is nil in which case a new slice is first allocated.
  294. func Row(dst []float64, i int, a Matrix) []float64 {
  295. r, c := a.Dims()
  296. if i < 0 || i >= r {
  297. panic(ErrColAccess)
  298. }
  299. if dst == nil {
  300. dst = make([]float64, c)
  301. } else {
  302. if len(dst) != c {
  303. panic(ErrRowLength)
  304. }
  305. }
  306. aU, aTrans := untranspose(a)
  307. if rm, ok := aU.(RawMatrixer); ok {
  308. m := rm.RawMatrix()
  309. if aTrans {
  310. blas64.Copy(blas64.Vector{N: c, Inc: m.Stride, Data: m.Data[i:]},
  311. blas64.Vector{N: c, Inc: 1, Data: dst},
  312. )
  313. return dst
  314. }
  315. copy(dst, m.Data[i*m.Stride:i*m.Stride+m.Cols])
  316. return dst
  317. }
  318. for j := 0; j < c; j++ {
  319. dst[j] = a.At(i, j)
  320. }
  321. return dst
  322. }
  323. // Cond returns the condition number of the given matrix under the given norm.
  324. // The condition number must be based on the 1-norm, 2-norm or ∞-norm.
  325. // Cond will panic with matrix.ErrShape if the matrix has zero size.
  326. //
  327. // BUG(btracey): The computation of the 1-norm and ∞-norm for non-square matrices
  328. // is innacurate, although is typically the right order of magnitude. See
  329. // https://github.com/xianyi/OpenBLAS/issues/636. While the value returned will
  330. // change with the resolution of this bug, the result from Cond will match the
  331. // condition number used internally.
  332. func Cond(a Matrix, norm float64) float64 {
  333. m, n := a.Dims()
  334. if m == 0 || n == 0 {
  335. panic(ErrShape)
  336. }
  337. var lnorm lapack.MatrixNorm
  338. switch norm {
  339. default:
  340. panic("mat: bad norm value")
  341. case 1:
  342. lnorm = lapack.MaxColumnSum
  343. case 2:
  344. var svd SVD
  345. ok := svd.Factorize(a, SVDNone)
  346. if !ok {
  347. return math.Inf(1)
  348. }
  349. return svd.Cond()
  350. case math.Inf(1):
  351. lnorm = lapack.MaxRowSum
  352. }
  353. if m == n {
  354. // Use the LU decomposition to compute the condition number.
  355. var lu LU
  356. lu.factorize(a, lnorm)
  357. return lu.Cond()
  358. }
  359. if m > n {
  360. // Use the QR factorization to compute the condition number.
  361. var qr QR
  362. qr.factorize(a, lnorm)
  363. return qr.Cond()
  364. }
  365. // Use the LQ factorization to compute the condition number.
  366. var lq LQ
  367. lq.factorize(a, lnorm)
  368. return lq.Cond()
  369. }
  370. // Det returns the determinant of the matrix a. In many expressions using LogDet
  371. // will be more numerically stable.
  372. func Det(a Matrix) float64 {
  373. det, sign := LogDet(a)
  374. return math.Exp(det) * sign
  375. }
  376. // Dot returns the sum of the element-wise product of a and b.
  377. // Dot panics if the matrix sizes are unequal.
  378. func Dot(a, b Vector) float64 {
  379. la := a.Len()
  380. lb := b.Len()
  381. if la != lb {
  382. panic(ErrShape)
  383. }
  384. if arv, ok := a.(RawVectorer); ok {
  385. if brv, ok := b.(RawVectorer); ok {
  386. return blas64.Dot(arv.RawVector(), brv.RawVector())
  387. }
  388. }
  389. var sum float64
  390. for i := 0; i < la; i++ {
  391. sum += a.At(i, 0) * b.At(i, 0)
  392. }
  393. return sum
  394. }
  395. // Equal returns whether the matrices a and b have the same size
  396. // and are element-wise equal.
  397. func Equal(a, b Matrix) bool {
  398. ar, ac := a.Dims()
  399. br, bc := b.Dims()
  400. if ar != br || ac != bc {
  401. return false
  402. }
  403. aU, aTrans := untranspose(a)
  404. bU, bTrans := untranspose(b)
  405. if rma, ok := aU.(RawMatrixer); ok {
  406. if rmb, ok := bU.(RawMatrixer); ok {
  407. ra := rma.RawMatrix()
  408. rb := rmb.RawMatrix()
  409. if aTrans == bTrans {
  410. for i := 0; i < ra.Rows; i++ {
  411. for j := 0; j < ra.Cols; j++ {
  412. if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
  413. return false
  414. }
  415. }
  416. }
  417. return true
  418. }
  419. for i := 0; i < ra.Rows; i++ {
  420. for j := 0; j < ra.Cols; j++ {
  421. if ra.Data[i*ra.Stride+j] != rb.Data[j*rb.Stride+i] {
  422. return false
  423. }
  424. }
  425. }
  426. return true
  427. }
  428. }
  429. if rma, ok := aU.(RawSymmetricer); ok {
  430. if rmb, ok := bU.(RawSymmetricer); ok {
  431. ra := rma.RawSymmetric()
  432. rb := rmb.RawSymmetric()
  433. // Symmetric matrices are always upper and equal to their transpose.
  434. for i := 0; i < ra.N; i++ {
  435. for j := i; j < ra.N; j++ {
  436. if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
  437. return false
  438. }
  439. }
  440. }
  441. return true
  442. }
  443. }
  444. if ra, ok := aU.(*VecDense); ok {
  445. if rb, ok := bU.(*VecDense); ok {
  446. // If the raw vectors are the same length they must either both be
  447. // transposed or both not transposed (or have length 1).
  448. for i := 0; i < ra.mat.N; i++ {
  449. if ra.mat.Data[i*ra.mat.Inc] != rb.mat.Data[i*rb.mat.Inc] {
  450. return false
  451. }
  452. }
  453. return true
  454. }
  455. }
  456. for i := 0; i < ar; i++ {
  457. for j := 0; j < ac; j++ {
  458. if a.At(i, j) != b.At(i, j) {
  459. return false
  460. }
  461. }
  462. }
  463. return true
  464. }
  465. // EqualApprox returns whether the matrices a and b have the same size and contain all equal
  466. // elements with tolerance for element-wise equality specified by epsilon. Matrices
  467. // with non-equal shapes are not equal.
  468. func EqualApprox(a, b Matrix, epsilon float64) bool {
  469. ar, ac := a.Dims()
  470. br, bc := b.Dims()
  471. if ar != br || ac != bc {
  472. return false
  473. }
  474. aU, aTrans := untranspose(a)
  475. bU, bTrans := untranspose(b)
  476. if rma, ok := aU.(RawMatrixer); ok {
  477. if rmb, ok := bU.(RawMatrixer); ok {
  478. ra := rma.RawMatrix()
  479. rb := rmb.RawMatrix()
  480. if aTrans == bTrans {
  481. for i := 0; i < ra.Rows; i++ {
  482. for j := 0; j < ra.Cols; j++ {
  483. if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
  484. return false
  485. }
  486. }
  487. }
  488. return true
  489. }
  490. for i := 0; i < ra.Rows; i++ {
  491. for j := 0; j < ra.Cols; j++ {
  492. if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[j*rb.Stride+i], epsilon, epsilon) {
  493. return false
  494. }
  495. }
  496. }
  497. return true
  498. }
  499. }
  500. if rma, ok := aU.(RawSymmetricer); ok {
  501. if rmb, ok := bU.(RawSymmetricer); ok {
  502. ra := rma.RawSymmetric()
  503. rb := rmb.RawSymmetric()
  504. // Symmetric matrices are always upper and equal to their transpose.
  505. for i := 0; i < ra.N; i++ {
  506. for j := i; j < ra.N; j++ {
  507. if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
  508. return false
  509. }
  510. }
  511. }
  512. return true
  513. }
  514. }
  515. if ra, ok := aU.(*VecDense); ok {
  516. if rb, ok := bU.(*VecDense); ok {
  517. // If the raw vectors are the same length they must either both be
  518. // transposed or both not transposed (or have length 1).
  519. for i := 0; i < ra.mat.N; i++ {
  520. if !floats.EqualWithinAbsOrRel(ra.mat.Data[i*ra.mat.Inc], rb.mat.Data[i*rb.mat.Inc], epsilon, epsilon) {
  521. return false
  522. }
  523. }
  524. return true
  525. }
  526. }
  527. for i := 0; i < ar; i++ {
  528. for j := 0; j < ac; j++ {
  529. if !floats.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) {
  530. return false
  531. }
  532. }
  533. }
  534. return true
  535. }
  536. // LogDet returns the log of the determinant and the sign of the determinant
  537. // for the matrix that has been factorized. Numerical stability in product and
  538. // division expressions is generally improved by working in log space.
  539. func LogDet(a Matrix) (det float64, sign float64) {
  540. // TODO(btracey): Add specialized routines for TriDense, etc.
  541. var lu LU
  542. lu.Factorize(a)
  543. return lu.LogDet()
  544. }
  545. // Max returns the largest element value of the matrix A.
  546. // Max will panic with matrix.ErrShape if the matrix has zero size.
  547. func Max(a Matrix) float64 {
  548. r, c := a.Dims()
  549. if r == 0 || c == 0 {
  550. panic(ErrShape)
  551. }
  552. // Max(A) = Max(Aᵀ)
  553. aU, _ := untranspose(a)
  554. switch m := aU.(type) {
  555. case RawMatrixer:
  556. rm := m.RawMatrix()
  557. max := math.Inf(-1)
  558. for i := 0; i < rm.Rows; i++ {
  559. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
  560. if v > max {
  561. max = v
  562. }
  563. }
  564. }
  565. return max
  566. case RawTriangular:
  567. rm := m.RawTriangular()
  568. // The max of a triangular is at least 0 unless the size is 1.
  569. if rm.N == 1 {
  570. return rm.Data[0]
  571. }
  572. max := 0.0
  573. if rm.Uplo == blas.Upper {
  574. for i := 0; i < rm.N; i++ {
  575. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  576. if v > max {
  577. max = v
  578. }
  579. }
  580. }
  581. return max
  582. }
  583. for i := 0; i < rm.N; i++ {
  584. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
  585. if v > max {
  586. max = v
  587. }
  588. }
  589. }
  590. return max
  591. case RawSymmetricer:
  592. rm := m.RawSymmetric()
  593. if rm.Uplo != blas.Upper {
  594. panic(badSymTriangle)
  595. }
  596. max := math.Inf(-1)
  597. for i := 0; i < rm.N; i++ {
  598. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  599. if v > max {
  600. max = v
  601. }
  602. }
  603. }
  604. return max
  605. default:
  606. r, c := aU.Dims()
  607. max := math.Inf(-1)
  608. for i := 0; i < r; i++ {
  609. for j := 0; j < c; j++ {
  610. v := aU.At(i, j)
  611. if v > max {
  612. max = v
  613. }
  614. }
  615. }
  616. return max
  617. }
  618. }
  619. // Min returns the smallest element value of the matrix A.
  620. // Min will panic with matrix.ErrShape if the matrix has zero size.
  621. func Min(a Matrix) float64 {
  622. r, c := a.Dims()
  623. if r == 0 || c == 0 {
  624. panic(ErrShape)
  625. }
  626. // Min(A) = Min(Aᵀ)
  627. aU, _ := untranspose(a)
  628. switch m := aU.(type) {
  629. case RawMatrixer:
  630. rm := m.RawMatrix()
  631. min := math.Inf(1)
  632. for i := 0; i < rm.Rows; i++ {
  633. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
  634. if v < min {
  635. min = v
  636. }
  637. }
  638. }
  639. return min
  640. case RawTriangular:
  641. rm := m.RawTriangular()
  642. // The min of a triangular is at most 0 unless the size is 1.
  643. if rm.N == 1 {
  644. return rm.Data[0]
  645. }
  646. min := 0.0
  647. if rm.Uplo == blas.Upper {
  648. for i := 0; i < rm.N; i++ {
  649. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  650. if v < min {
  651. min = v
  652. }
  653. }
  654. }
  655. return min
  656. }
  657. for i := 0; i < rm.N; i++ {
  658. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
  659. if v < min {
  660. min = v
  661. }
  662. }
  663. }
  664. return min
  665. case RawSymmetricer:
  666. rm := m.RawSymmetric()
  667. if rm.Uplo != blas.Upper {
  668. panic(badSymTriangle)
  669. }
  670. min := math.Inf(1)
  671. for i := 0; i < rm.N; i++ {
  672. for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
  673. if v < min {
  674. min = v
  675. }
  676. }
  677. }
  678. return min
  679. default:
  680. r, c := aU.Dims()
  681. min := math.Inf(1)
  682. for i := 0; i < r; i++ {
  683. for j := 0; j < c; j++ {
  684. v := aU.At(i, j)
  685. if v < min {
  686. min = v
  687. }
  688. }
  689. }
  690. return min
  691. }
  692. }
  693. // Norm returns the specified (induced) norm of the matrix a. See
  694. // https://en.wikipedia.org/wiki/Matrix_norm for the definition of an induced norm.
  695. //
  696. // Valid norms are:
  697. // 1 - The maximum absolute column sum
  698. // 2 - Frobenius norm, the square root of the sum of the squares of the elements.
  699. // Inf - The maximum absolute row sum.
  700. // Norm will panic with ErrNormOrder if an illegal norm order is specified and
  701. // with matrix.ErrShape if the matrix has zero size.
  702. func Norm(a Matrix, norm float64) float64 {
  703. r, c := a.Dims()
  704. if r == 0 || c == 0 {
  705. panic(ErrShape)
  706. }
  707. aU, aTrans := untranspose(a)
  708. var work []float64
  709. switch rma := aU.(type) {
  710. case RawMatrixer:
  711. rm := rma.RawMatrix()
  712. n := normLapack(norm, aTrans)
  713. if n == lapack.MaxColumnSum {
  714. work = getFloats(rm.Cols, false)
  715. defer putFloats(work)
  716. }
  717. return lapack64.Lange(n, rm, work)
  718. case RawTriangular:
  719. rm := rma.RawTriangular()
  720. n := normLapack(norm, aTrans)
  721. if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
  722. work = getFloats(rm.N, false)
  723. defer putFloats(work)
  724. }
  725. return lapack64.Lantr(n, rm, work)
  726. case RawSymmetricer:
  727. rm := rma.RawSymmetric()
  728. n := normLapack(norm, aTrans)
  729. if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
  730. work = getFloats(rm.N, false)
  731. defer putFloats(work)
  732. }
  733. return lapack64.Lansy(n, rm, work)
  734. case *VecDense:
  735. rv := rma.RawVector()
  736. switch norm {
  737. default:
  738. panic(ErrNormOrder)
  739. case 1:
  740. if aTrans {
  741. imax := blas64.Iamax(rv)
  742. return math.Abs(rma.At(imax, 0))
  743. }
  744. return blas64.Asum(rv)
  745. case 2:
  746. return blas64.Nrm2(rv)
  747. case math.Inf(1):
  748. if aTrans {
  749. return blas64.Asum(rv)
  750. }
  751. imax := blas64.Iamax(rv)
  752. return math.Abs(rma.At(imax, 0))
  753. }
  754. }
  755. switch norm {
  756. default:
  757. panic(ErrNormOrder)
  758. case 1:
  759. var max float64
  760. for j := 0; j < c; j++ {
  761. var sum float64
  762. for i := 0; i < r; i++ {
  763. sum += math.Abs(a.At(i, j))
  764. }
  765. if sum > max {
  766. max = sum
  767. }
  768. }
  769. return max
  770. case 2:
  771. var sum float64
  772. for i := 0; i < r; i++ {
  773. for j := 0; j < c; j++ {
  774. v := a.At(i, j)
  775. sum += v * v
  776. }
  777. }
  778. return math.Sqrt(sum)
  779. case math.Inf(1):
  780. var max float64
  781. for i := 0; i < r; i++ {
  782. var sum float64
  783. for j := 0; j < c; j++ {
  784. sum += math.Abs(a.At(i, j))
  785. }
  786. if sum > max {
  787. max = sum
  788. }
  789. }
  790. return max
  791. }
  792. }
  793. // normLapack converts the float64 norm input in Norm to a lapack.MatrixNorm.
  794. func normLapack(norm float64, aTrans bool) lapack.MatrixNorm {
  795. switch norm {
  796. case 1:
  797. n := lapack.MaxColumnSum
  798. if aTrans {
  799. n = lapack.MaxRowSum
  800. }
  801. return n
  802. case 2:
  803. return lapack.Frobenius
  804. case math.Inf(1):
  805. n := lapack.MaxRowSum
  806. if aTrans {
  807. n = lapack.MaxColumnSum
  808. }
  809. return n
  810. default:
  811. panic(ErrNormOrder)
  812. }
  813. }
  814. // Sum returns the sum of the elements of the matrix.
  815. func Sum(a Matrix) float64 {
  816. var sum float64
  817. aU, _ := untranspose(a)
  818. switch rma := aU.(type) {
  819. case RawSymmetricer:
  820. rm := rma.RawSymmetric()
  821. for i := 0; i < rm.N; i++ {
  822. // Diagonals count once while off-diagonals count twice.
  823. sum += rm.Data[i*rm.Stride+i]
  824. var s float64
  825. for _, v := range rm.Data[i*rm.Stride+i+1 : i*rm.Stride+rm.N] {
  826. s += v
  827. }
  828. sum += 2 * s
  829. }
  830. return sum
  831. case RawTriangular:
  832. rm := rma.RawTriangular()
  833. var startIdx, endIdx int
  834. for i := 0; i < rm.N; i++ {
  835. // Start and end index for this triangle-row.
  836. switch rm.Uplo {
  837. case blas.Upper:
  838. startIdx = i
  839. endIdx = rm.N
  840. case blas.Lower:
  841. startIdx = 0
  842. endIdx = i + 1
  843. default:
  844. panic(badTriangle)
  845. }
  846. for _, v := range rm.Data[i*rm.Stride+startIdx : i*rm.Stride+endIdx] {
  847. sum += v
  848. }
  849. }
  850. return sum
  851. case RawMatrixer:
  852. rm := rma.RawMatrix()
  853. for i := 0; i < rm.Rows; i++ {
  854. for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
  855. sum += v
  856. }
  857. }
  858. return sum
  859. case *VecDense:
  860. rm := rma.RawVector()
  861. for i := 0; i < rm.N; i++ {
  862. sum += rm.Data[i*rm.Inc]
  863. }
  864. return sum
  865. default:
  866. r, c := a.Dims()
  867. for i := 0; i < r; i++ {
  868. for j := 0; j < c; j++ {
  869. sum += a.At(i, j)
  870. }
  871. }
  872. return sum
  873. }
  874. }
  875. // A Tracer can compute the trace of the matrix. Trace must panic if the
  876. // matrix is not square.
  877. type Tracer interface {
  878. Trace() float64
  879. }
  880. // Trace returns the trace of the matrix. Trace will panic if the
  881. // matrix is not square.
  882. func Trace(a Matrix) float64 {
  883. m, _ := untransposeExtract(a)
  884. if t, ok := m.(Tracer); ok {
  885. return t.Trace()
  886. }
  887. r, c := a.Dims()
  888. if r != c {
  889. panic(ErrSquare)
  890. }
  891. var v float64
  892. for i := 0; i < r; i++ {
  893. v += a.At(i, i)
  894. }
  895. return v
  896. }
  897. func min(a, b int) int {
  898. if a < b {
  899. return a
  900. }
  901. return b
  902. }
  903. func max(a, b int) int {
  904. if a > b {
  905. return a
  906. }
  907. return b
  908. }
  909. // use returns a float64 slice with l elements, using f if it
  910. // has the necessary capacity, otherwise creating a new slice.
  911. func use(f []float64, l int) []float64 {
  912. if l <= cap(f) {
  913. return f[:l]
  914. }
  915. return make([]float64, l)
  916. }
  917. // useZeroed returns a float64 slice with l elements, using f if it
  918. // has the necessary capacity, otherwise creating a new slice. The
  919. // elements of the returned slice are guaranteed to be zero.
  920. func useZeroed(f []float64, l int) []float64 {
  921. if l <= cap(f) {
  922. f = f[:l]
  923. zero(f)
  924. return f
  925. }
  926. return make([]float64, l)
  927. }
  928. // zero zeros the given slice's elements.
  929. func zero(f []float64) {
  930. for i := range f {
  931. f[i] = 0
  932. }
  933. }
  934. // useInt returns an int slice with l elements, using i if it
  935. // has the necessary capacity, otherwise creating a new slice.
  936. func useInt(i []int, l int) []int {
  937. if l <= cap(i) {
  938. return i[:l]
  939. }
  940. return make([]int, l)
  941. }