symmetric.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. )
  10. var (
  11. symDense *SymDense
  12. _ Matrix = symDense
  13. _ Symmetric = symDense
  14. _ RawSymmetricer = symDense
  15. _ MutableSymmetric = symDense
  16. )
  17. const (
  18. badSymTriangle = "mat: blas64.Symmetric not upper"
  19. badSymCap = "mat: bad capacity for SymDense"
  20. )
  21. // SymDense is a symmetric matrix that uses dense storage. SymDense
  22. // matrices are stored in the upper triangle.
  23. type SymDense struct {
  24. mat blas64.Symmetric
  25. cap int
  26. }
  27. // Symmetric represents a symmetric matrix (where the element at {i, j} equals
  28. // the element at {j, i}). Symmetric matrices are always square.
  29. type Symmetric interface {
  30. Matrix
  31. // Symmetric returns the number of rows/columns in the matrix.
  32. Symmetric() int
  33. }
  34. // A RawSymmetricer can return a view of itself as a BLAS Symmetric matrix.
  35. type RawSymmetricer interface {
  36. RawSymmetric() blas64.Symmetric
  37. }
  38. // A MutableSymmetric can set elements of a symmetric matrix.
  39. type MutableSymmetric interface {
  40. Symmetric
  41. SetSym(i, j int, v float64)
  42. }
  43. // NewSymDense creates a new Symmetric matrix with n rows and columns. If data == nil,
  44. // a new slice is allocated for the backing slice. If len(data) == n*n, data is
  45. // used as the backing slice, and changes to the elements of the returned SymDense
  46. // will be reflected in data. If neither of these is true, NewSymDense will panic.
  47. // NewSymDense will panic if n is zero.
  48. //
  49. // The data must be arranged in row-major order, i.e. the (i*c + j)-th
  50. // element in the data slice is the {i, j}-th element in the matrix.
  51. // Only the values in the upper triangular portion of the matrix are used.
  52. func NewSymDense(n int, data []float64) *SymDense {
  53. if n <= 0 {
  54. if n == 0 {
  55. panic(ErrZeroLength)
  56. }
  57. panic("mat: negative dimension")
  58. }
  59. if data != nil && n*n != len(data) {
  60. panic(ErrShape)
  61. }
  62. if data == nil {
  63. data = make([]float64, n*n)
  64. }
  65. return &SymDense{
  66. mat: blas64.Symmetric{
  67. N: n,
  68. Stride: n,
  69. Data: data,
  70. Uplo: blas.Upper,
  71. },
  72. cap: n,
  73. }
  74. }
  75. // Dims returns the number of rows and columns in the matrix.
  76. func (s *SymDense) Dims() (r, c int) {
  77. return s.mat.N, s.mat.N
  78. }
  79. // Caps returns the number of rows and columns in the backing matrix.
  80. func (s *SymDense) Caps() (r, c int) {
  81. return s.cap, s.cap
  82. }
  83. // T returns the receiver, the transpose of a symmetric matrix.
  84. func (s *SymDense) T() Matrix {
  85. return s
  86. }
  87. // Symmetric implements the Symmetric interface and returns the number of rows
  88. // and columns in the matrix.
  89. func (s *SymDense) Symmetric() int {
  90. return s.mat.N
  91. }
  92. // RawSymmetric returns the matrix as a blas64.Symmetric. The returned
  93. // value must be stored in upper triangular format.
  94. func (s *SymDense) RawSymmetric() blas64.Symmetric {
  95. return s.mat
  96. }
  97. // SetRawSymmetric sets the underlying blas64.Symmetric used by the receiver.
  98. // Changes to elements in the receiver following the call will be reflected
  99. // in the input.
  100. //
  101. // The supplied Symmetric must use blas.Upper storage format.
  102. func (s *SymDense) SetRawSymmetric(mat blas64.Symmetric) {
  103. if mat.Uplo != blas.Upper {
  104. panic(badSymTriangle)
  105. }
  106. s.mat = mat
  107. }
  108. // Reset zeros the dimensions of the matrix so that it can be reused as the
  109. // receiver of a dimensionally restricted operation.
  110. //
  111. // See the Reseter interface for more information.
  112. func (s *SymDense) Reset() {
  113. // N and Stride must be zeroed in unison.
  114. s.mat.N, s.mat.Stride = 0, 0
  115. s.mat.Data = s.mat.Data[:0]
  116. }
  117. // Zero sets all of the matrix elements to zero.
  118. func (s *SymDense) Zero() {
  119. for i := 0; i < s.mat.N; i++ {
  120. zero(s.mat.Data[i*s.mat.Stride+i : i*s.mat.Stride+s.mat.N])
  121. }
  122. }
  123. // IsZero returns whether the receiver is zero-sized. Zero-sized matrices can be the
  124. // receiver for size-restricted operations. SymDense matrices can be zeroed using Reset.
  125. func (s *SymDense) IsZero() bool {
  126. // It must be the case that m.Dims() returns
  127. // zeros in this case. See comment in Reset().
  128. return s.mat.N == 0
  129. }
  130. // reuseAs resizes an empty matrix to a n×n matrix,
  131. // or checks that a non-empty matrix is n×n.
  132. func (s *SymDense) reuseAs(n int) {
  133. if n == 0 {
  134. panic(ErrZeroLength)
  135. }
  136. if s.mat.N > s.cap {
  137. panic(badSymCap)
  138. }
  139. if s.IsZero() {
  140. s.mat = blas64.Symmetric{
  141. N: n,
  142. Stride: n,
  143. Data: use(s.mat.Data, n*n),
  144. Uplo: blas.Upper,
  145. }
  146. s.cap = n
  147. return
  148. }
  149. if s.mat.Uplo != blas.Upper {
  150. panic(badSymTriangle)
  151. }
  152. if s.mat.N != n {
  153. panic(ErrShape)
  154. }
  155. }
  156. func (s *SymDense) isolatedWorkspace(a Symmetric) (w *SymDense, restore func()) {
  157. n := a.Symmetric()
  158. if n == 0 {
  159. panic(ErrZeroLength)
  160. }
  161. w = getWorkspaceSym(n, false)
  162. return w, func() {
  163. s.CopySym(w)
  164. putWorkspaceSym(w)
  165. }
  166. }
  167. // DiagView returns the diagonal as a matrix backed by the original data.
  168. func (s *SymDense) DiagView() Diagonal {
  169. n := s.mat.N
  170. return &DiagDense{
  171. mat: blas64.Vector{
  172. N: n,
  173. Inc: s.mat.Stride + 1,
  174. Data: s.mat.Data[:(n-1)*s.mat.Stride+n],
  175. },
  176. }
  177. }
  178. func (s *SymDense) AddSym(a, b Symmetric) {
  179. n := a.Symmetric()
  180. if n != b.Symmetric() {
  181. panic(ErrShape)
  182. }
  183. s.reuseAs(n)
  184. if a, ok := a.(RawSymmetricer); ok {
  185. if b, ok := b.(RawSymmetricer); ok {
  186. amat, bmat := a.RawSymmetric(), b.RawSymmetric()
  187. if s != a {
  188. s.checkOverlap(generalFromSymmetric(amat))
  189. }
  190. if s != b {
  191. s.checkOverlap(generalFromSymmetric(bmat))
  192. }
  193. for i := 0; i < n; i++ {
  194. btmp := bmat.Data[i*bmat.Stride+i : i*bmat.Stride+n]
  195. stmp := s.mat.Data[i*s.mat.Stride+i : i*s.mat.Stride+n]
  196. for j, v := range amat.Data[i*amat.Stride+i : i*amat.Stride+n] {
  197. stmp[j] = v + btmp[j]
  198. }
  199. }
  200. return
  201. }
  202. }
  203. s.checkOverlapMatrix(a)
  204. s.checkOverlapMatrix(b)
  205. for i := 0; i < n; i++ {
  206. stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
  207. for j := i; j < n; j++ {
  208. stmp[j] = a.At(i, j) + b.At(i, j)
  209. }
  210. }
  211. }
  212. func (s *SymDense) CopySym(a Symmetric) int {
  213. n := a.Symmetric()
  214. n = min(n, s.mat.N)
  215. if n == 0 {
  216. return 0
  217. }
  218. switch a := a.(type) {
  219. case RawSymmetricer:
  220. amat := a.RawSymmetric()
  221. if amat.Uplo != blas.Upper {
  222. panic(badSymTriangle)
  223. }
  224. for i := 0; i < n; i++ {
  225. copy(s.mat.Data[i*s.mat.Stride+i:i*s.mat.Stride+n], amat.Data[i*amat.Stride+i:i*amat.Stride+n])
  226. }
  227. default:
  228. for i := 0; i < n; i++ {
  229. stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
  230. for j := i; j < n; j++ {
  231. stmp[j] = a.At(i, j)
  232. }
  233. }
  234. }
  235. return n
  236. }
  237. // SymRankOne performs a symetric rank-one update to the matrix a and stores
  238. // the result in the receiver
  239. // s = a + alpha * x * x'
  240. func (s *SymDense) SymRankOne(a Symmetric, alpha float64, x Vector) {
  241. n, c := x.Dims()
  242. if a.Symmetric() != n || c != 1 {
  243. panic(ErrShape)
  244. }
  245. s.reuseAs(n)
  246. if s != a {
  247. if rs, ok := a.(RawSymmetricer); ok {
  248. s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
  249. }
  250. s.CopySym(a)
  251. }
  252. xU, _ := untranspose(x)
  253. if rv, ok := xU.(RawVectorer); ok {
  254. xmat := rv.RawVector()
  255. s.checkOverlap((&VecDense{mat: xmat}).asGeneral())
  256. blas64.Syr(alpha, xmat, s.mat)
  257. return
  258. }
  259. for i := 0; i < n; i++ {
  260. for j := i; j < n; j++ {
  261. s.set(i, j, s.at(i, j)+alpha*x.AtVec(i)*x.AtVec(j))
  262. }
  263. }
  264. }
  265. // SymRankK performs a symmetric rank-k update to the matrix a and stores the
  266. // result into the receiver. If a is zero, see SymOuterK.
  267. // s = a + alpha * x * x'
  268. func (s *SymDense) SymRankK(a Symmetric, alpha float64, x Matrix) {
  269. n := a.Symmetric()
  270. r, _ := x.Dims()
  271. if r != n {
  272. panic(ErrShape)
  273. }
  274. xMat, aTrans := untranspose(x)
  275. var g blas64.General
  276. if rm, ok := xMat.(RawMatrixer); ok {
  277. g = rm.RawMatrix()
  278. } else {
  279. g = DenseCopyOf(x).mat
  280. aTrans = false
  281. }
  282. if a != s {
  283. if rs, ok := a.(RawSymmetricer); ok {
  284. s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
  285. }
  286. s.reuseAs(n)
  287. s.CopySym(a)
  288. }
  289. t := blas.NoTrans
  290. if aTrans {
  291. t = blas.Trans
  292. }
  293. blas64.Syrk(t, alpha, g, 1, s.mat)
  294. }
  295. // SymOuterK calculates the outer product of x with itself and stores
  296. // the result into the receiver. It is equivalent to the matrix
  297. // multiplication
  298. // s = alpha * x * x'.
  299. // In order to update an existing matrix, see SymRankOne.
  300. func (s *SymDense) SymOuterK(alpha float64, x Matrix) {
  301. n, _ := x.Dims()
  302. switch {
  303. case s.IsZero():
  304. s.mat = blas64.Symmetric{
  305. N: n,
  306. Stride: n,
  307. Data: useZeroed(s.mat.Data, n*n),
  308. Uplo: blas.Upper,
  309. }
  310. s.cap = n
  311. s.SymRankK(s, alpha, x)
  312. case s.mat.Uplo != blas.Upper:
  313. panic(badSymTriangle)
  314. case s.mat.N == n:
  315. if s == x {
  316. w := getWorkspaceSym(n, true)
  317. w.SymRankK(w, alpha, x)
  318. s.CopySym(w)
  319. putWorkspaceSym(w)
  320. } else {
  321. switch r := x.(type) {
  322. case RawMatrixer:
  323. s.checkOverlap(r.RawMatrix())
  324. case RawSymmetricer:
  325. s.checkOverlap(generalFromSymmetric(r.RawSymmetric()))
  326. case RawTriangular:
  327. s.checkOverlap(generalFromTriangular(r.RawTriangular()))
  328. }
  329. // Only zero the upper triangle.
  330. for i := 0; i < n; i++ {
  331. ri := i * s.mat.Stride
  332. zero(s.mat.Data[ri+i : ri+n])
  333. }
  334. s.SymRankK(s, alpha, x)
  335. }
  336. default:
  337. panic(ErrShape)
  338. }
  339. }
  340. // RankTwo performs a symmmetric rank-two update to the matrix a and stores
  341. // the result in the receiver
  342. // m = a + alpha * (x * y' + y * x')
  343. func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y Vector) {
  344. n := s.mat.N
  345. xr, xc := x.Dims()
  346. if xr != n || xc != 1 {
  347. panic(ErrShape)
  348. }
  349. yr, yc := y.Dims()
  350. if yr != n || yc != 1 {
  351. panic(ErrShape)
  352. }
  353. if s != a {
  354. if rs, ok := a.(RawSymmetricer); ok {
  355. s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
  356. }
  357. }
  358. var xmat, ymat blas64.Vector
  359. fast := true
  360. xU, _ := untranspose(x)
  361. if rv, ok := xU.(RawVectorer); ok {
  362. xmat = rv.RawVector()
  363. s.checkOverlap((&VecDense{mat: xmat}).asGeneral())
  364. } else {
  365. fast = false
  366. }
  367. yU, _ := untranspose(y)
  368. if rv, ok := yU.(RawVectorer); ok {
  369. ymat = rv.RawVector()
  370. s.checkOverlap((&VecDense{mat: ymat}).asGeneral())
  371. } else {
  372. fast = false
  373. }
  374. if s != a {
  375. if rs, ok := a.(RawSymmetricer); ok {
  376. s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
  377. }
  378. s.reuseAs(n)
  379. s.CopySym(a)
  380. }
  381. if fast {
  382. if s != a {
  383. s.reuseAs(n)
  384. s.CopySym(a)
  385. }
  386. blas64.Syr2(alpha, xmat, ymat, s.mat)
  387. return
  388. }
  389. for i := 0; i < n; i++ {
  390. s.reuseAs(n)
  391. for j := i; j < n; j++ {
  392. s.set(i, j, a.At(i, j)+alpha*(x.AtVec(i)*y.AtVec(j)+y.AtVec(i)*x.AtVec(j)))
  393. }
  394. }
  395. }
  396. // ScaleSym multiplies the elements of a by f, placing the result in the receiver.
  397. func (s *SymDense) ScaleSym(f float64, a Symmetric) {
  398. n := a.Symmetric()
  399. s.reuseAs(n)
  400. if a, ok := a.(RawSymmetricer); ok {
  401. amat := a.RawSymmetric()
  402. if s != a {
  403. s.checkOverlap(generalFromSymmetric(amat))
  404. }
  405. for i := 0; i < n; i++ {
  406. for j := i; j < n; j++ {
  407. s.mat.Data[i*s.mat.Stride+j] = f * amat.Data[i*amat.Stride+j]
  408. }
  409. }
  410. return
  411. }
  412. for i := 0; i < n; i++ {
  413. for j := i; j < n; j++ {
  414. s.mat.Data[i*s.mat.Stride+j] = f * a.At(i, j)
  415. }
  416. }
  417. }
  418. // SubsetSym extracts a subset of the rows and columns of the matrix a and stores
  419. // the result in-place into the receiver. The resulting matrix size is
  420. // len(set)×len(set). Specifically, at the conclusion of SubsetSym,
  421. // s.At(i, j) equals a.At(set[i], set[j]). Note that the supplied set does not
  422. // have to be a strict subset, dimension repeats are allowed.
  423. func (s *SymDense) SubsetSym(a Symmetric, set []int) {
  424. n := len(set)
  425. na := a.Symmetric()
  426. s.reuseAs(n)
  427. var restore func()
  428. if a == s {
  429. s, restore = s.isolatedWorkspace(a)
  430. defer restore()
  431. }
  432. if a, ok := a.(RawSymmetricer); ok {
  433. raw := a.RawSymmetric()
  434. if s != a {
  435. s.checkOverlap(generalFromSymmetric(raw))
  436. }
  437. for i := 0; i < n; i++ {
  438. ssub := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
  439. r := set[i]
  440. rsub := raw.Data[r*raw.Stride : r*raw.Stride+na]
  441. for j := i; j < n; j++ {
  442. c := set[j]
  443. if r <= c {
  444. ssub[j] = rsub[c]
  445. } else {
  446. ssub[j] = raw.Data[c*raw.Stride+r]
  447. }
  448. }
  449. }
  450. return
  451. }
  452. for i := 0; i < n; i++ {
  453. for j := i; j < n; j++ {
  454. s.mat.Data[i*s.mat.Stride+j] = a.At(set[i], set[j])
  455. }
  456. }
  457. }
  458. // SliceSym returns a new Matrix that shares backing data with the receiver.
  459. // The returned matrix starts at {i,i} of the receiver and extends k-i rows
  460. // and columns. The final row and column in the resulting matrix is k-1.
  461. // SliceSym panics with ErrIndexOutOfRange if the slice is outside the
  462. // capacity of the receiver.
  463. func (s *SymDense) SliceSym(i, k int) Symmetric {
  464. sz := s.cap
  465. if i < 0 || sz < i || k < i || sz < k {
  466. panic(ErrIndexOutOfRange)
  467. }
  468. v := *s
  469. v.mat.Data = s.mat.Data[i*s.mat.Stride+i : (k-1)*s.mat.Stride+k]
  470. v.mat.N = k - i
  471. v.cap = s.cap - i
  472. return &v
  473. }
  474. // Trace returns the trace of the matrix.
  475. func (s *SymDense) Trace() float64 {
  476. // TODO(btracey): could use internal asm sum routine.
  477. var v float64
  478. for i := 0; i < s.mat.N; i++ {
  479. v += s.mat.Data[i*s.mat.Stride+i]
  480. }
  481. return v
  482. }
  483. // GrowSym returns the receiver expanded by n rows and n columns. If the
  484. // dimensions of the expanded matrix are outside the capacity of the receiver
  485. // a new allocation is made, otherwise not. Note that the receiver itself is
  486. // not modified during the call to GrowSquare.
  487. func (s *SymDense) GrowSym(n int) Symmetric {
  488. if n < 0 {
  489. panic(ErrIndexOutOfRange)
  490. }
  491. if n == 0 {
  492. return s
  493. }
  494. var v SymDense
  495. n += s.mat.N
  496. if n > s.cap {
  497. v.mat = blas64.Symmetric{
  498. N: n,
  499. Stride: n,
  500. Uplo: blas.Upper,
  501. Data: make([]float64, n*n),
  502. }
  503. v.cap = n
  504. // Copy elements, including those not currently visible. Use a temporary
  505. // structure to avoid modifying the receiver.
  506. var tmp SymDense
  507. tmp.mat = blas64.Symmetric{
  508. N: s.cap,
  509. Stride: s.mat.Stride,
  510. Data: s.mat.Data,
  511. Uplo: s.mat.Uplo,
  512. }
  513. tmp.cap = s.cap
  514. v.CopySym(&tmp)
  515. return &v
  516. }
  517. v.mat = blas64.Symmetric{
  518. N: n,
  519. Stride: s.mat.Stride,
  520. Uplo: blas.Upper,
  521. Data: s.mat.Data[:(n-1)*s.mat.Stride+n],
  522. }
  523. v.cap = s.cap
  524. return &v
  525. }
  526. // PowPSD computes a^pow where a is a positive symmetric definite matrix.
  527. //
  528. // PowPSD returns an error if the matrix is not not positive symmetric definite
  529. // or the Eigendecomposition is not successful.
  530. func (s *SymDense) PowPSD(a Symmetric, pow float64) error {
  531. dim := a.Symmetric()
  532. s.reuseAs(dim)
  533. var eigen EigenSym
  534. ok := eigen.Factorize(a, true)
  535. if !ok {
  536. return ErrFailedEigen
  537. }
  538. values := eigen.Values(nil)
  539. for i, v := range values {
  540. if v <= 0 {
  541. return ErrNotPSD
  542. }
  543. values[i] = math.Pow(v, pow)
  544. }
  545. u := eigen.VectorsTo(nil)
  546. s.SymOuterK(values[0], u.ColView(0))
  547. var v VecDense
  548. for i := 1; i < dim; i++ {
  549. v.ColViewOf(u, i)
  550. s.SymRankOne(s, values[i], &v)
  551. }
  552. return nil
  553. }