dense_arithmetic.go 20 KB


  1. // Copyright ©2013 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "math"
  7. "gonum.org/v1/gonum/blas"
  8. "gonum.org/v1/gonum/blas/blas64"
  9. "gonum.org/v1/gonum/lapack/lapack64"
  10. )
  11. // Add adds a and b element-wise, placing the result in the receiver. Add
  12. // will panic if the two matrices do not have the same shape.
  13. func (m *Dense) Add(a, b Matrix) {
  14. ar, ac := a.Dims()
  15. br, bc := b.Dims()
  16. if ar != br || ac != bc {
  17. panic(ErrShape)
  18. }
  19. aU, _ := untranspose(a)
  20. bU, _ := untranspose(b)
  21. m.reuseAs(ar, ac)
  22. if arm, ok := a.(RawMatrixer); ok {
  23. if brm, ok := b.(RawMatrixer); ok {
  24. amat, bmat := arm.RawMatrix(), brm.RawMatrix()
  25. if m != aU {
  26. m.checkOverlap(amat)
  27. }
  28. if m != bU {
  29. m.checkOverlap(bmat)
  30. }
  31. for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
  32. for i, v := range amat.Data[ja : ja+ac] {
  33. m.mat.Data[i+jm] = v + bmat.Data[i+jb]
  34. }
  35. }
  36. return
  37. }
  38. }
  39. m.checkOverlapMatrix(aU)
  40. m.checkOverlapMatrix(bU)
  41. var restore func()
  42. if m == aU {
  43. m, restore = m.isolatedWorkspace(aU)
  44. defer restore()
  45. } else if m == bU {
  46. m, restore = m.isolatedWorkspace(bU)
  47. defer restore()
  48. }
  49. for r := 0; r < ar; r++ {
  50. for c := 0; c < ac; c++ {
  51. m.set(r, c, a.At(r, c)+b.At(r, c))
  52. }
  53. }
  54. }
  55. // Sub subtracts the matrix b from a, placing the result in the receiver. Sub
  56. // will panic if the two matrices do not have the same shape.
  57. func (m *Dense) Sub(a, b Matrix) {
  58. ar, ac := a.Dims()
  59. br, bc := b.Dims()
  60. if ar != br || ac != bc {
  61. panic(ErrShape)
  62. }
  63. aU, _ := untranspose(a)
  64. bU, _ := untranspose(b)
  65. m.reuseAs(ar, ac)
  66. if arm, ok := a.(RawMatrixer); ok {
  67. if brm, ok := b.(RawMatrixer); ok {
  68. amat, bmat := arm.RawMatrix(), brm.RawMatrix()
  69. if m != aU {
  70. m.checkOverlap(amat)
  71. }
  72. if m != bU {
  73. m.checkOverlap(bmat)
  74. }
  75. for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
  76. for i, v := range amat.Data[ja : ja+ac] {
  77. m.mat.Data[i+jm] = v - bmat.Data[i+jb]
  78. }
  79. }
  80. return
  81. }
  82. }
  83. m.checkOverlapMatrix(aU)
  84. m.checkOverlapMatrix(bU)
  85. var restore func()
  86. if m == aU {
  87. m, restore = m.isolatedWorkspace(aU)
  88. defer restore()
  89. } else if m == bU {
  90. m, restore = m.isolatedWorkspace(bU)
  91. defer restore()
  92. }
  93. for r := 0; r < ar; r++ {
  94. for c := 0; c < ac; c++ {
  95. m.set(r, c, a.At(r, c)-b.At(r, c))
  96. }
  97. }
  98. }
  99. // MulElem performs element-wise multiplication of a and b, placing the result
  100. // in the receiver. MulElem will panic if the two matrices do not have the same
  101. // shape.
  102. func (m *Dense) MulElem(a, b Matrix) {
  103. ar, ac := a.Dims()
  104. br, bc := b.Dims()
  105. if ar != br || ac != bc {
  106. panic(ErrShape)
  107. }
  108. aU, _ := untranspose(a)
  109. bU, _ := untranspose(b)
  110. m.reuseAs(ar, ac)
  111. if arm, ok := a.(RawMatrixer); ok {
  112. if brm, ok := b.(RawMatrixer); ok {
  113. amat, bmat := arm.RawMatrix(), brm.RawMatrix()
  114. if m != aU {
  115. m.checkOverlap(amat)
  116. }
  117. if m != bU {
  118. m.checkOverlap(bmat)
  119. }
  120. for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
  121. for i, v := range amat.Data[ja : ja+ac] {
  122. m.mat.Data[i+jm] = v * bmat.Data[i+jb]
  123. }
  124. }
  125. return
  126. }
  127. }
  128. m.checkOverlapMatrix(aU)
  129. m.checkOverlapMatrix(bU)
  130. var restore func()
  131. if m == aU {
  132. m, restore = m.isolatedWorkspace(aU)
  133. defer restore()
  134. } else if m == bU {
  135. m, restore = m.isolatedWorkspace(bU)
  136. defer restore()
  137. }
  138. for r := 0; r < ar; r++ {
  139. for c := 0; c < ac; c++ {
  140. m.set(r, c, a.At(r, c)*b.At(r, c))
  141. }
  142. }
  143. }
  144. // DivElem performs element-wise division of a by b, placing the result
  145. // in the receiver. DivElem will panic if the two matrices do not have the same
  146. // shape.
  147. func (m *Dense) DivElem(a, b Matrix) {
  148. ar, ac := a.Dims()
  149. br, bc := b.Dims()
  150. if ar != br || ac != bc {
  151. panic(ErrShape)
  152. }
  153. aU, _ := untranspose(a)
  154. bU, _ := untranspose(b)
  155. m.reuseAs(ar, ac)
  156. if arm, ok := a.(RawMatrixer); ok {
  157. if brm, ok := b.(RawMatrixer); ok {
  158. amat, bmat := arm.RawMatrix(), brm.RawMatrix()
  159. if m != aU {
  160. m.checkOverlap(amat)
  161. }
  162. if m != bU {
  163. m.checkOverlap(bmat)
  164. }
  165. for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
  166. for i, v := range amat.Data[ja : ja+ac] {
  167. m.mat.Data[i+jm] = v / bmat.Data[i+jb]
  168. }
  169. }
  170. return
  171. }
  172. }
  173. m.checkOverlapMatrix(aU)
  174. m.checkOverlapMatrix(bU)
  175. var restore func()
  176. if m == aU {
  177. m, restore = m.isolatedWorkspace(aU)
  178. defer restore()
  179. } else if m == bU {
  180. m, restore = m.isolatedWorkspace(bU)
  181. defer restore()
  182. }
  183. for r := 0; r < ar; r++ {
  184. for c := 0; c < ac; c++ {
  185. m.set(r, c, a.At(r, c)/b.At(r, c))
  186. }
  187. }
  188. }
  189. // Inverse computes the inverse of the matrix a, storing the result into the
  190. // receiver. If a is ill-conditioned, a Condition error will be returned.
  191. // Note that matrix inversion is numerically unstable, and should generally
  192. // be avoided where possible, for example by using the Solve routines.
  193. func (m *Dense) Inverse(a Matrix) error {
  194. // TODO(btracey): Special case for RawTriangular, etc.
  195. r, c := a.Dims()
  196. if r != c {
  197. panic(ErrSquare)
  198. }
  199. m.reuseAs(a.Dims())
  200. aU, aTrans := untranspose(a)
  201. switch rm := aU.(type) {
  202. case RawMatrixer:
  203. if m != aU || aTrans {
  204. if m == aU || m.checkOverlap(rm.RawMatrix()) {
  205. tmp := getWorkspace(r, c, false)
  206. tmp.Copy(a)
  207. m.Copy(tmp)
  208. putWorkspace(tmp)
  209. break
  210. }
  211. m.Copy(a)
  212. }
  213. default:
  214. m.Copy(a)
  215. }
  216. ipiv := getInts(r, false)
  217. defer putInts(ipiv)
  218. ok := lapack64.Getrf(m.mat, ipiv)
  219. if !ok {
  220. return Condition(math.Inf(1))
  221. }
  222. work := getFloats(4*r, false) // must be at least 4*r for cond.
  223. lapack64.Getri(m.mat, ipiv, work, -1)
  224. if int(work[0]) > 4*r {
  225. l := int(work[0])
  226. putFloats(work)
  227. work = getFloats(l, false)
  228. } else {
  229. work = work[:4*r]
  230. }
  231. defer putFloats(work)
  232. lapack64.Getri(m.mat, ipiv, work, len(work))
  233. norm := lapack64.Lange(CondNorm, m.mat, work)
  234. rcond := lapack64.Gecon(CondNorm, m.mat, norm, work, ipiv) // reuse ipiv
  235. if rcond == 0 {
  236. return Condition(math.Inf(1))
  237. }
  238. cond := 1 / rcond
  239. if cond > ConditionTolerance {
  240. return Condition(cond)
  241. }
  242. return nil
  243. }
  244. // Mul takes the matrix product of a and b, placing the result in the receiver.
  245. // If the number of columns in a does not equal the number of rows in b, Mul will panic.
  246. func (m *Dense) Mul(a, b Matrix) {
  247. ar, ac := a.Dims()
  248. br, bc := b.Dims()
  249. if ac != br {
  250. panic(ErrShape)
  251. }
  252. aU, aTrans := untranspose(a)
  253. bU, bTrans := untranspose(b)
  254. m.reuseAs(ar, bc)
  255. var restore func()
  256. if m == aU {
  257. m, restore = m.isolatedWorkspace(aU)
  258. defer restore()
  259. } else if m == bU {
  260. m, restore = m.isolatedWorkspace(bU)
  261. defer restore()
  262. }
  263. aT := blas.NoTrans
  264. if aTrans {
  265. aT = blas.Trans
  266. }
  267. bT := blas.NoTrans
  268. if bTrans {
  269. bT = blas.Trans
  270. }
  271. // Some of the cases do not have a transpose option, so create
  272. // temporary memory.
  273. // C = A^T * B = (B^T * A)^T
  274. // C^T = B^T * A.
  275. if aUrm, ok := aU.(RawMatrixer); ok {
  276. amat := aUrm.RawMatrix()
  277. if restore == nil {
  278. m.checkOverlap(amat)
  279. }
  280. if bUrm, ok := bU.(RawMatrixer); ok {
  281. bmat := bUrm.RawMatrix()
  282. if restore == nil {
  283. m.checkOverlap(bmat)
  284. }
  285. blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
  286. return
  287. }
  288. if bU, ok := bU.(RawSymmetricer); ok {
  289. bmat := bU.RawSymmetric()
  290. if aTrans {
  291. c := getWorkspace(ac, ar, false)
  292. blas64.Symm(blas.Left, 1, bmat, amat, 0, c.mat)
  293. strictCopy(m, c.T())
  294. putWorkspace(c)
  295. return
  296. }
  297. blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat)
  298. return
  299. }
  300. if bU, ok := bU.(RawTriangular); ok {
  301. // Trmm updates in place, so copy aU first.
  302. bmat := bU.RawTriangular()
  303. if aTrans {
  304. c := getWorkspace(ac, ar, false)
  305. var tmp Dense
  306. tmp.SetRawMatrix(amat)
  307. c.Copy(&tmp)
  308. bT := blas.Trans
  309. if bTrans {
  310. bT = blas.NoTrans
  311. }
  312. blas64.Trmm(blas.Left, bT, 1, bmat, c.mat)
  313. strictCopy(m, c.T())
  314. putWorkspace(c)
  315. return
  316. }
  317. m.Copy(a)
  318. blas64.Trmm(blas.Right, bT, 1, bmat, m.mat)
  319. return
  320. }
  321. if bU, ok := bU.(*VecDense); ok {
  322. m.checkOverlap(bU.asGeneral())
  323. bvec := bU.RawVector()
  324. if bTrans {
  325. // {ar,1} x {1,bc}, which is not a vector.
  326. // Instead, construct B as a General.
  327. bmat := blas64.General{
  328. Rows: bc,
  329. Cols: 1,
  330. Stride: bvec.Inc,
  331. Data: bvec.Data,
  332. }
  333. blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
  334. return
  335. }
  336. cvec := blas64.Vector{
  337. Inc: m.mat.Stride,
  338. Data: m.mat.Data,
  339. }
  340. blas64.Gemv(aT, 1, amat, bvec, 0, cvec)
  341. return
  342. }
  343. }
  344. if bUrm, ok := bU.(RawMatrixer); ok {
  345. bmat := bUrm.RawMatrix()
  346. if restore == nil {
  347. m.checkOverlap(bmat)
  348. }
  349. if aU, ok := aU.(RawSymmetricer); ok {
  350. amat := aU.RawSymmetric()
  351. if bTrans {
  352. c := getWorkspace(bc, br, false)
  353. blas64.Symm(blas.Right, 1, amat, bmat, 0, c.mat)
  354. strictCopy(m, c.T())
  355. putWorkspace(c)
  356. return
  357. }
  358. blas64.Symm(blas.Left, 1, amat, bmat, 0, m.mat)
  359. return
  360. }
  361. if aU, ok := aU.(RawTriangular); ok {
  362. // Trmm updates in place, so copy bU first.
  363. amat := aU.RawTriangular()
  364. if bTrans {
  365. c := getWorkspace(bc, br, false)
  366. var tmp Dense
  367. tmp.SetRawMatrix(bmat)
  368. c.Copy(&tmp)
  369. aT := blas.Trans
  370. if aTrans {
  371. aT = blas.NoTrans
  372. }
  373. blas64.Trmm(blas.Right, aT, 1, amat, c.mat)
  374. strictCopy(m, c.T())
  375. putWorkspace(c)
  376. return
  377. }
  378. m.Copy(b)
  379. blas64.Trmm(blas.Left, aT, 1, amat, m.mat)
  380. return
  381. }
  382. if aU, ok := aU.(*VecDense); ok {
  383. m.checkOverlap(aU.asGeneral())
  384. avec := aU.RawVector()
  385. if aTrans {
  386. // {1,ac} x {ac, bc}
  387. // Transpose B so that the vector is on the right.
  388. cvec := blas64.Vector{
  389. Inc: 1,
  390. Data: m.mat.Data,
  391. }
  392. bT := blas.Trans
  393. if bTrans {
  394. bT = blas.NoTrans
  395. }
  396. blas64.Gemv(bT, 1, bmat, avec, 0, cvec)
  397. return
  398. }
  399. // {ar,1} x {1,bc} which is not a vector result.
  400. // Instead, construct A as a General.
  401. amat := blas64.General{
  402. Rows: ar,
  403. Cols: 1,
  404. Stride: avec.Inc,
  405. Data: avec.Data,
  406. }
  407. blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
  408. return
  409. }
  410. }
  411. m.checkOverlapMatrix(aU)
  412. m.checkOverlapMatrix(bU)
  413. row := getFloats(ac, false)
  414. defer putFloats(row)
  415. for r := 0; r < ar; r++ {
  416. for i := range row {
  417. row[i] = a.At(r, i)
  418. }
  419. for c := 0; c < bc; c++ {
  420. var v float64
  421. for i, e := range row {
  422. v += e * b.At(i, c)
  423. }
  424. m.mat.Data[r*m.mat.Stride+c] = v
  425. }
  426. }
  427. }
  428. // strictCopy copies a into m panicking if the shape of a and m differ.
  429. func strictCopy(m *Dense, a Matrix) {
  430. r, c := m.Copy(a)
  431. if r != m.mat.Rows || c != m.mat.Cols {
  432. // Panic with a string since this
  433. // is not a user-facing panic.
  434. panic(ErrShape.Error())
  435. }
  436. }
  437. // Exp calculates the exponential of the matrix a, e^a, placing the result
  438. // in the receiver. Exp will panic with matrix.ErrShape if a is not square.
  439. func (m *Dense) Exp(a Matrix) {
  440. // The implementation used here is from Functions of Matrices: Theory and Computation
  441. // Chapter 10, Algorithm 10.20. https://doi.org/10.1137/1.9780898717778.ch10
  442. r, c := a.Dims()
  443. if r != c {
  444. panic(ErrShape)
  445. }
  446. m.reuseAs(r, r)
  447. if r == 1 {
  448. m.mat.Data[0] = math.Exp(a.At(0, 0))
  449. return
  450. }
  451. pade := []struct {
  452. theta float64
  453. b []float64
  454. }{
  455. {theta: 0.015, b: []float64{
  456. 120, 60, 12, 1,
  457. }},
  458. {theta: 0.25, b: []float64{
  459. 30240, 15120, 3360, 420, 30, 1,
  460. }},
  461. {theta: 0.95, b: []float64{
  462. 17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1,
  463. }},
  464. {theta: 2.1, b: []float64{
  465. 17643225600, 8821612800, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1,
  466. }},
  467. }
  468. a1 := m
  469. a1.Copy(a)
  470. v := getWorkspace(r, r, true)
  471. vraw := v.RawMatrix()
  472. n := r * r
  473. vvec := blas64.Vector{N: n, Inc: 1, Data: vraw.Data}
  474. defer putWorkspace(v)
  475. u := getWorkspace(r, r, true)
  476. uraw := u.RawMatrix()
  477. uvec := blas64.Vector{N: n, Inc: 1, Data: uraw.Data}
  478. defer putWorkspace(u)
  479. a2 := getWorkspace(r, r, false)
  480. defer putWorkspace(a2)
  481. n1 := Norm(a, 1)
  482. for i, t := range pade {
  483. if n1 > t.theta {
  484. continue
  485. }
  486. // This loop only executes once, so
  487. // this is not as horrible as it looks.
  488. p := getWorkspace(r, r, true)
  489. praw := p.RawMatrix()
  490. pvec := blas64.Vector{N: n, Inc: 1, Data: praw.Data}
  491. defer putWorkspace(p)
  492. for k := 0; k < r; k++ {
  493. p.set(k, k, 1)
  494. v.set(k, k, t.b[0])
  495. u.set(k, k, t.b[1])
  496. }
  497. a2.Mul(a1, a1)
  498. for j := 0; j <= i; j++ {
  499. p.Mul(p, a2)
  500. blas64.Axpy(t.b[2*j+2], pvec, vvec)
  501. blas64.Axpy(t.b[2*j+3], pvec, uvec)
  502. }
  503. u.Mul(a1, u)
  504. // Use p as a workspace here and
  505. // rename u for the second call's
  506. // receiver.
  507. vmu, vpu := u, p
  508. vpu.Add(v, u)
  509. vmu.Sub(v, u)
  510. m.Solve(vmu, vpu)
  511. return
  512. }
  513. // Remaining Padé table line.
  514. const theta13 = 5.4
  515. b := [...]float64{
  516. 64764752532480000, 32382376266240000, 7771770303897600, 1187353796428800,
  517. 129060195264000, 10559470521600, 670442572800, 33522128640,
  518. 1323241920, 40840800, 960960, 16380, 182, 1,
  519. }
  520. s := math.Log2(n1 / theta13)
  521. if s >= 0 {
  522. s = math.Ceil(s)
  523. a1.Scale(1/math.Pow(2, s), a1)
  524. }
  525. a2.Mul(a1, a1)
  526. i := getWorkspace(r, r, true)
  527. for j := 0; j < r; j++ {
  528. i.set(j, j, 1)
  529. }
  530. iraw := i.RawMatrix()
  531. ivec := blas64.Vector{N: n, Inc: 1, Data: iraw.Data}
  532. defer putWorkspace(i)
  533. a2raw := a2.RawMatrix()
  534. a2vec := blas64.Vector{N: n, Inc: 1, Data: a2raw.Data}
  535. a4 := getWorkspace(r, r, false)
  536. a4raw := a4.RawMatrix()
  537. a4vec := blas64.Vector{N: n, Inc: 1, Data: a4raw.Data}
  538. defer putWorkspace(a4)
  539. a4.Mul(a2, a2)
  540. a6 := getWorkspace(r, r, false)
  541. a6raw := a6.RawMatrix()
  542. a6vec := blas64.Vector{N: n, Inc: 1, Data: a6raw.Data}
  543. defer putWorkspace(a6)
  544. a6.Mul(a2, a4)
  545. // V = A_6(b_12*A_6 + b_10*A_4 + b_8*A_2) + b_6*A_6 + b_4*A_4 + b_2*A_2 +b_0*I
  546. blas64.Axpy(b[12], a6vec, vvec)
  547. blas64.Axpy(b[10], a4vec, vvec)
  548. blas64.Axpy(b[8], a2vec, vvec)
  549. v.Mul(v, a6)
  550. blas64.Axpy(b[6], a6vec, vvec)
  551. blas64.Axpy(b[4], a4vec, vvec)
  552. blas64.Axpy(b[2], a2vec, vvec)
  553. blas64.Axpy(b[0], ivec, vvec)
  554. // U = A(A_6(b_13*A_6 + b_11*A_4 + b_9*A_2) + b_7*A_6 + b_5*A_4 + b_2*A_3 +b_1*I)
  555. blas64.Axpy(b[13], a6vec, uvec)
  556. blas64.Axpy(b[11], a4vec, uvec)
  557. blas64.Axpy(b[9], a2vec, uvec)
  558. u.Mul(u, a6)
  559. blas64.Axpy(b[7], a6vec, uvec)
  560. blas64.Axpy(b[5], a4vec, uvec)
  561. blas64.Axpy(b[3], a2vec, uvec)
  562. blas64.Axpy(b[1], ivec, uvec)
  563. u.Mul(u, a1)
  564. // Use i as a workspace here and
  565. // rename u for the second call's
  566. // receiver.
  567. vmu, vpu := u, i
  568. vpu.Add(v, u)
  569. vmu.Sub(v, u)
  570. m.Solve(vmu, vpu)
  571. for ; s > 0; s-- {
  572. m.Mul(m, m)
  573. }
  574. }
  575. // Pow calculates the integral power of the matrix a to n, placing the result
  576. // in the receiver. Pow will panic if n is negative or if a is not square.
  577. func (m *Dense) Pow(a Matrix, n int) {
  578. if n < 0 {
  579. panic("matrix: illegal power")
  580. }
  581. r, c := a.Dims()
  582. if r != c {
  583. panic(ErrShape)
  584. }
  585. m.reuseAs(r, c)
  586. // Take possible fast paths.
  587. switch n {
  588. case 0:
  589. for i := 0; i < r; i++ {
  590. zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
  591. m.mat.Data[i*m.mat.Stride+i] = 1
  592. }
  593. return
  594. case 1:
  595. m.Copy(a)
  596. return
  597. case 2:
  598. m.Mul(a, a)
  599. return
  600. }
  601. // Perform iterative exponentiation by squaring in work space.
  602. w := getWorkspace(r, r, false)
  603. w.Copy(a)
  604. s := getWorkspace(r, r, false)
  605. s.Copy(a)
  606. x := getWorkspace(r, r, false)
  607. for n--; n > 0; n >>= 1 {
  608. if n&1 != 0 {
  609. x.Mul(w, s)
  610. w, x = x, w
  611. }
  612. if n != 1 {
  613. x.Mul(s, s)
  614. s, x = x, s
  615. }
  616. }
  617. m.Copy(w)
  618. putWorkspace(w)
  619. putWorkspace(s)
  620. putWorkspace(x)
  621. }
  622. // Scale multiplies the elements of a by f, placing the result in the receiver.
  623. //
  624. // See the Scaler interface for more information.
  625. func (m *Dense) Scale(f float64, a Matrix) {
  626. ar, ac := a.Dims()
  627. m.reuseAs(ar, ac)
  628. aU, aTrans := untranspose(a)
  629. if rm, ok := aU.(RawMatrixer); ok {
  630. amat := rm.RawMatrix()
  631. if m == aU || m.checkOverlap(amat) {
  632. var restore func()
  633. m, restore = m.isolatedWorkspace(a)
  634. defer restore()
  635. }
  636. if !aTrans {
  637. for ja, jm := 0, 0; ja < ar*amat.Stride; ja, jm = ja+amat.Stride, jm+m.mat.Stride {
  638. for i, v := range amat.Data[ja : ja+ac] {
  639. m.mat.Data[i+jm] = v * f
  640. }
  641. }
  642. } else {
  643. for ja, jm := 0, 0; ja < ac*amat.Stride; ja, jm = ja+amat.Stride, jm+1 {
  644. for i, v := range amat.Data[ja : ja+ar] {
  645. m.mat.Data[i*m.mat.Stride+jm] = v * f
  646. }
  647. }
  648. }
  649. return
  650. }
  651. m.checkOverlapMatrix(a)
  652. for r := 0; r < ar; r++ {
  653. for c := 0; c < ac; c++ {
  654. m.set(r, c, f*a.At(r, c))
  655. }
  656. }
  657. }
  658. // Apply applies the function fn to each of the elements of a, placing the
  659. // resulting matrix in the receiver. The function fn takes a row/column
  660. // index and element value and returns some function of that tuple.
  661. func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
  662. ar, ac := a.Dims()
  663. m.reuseAs(ar, ac)
  664. aU, aTrans := untranspose(a)
  665. if rm, ok := aU.(RawMatrixer); ok {
  666. amat := rm.RawMatrix()
  667. if m == aU || m.checkOverlap(amat) {
  668. var restore func()
  669. m, restore = m.isolatedWorkspace(a)
  670. defer restore()
  671. }
  672. if !aTrans {
  673. for j, ja, jm := 0, 0, 0; ja < ar*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+m.mat.Stride {
  674. for i, v := range amat.Data[ja : ja+ac] {
  675. m.mat.Data[i+jm] = fn(j, i, v)
  676. }
  677. }
  678. } else {
  679. for j, ja, jm := 0, 0, 0; ja < ac*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+1 {
  680. for i, v := range amat.Data[ja : ja+ar] {
  681. m.mat.Data[i*m.mat.Stride+jm] = fn(i, j, v)
  682. }
  683. }
  684. }
  685. return
  686. }
  687. m.checkOverlapMatrix(a)
  688. for r := 0; r < ar; r++ {
  689. for c := 0; c < ac; c++ {
  690. m.set(r, c, fn(r, c, a.At(r, c)))
  691. }
  692. }
  693. }
  694. // RankOne performs a rank-one update to the matrix a and stores the result
  695. // in the receiver. If a is zero, see Outer.
  696. // m = a + alpha * x * y'
  697. func (m *Dense) RankOne(a Matrix, alpha float64, x, y Vector) {
  698. ar, ac := a.Dims()
  699. xr, xc := x.Dims()
  700. if xr != ar || xc != 1 {
  701. panic(ErrShape)
  702. }
  703. yr, yc := y.Dims()
  704. if yr != ac || yc != 1 {
  705. panic(ErrShape)
  706. }
  707. if a != m {
  708. aU, _ := untranspose(a)
  709. if rm, ok := aU.(RawMatrixer); ok {
  710. m.checkOverlap(rm.RawMatrix())
  711. }
  712. }
  713. var xmat, ymat blas64.Vector
  714. fast := true
  715. xU, _ := untranspose(x)
  716. if rv, ok := xU.(RawVectorer); ok {
  717. xmat = rv.RawVector()
  718. m.checkOverlap((&VecDense{mat: xmat}).asGeneral())
  719. } else {
  720. fast = false
  721. }
  722. yU, _ := untranspose(y)
  723. if rv, ok := yU.(RawVectorer); ok {
  724. ymat = rv.RawVector()
  725. m.checkOverlap((&VecDense{mat: ymat}).asGeneral())
  726. } else {
  727. fast = false
  728. }
  729. if fast {
  730. if m != a {
  731. m.reuseAs(ar, ac)
  732. m.Copy(a)
  733. }
  734. blas64.Ger(alpha, xmat, ymat, m.mat)
  735. return
  736. }
  737. m.reuseAs(ar, ac)
  738. for i := 0; i < ar; i++ {
  739. for j := 0; j < ac; j++ {
  740. m.set(i, j, a.At(i, j)+alpha*x.AtVec(i)*y.AtVec(j))
  741. }
  742. }
  743. }
  744. // Outer calculates the outer product of the column vectors x and y,
  745. // and stores the result in the receiver.
  746. // m = alpha * x * y'
  747. // In order to update an existing matrix, see RankOne.
  748. func (m *Dense) Outer(alpha float64, x, y Vector) {
  749. xr, xc := x.Dims()
  750. if xc != 1 {
  751. panic(ErrShape)
  752. }
  753. yr, yc := y.Dims()
  754. if yc != 1 {
  755. panic(ErrShape)
  756. }
  757. r := xr
  758. c := yr
  759. // Copied from reuseAs with use replaced by useZeroed
  760. // and a final zero of the matrix elements if we pass
  761. // the shape checks.
  762. // TODO(kortschak): Factor out into reuseZeroedAs if
  763. // we find another case that needs it.
  764. if m.mat.Rows > m.capRows || m.mat.Cols > m.capCols {
  765. // Panic as a string, not a mat.Error.
  766. panic("mat: caps not correctly set")
  767. }
  768. if m.IsZero() {
  769. m.mat = blas64.General{
  770. Rows: r,
  771. Cols: c,
  772. Stride: c,
  773. Data: useZeroed(m.mat.Data, r*c),
  774. }
  775. m.capRows = r
  776. m.capCols = c
  777. } else if r != m.mat.Rows || c != m.mat.Cols {
  778. panic(ErrShape)
  779. }
  780. var xmat, ymat blas64.Vector
  781. fast := true
  782. xU, _ := untranspose(x)
  783. if rv, ok := xU.(RawVectorer); ok {
  784. xmat = rv.RawVector()
  785. m.checkOverlap((&VecDense{mat: xmat}).asGeneral())
  786. } else {
  787. fast = false
  788. }
  789. yU, _ := untranspose(y)
  790. if rv, ok := yU.(RawVectorer); ok {
  791. ymat = rv.RawVector()
  792. m.checkOverlap((&VecDense{mat: ymat}).asGeneral())
  793. } else {
  794. fast = false
  795. }
  796. if fast {
  797. for i := 0; i < r; i++ {
  798. zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
  799. }
  800. blas64.Ger(alpha, xmat, ymat, m.mat)
  801. return
  802. }
  803. for i := 0; i < r; i++ {
  804. for j := 0; j < c; j++ {
  805. m.set(i, j, alpha*x.AtVec(i)*y.AtVec(j))
  806. }
  807. }
  808. }