dlarfb.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gonum
  5. import (
  6. "gonum.org/v1/gonum/blas"
  7. "gonum.org/v1/gonum/blas/blas64"
  8. "gonum.org/v1/gonum/lapack"
  9. )
  10. // Dlarfb applies a block reflector to a matrix.
  11. //
  12. // In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
  13. // c = h * c if side == Left and trans == NoTrans
  14. // c = c * h if side == Right and trans == NoTrans
  15. // c = hᵀ * c if side == Left and trans == Trans
  16. // c = c * hᵀ if side == Right and trans == Trans
  17. // h is a product of elementary reflectors. direct sets the direction of multiplication
  18. // h = h_1 * h_2 * ... * h_k if direct == Forward
  19. // h = h_k * h_k-1 * ... * h_1 if direct == Backward
  20. // The combination of direct and store defines the orientation of the elementary
  21. // reflectors. In all cases the ones on the diagonal are implicitly represented.
  22. //
  23. // If direct == lapack.Forward and store == lapack.ColumnWise
  24. // V = [ 1 ]
  25. // [v1 1 ]
  26. // [v1 v2 1]
  27. // [v1 v2 v3]
  28. // [v1 v2 v3]
  29. // If direct == lapack.Forward and store == lapack.RowWise
  30. // V = [ 1 v1 v1 v1 v1]
  31. // [ 1 v2 v2 v2]
  32. // [ 1 v3 v3]
  33. // If direct == lapack.Backward and store == lapack.ColumnWise
  34. // V = [v1 v2 v3]
  35. // [v1 v2 v3]
  36. // [ 1 v2 v3]
  37. // [ 1 v3]
  38. // [ 1]
  39. // If direct == lapack.Backward and store == lapack.RowWise
  40. // V = [v1 v1 1 ]
  41. // [v2 v2 v2 1 ]
  42. // [v3 v3 v3 v3 1]
  43. // An elementary reflector can be explicitly constructed by extracting the
  44. // corresponding elements of v, placing a 1 where the diagonal would be, and
  45. // placing zeros in the remaining elements.
  46. //
  47. // t is a k×k matrix containing the block reflector, and this function will panic
  48. // if t is not of sufficient size. See Dlarft for more information.
  49. //
  50. // work is a temporary storage matrix with stride ldwork.
  51. // work must be of size at least n×k side == Left and m×k if side == Right, and
  52. // this function will panic if this size is not met.
  53. //
  54. // Dlarfb is an internal routine. It is exported for testing purposes.
  55. func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) {
  56. nv := m
  57. if side == blas.Right {
  58. nv = n
  59. }
  60. switch {
  61. case side != blas.Left && side != blas.Right:
  62. panic(badSide)
  63. case trans != blas.Trans && trans != blas.NoTrans:
  64. panic(badTrans)
  65. case direct != lapack.Forward && direct != lapack.Backward:
  66. panic(badDirect)
  67. case store != lapack.ColumnWise && store != lapack.RowWise:
  68. panic(badStoreV)
  69. case m < 0:
  70. panic(mLT0)
  71. case n < 0:
  72. panic(nLT0)
  73. case k < 0:
  74. panic(kLT0)
  75. case store == lapack.ColumnWise && ldv < max(1, k):
  76. panic(badLdV)
  77. case store == lapack.RowWise && ldv < max(1, nv):
  78. panic(badLdV)
  79. case ldt < max(1, k):
  80. panic(badLdT)
  81. case ldc < max(1, n):
  82. panic(badLdC)
  83. case ldwork < max(1, k):
  84. panic(badLdWork)
  85. }
  86. if m == 0 || n == 0 {
  87. return
  88. }
  89. nw := n
  90. if side == blas.Right {
  91. nw = m
  92. }
  93. switch {
  94. case store == lapack.ColumnWise && len(v) < (nv-1)*ldv+k:
  95. panic(shortV)
  96. case store == lapack.RowWise && len(v) < (k-1)*ldv+nv:
  97. panic(shortV)
  98. case len(t) < (k-1)*ldt+k:
  99. panic(shortT)
  100. case len(c) < (m-1)*ldc+n:
  101. panic(shortC)
  102. case len(work) < (nw-1)*ldwork+k:
  103. panic(shortWork)
  104. }
  105. bi := blas64.Implementation()
  106. transt := blas.Trans
  107. if trans == blas.Trans {
  108. transt = blas.NoTrans
  109. }
  110. // TODO(btracey): This follows the original Lapack code where the
  111. // elements are copied into the columns of the working array. The
  112. // loops should go in the other direction so the data is written
  113. // into the rows of work so the copy is not strided. A bigger change
  114. // would be to replace work with workᵀ, but benchmarks would be
  115. // needed to see if the change is merited.
  116. if store == lapack.ColumnWise {
  117. if direct == lapack.Forward {
  118. // V1 is the first k rows of C. V2 is the remaining rows.
  119. if side == blas.Left {
  120. // W = Cᵀ V = C1ᵀ V1 + C2ᵀ V2 (stored in work).
  121. // W = C1.
  122. for j := 0; j < k; j++ {
  123. bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
  124. }
  125. // W = W * V1.
  126. bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
  127. n, k, 1,
  128. v, ldv,
  129. work, ldwork)
  130. if m > k {
  131. // W = W + C2ᵀ V2.
  132. bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
  133. 1, c[k*ldc:], ldc, v[k*ldv:], ldv,
  134. 1, work, ldwork)
  135. }
  136. // W = W * Tᵀ or W * T.
  137. bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
  138. 1, t, ldt,
  139. work, ldwork)
  140. // C -= V * Wᵀ.
  141. if m > k {
  142. // C2 -= V2 * Wᵀ.
  143. bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
  144. -1, v[k*ldv:], ldv, work, ldwork,
  145. 1, c[k*ldc:], ldc)
  146. }
  147. // W *= V1ᵀ.
  148. bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
  149. 1, v, ldv,
  150. work, ldwork)
  151. // C1 -= Wᵀ.
  152. // TODO(btracey): This should use blas.Axpy.
  153. for i := 0; i < n; i++ {
  154. for j := 0; j < k; j++ {
  155. c[j*ldc+i] -= work[i*ldwork+j]
  156. }
  157. }
  158. return
  159. }
  160. // Form C = C * H or C * Hᵀ, where C = (C1 C2).
  161. // W = C1.
  162. for i := 0; i < k; i++ {
  163. bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
  164. }
  165. // W *= V1.
  166. bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
  167. 1, v, ldv,
  168. work, ldwork)
  169. if n > k {
  170. bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
  171. 1, c[k:], ldc, v[k*ldv:], ldv,
  172. 1, work, ldwork)
  173. }
  174. // W *= T or Tᵀ.
  175. bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
  176. 1, t, ldt,
  177. work, ldwork)
  178. if n > k {
  179. bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
  180. -1, work, ldwork, v[k*ldv:], ldv,
  181. 1, c[k:], ldc)
  182. }
  183. // C -= W * Vᵀ.
  184. bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
  185. 1, v, ldv,
  186. work, ldwork)
  187. // C -= W.
  188. // TODO(btracey): This should use blas.Axpy.
  189. for i := 0; i < m; i++ {
  190. for j := 0; j < k; j++ {
  191. c[i*ldc+j] -= work[i*ldwork+j]
  192. }
  193. }
  194. return
  195. }
  196. // V = (V1)
  197. // = (V2) (last k rows)
  198. // Where V2 is unit upper triangular.
  199. if side == blas.Left {
  200. // Form H * C or
  201. // W = Cᵀ V.
  202. // W = C2ᵀ.
  203. for j := 0; j < k; j++ {
  204. bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
  205. }
  206. // W *= V2.
  207. bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
  208. 1, v[(m-k)*ldv:], ldv,
  209. work, ldwork)
  210. if m > k {
  211. // W += C1ᵀ * V1.
  212. bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
  213. 1, c, ldc, v, ldv,
  214. 1, work, ldwork)
  215. }
  216. // W *= T or Tᵀ.
  217. bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
  218. 1, t, ldt,
  219. work, ldwork)
  220. // C -= V * Wᵀ.
  221. if m > k {
  222. bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
  223. -1, v, ldv, work, ldwork,
  224. 1, c, ldc)
  225. }
  226. // W *= V2ᵀ.
  227. bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
  228. 1, v[(m-k)*ldv:], ldv,
  229. work, ldwork)
  230. // C2 -= Wᵀ.
  231. // TODO(btracey): This should use blas.Axpy.
  232. for i := 0; i < n; i++ {
  233. for j := 0; j < k; j++ {
  234. c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
  235. }
  236. }
  237. return
  238. }
  239. // Form C * H or C * Hᵀ where C = (C1 C2).
  240. // W = C * V.
  241. // W = C2.
  242. for j := 0; j < k; j++ {
  243. bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
  244. }
  245. // W = W * V2.
  246. bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
  247. 1, v[(n-k)*ldv:], ldv,
  248. work, ldwork)
  249. if n > k {
  250. bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
  251. 1, c, ldc, v, ldv,
  252. 1, work, ldwork)
  253. }
  254. // W *= T or Tᵀ.
  255. bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
  256. 1, t, ldt,
  257. work, ldwork)
  258. // C -= W * Vᵀ.
  259. if n > k {
  260. // C1 -= W * V1ᵀ.
  261. bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
  262. -1, work, ldwork, v, ldv,
  263. 1, c, ldc)
  264. }
  265. // W *= V2ᵀ.
  266. bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
  267. 1, v[(n-k)*ldv:], ldv,
  268. work, ldwork)
  269. // C2 -= W.
  270. // TODO(btracey): This should use blas.Axpy.
  271. for i := 0; i < m; i++ {
  272. for j := 0; j < k; j++ {
  273. c[i*ldc+n-k+j] -= work[i*ldwork+j]
  274. }
  275. }
  276. return
  277. }
  278. // Store = Rowwise.
  279. if direct == lapack.Forward {
  280. // V = (V1 V2) where v1 is unit upper triangular.
  281. if side == blas.Left {
  282. // Form H * C or Hᵀ * C where C = (C1; C2).
  283. // W = Cᵀ * Vᵀ.
  284. // W = C1ᵀ.
  285. for j := 0; j < k; j++ {
  286. bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
  287. }
  288. // W *= V1ᵀ.
  289. bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
  290. 1, v, ldv,
  291. work, ldwork)
  292. if m > k {
  293. bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
  294. 1, c[k*ldc:], ldc, v[k:], ldv,
  295. 1, work, ldwork)
  296. }
  297. // W *= T or Tᵀ.
  298. bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
  299. 1, t, ldt,
  300. work, ldwork)
  301. // C -= Vᵀ * Wᵀ.
  302. if m > k {
  303. bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
  304. -1, v[k:], ldv, work, ldwork,
  305. 1, c[k*ldc:], ldc)
  306. }
  307. // W *= V1.
  308. bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
  309. 1, v, ldv,
  310. work, ldwork)
  311. // C1 -= Wᵀ.
  312. // TODO(btracey): This should use blas.Axpy.
  313. for i := 0; i < n; i++ {
  314. for j := 0; j < k; j++ {
  315. c[j*ldc+i] -= work[i*ldwork+j]
  316. }
  317. }
  318. return
  319. }
  320. // Form C * H or C * Hᵀ where C = (C1 C2).
  321. // W = C * Vᵀ.
  322. // W = C1.
  323. for j := 0; j < k; j++ {
  324. bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
  325. }
  326. // W *= V1ᵀ.
  327. bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
  328. 1, v, ldv,
  329. work, ldwork)
  330. if n > k {
  331. bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
  332. 1, c[k:], ldc, v[k:], ldv,
  333. 1, work, ldwork)
  334. }
  335. // W *= T or Tᵀ.
  336. bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
  337. 1, t, ldt,
  338. work, ldwork)
  339. // C -= W * V.
  340. if n > k {
  341. bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
  342. -1, work, ldwork, v[k:], ldv,
  343. 1, c[k:], ldc)
  344. }
  345. // W *= V1.
  346. bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
  347. 1, v, ldv,
  348. work, ldwork)
  349. // C1 -= W.
  350. // TODO(btracey): This should use blas.Axpy.
  351. for i := 0; i < m; i++ {
  352. for j := 0; j < k; j++ {
  353. c[i*ldc+j] -= work[i*ldwork+j]
  354. }
  355. }
  356. return
  357. }
  358. // V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
  359. if side == blas.Left {
  360. // Form H * C or Hᵀ C where C = (C1 ; C2).
  361. // W = Cᵀ * Vᵀ.
  362. // W = C2ᵀ.
  363. for j := 0; j < k; j++ {
  364. bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
  365. }
  366. // W *= V2ᵀ.
  367. bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
  368. 1, v[m-k:], ldv,
  369. work, ldwork)
  370. if m > k {
  371. bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
  372. 1, c, ldc, v, ldv,
  373. 1, work, ldwork)
  374. }
  375. // W *= T or Tᵀ.
  376. bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
  377. 1, t, ldt,
  378. work, ldwork)
  379. // C -= Vᵀ * Wᵀ.
  380. if m > k {
  381. bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
  382. -1, v, ldv, work, ldwork,
  383. 1, c, ldc)
  384. }
  385. // W *= V2.
  386. bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
  387. 1, v[m-k:], ldv,
  388. work, ldwork)
  389. // C2 -= Wᵀ.
  390. // TODO(btracey): This should use blas.Axpy.
  391. for i := 0; i < n; i++ {
  392. for j := 0; j < k; j++ {
  393. c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
  394. }
  395. }
  396. return
  397. }
  398. // Form C * H or C * Hᵀ where C = (C1 C2).
  399. // W = C * Vᵀ.
  400. // W = C2.
  401. for j := 0; j < k; j++ {
  402. bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
  403. }
  404. // W *= V2ᵀ.
  405. bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
  406. 1, v[n-k:], ldv,
  407. work, ldwork)
  408. if n > k {
  409. bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
  410. 1, c, ldc, v, ldv,
  411. 1, work, ldwork)
  412. }
  413. // W *= T or Tᵀ.
  414. bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
  415. 1, t, ldt,
  416. work, ldwork)
  417. // C -= W * V.
  418. if n > k {
  419. bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
  420. -1, work, ldwork, v, ldv,
  421. 1, c, ldc)
  422. }
  423. // W *= V2.
  424. bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
  425. 1, v[n-k:], ldv,
  426. work, ldwork)
  427. // C1 -= W.
  428. // TODO(btracey): This should use blas.Axpy.
  429. for i := 0; i < m; i++ {
  430. for j := 0; j < k; j++ {
  431. c[i*ldc+n-k+j] -= work[i*ldwork+j]
  432. }
  433. }
  434. }