| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450 |
- // Copyright ©2015 The Gonum Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package gonum
- import (
- "gonum.org/v1/gonum/blas"
- "gonum.org/v1/gonum/blas/blas64"
- "gonum.org/v1/gonum/lapack"
- )
- // Dlarfb applies a block reflector to a matrix.
- //
- // In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
- // c = h * c if side == Left and trans == NoTrans
- // c = c * h if side == Right and trans == NoTrans
- // c = hᵀ * c if side == Left and trans == Trans
- // c = c * hᵀ if side == Right and trans == Trans
- // h is a product of elementary reflectors. direct sets the direction of multiplication
- // h = h_1 * h_2 * ... * h_k if direct == Forward
- // h = h_k * h_k-1 * ... * h_1 if direct == Backward
- // The combination of direct and store defines the orientation of the elementary
- // reflectors. In all cases the ones on the diagonal are implicitly represented.
- //
- // If direct == lapack.Forward and store == lapack.ColumnWise
- // V = [ 1 ]
- // [v1 1 ]
- // [v1 v2 1]
- // [v1 v2 v3]
- // [v1 v2 v3]
- // If direct == lapack.Forward and store == lapack.RowWise
- // V = [ 1 v1 v1 v1 v1]
- // [ 1 v2 v2 v2]
- // [ 1 v3 v3]
- // If direct == lapack.Backward and store == lapack.ColumnWise
- // V = [v1 v2 v3]
- // [v1 v2 v3]
- // [ 1 v2 v3]
- // [ 1 v3]
- // [ 1]
- // If direct == lapack.Backward and store == lapack.RowWise
- // V = [v1 v1 1 ]
- // [v2 v2 v2 1 ]
- // [v3 v3 v3 v3 1]
- // An elementary reflector can be explicitly constructed by extracting the
- // corresponding elements of v, placing a 1 where the diagonal would be, and
- // placing zeros in the remaining elements.
- //
- // t is a k×k matrix containing the block reflector, and this function will panic
- // if t is not of sufficient size. See Dlarft for more information.
- //
- // work is a temporary storage matrix with stride ldwork.
- // work must be of size at least n×k side == Left and m×k if side == Right, and
- // this function will panic if this size is not met.
- //
- // Dlarfb is an internal routine. It is exported for testing purposes.
- func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) {
- nv := m
- if side == blas.Right {
- nv = n
- }
- switch {
- case side != blas.Left && side != blas.Right:
- panic(badSide)
- case trans != blas.Trans && trans != blas.NoTrans:
- panic(badTrans)
- case direct != lapack.Forward && direct != lapack.Backward:
- panic(badDirect)
- case store != lapack.ColumnWise && store != lapack.RowWise:
- panic(badStoreV)
- case m < 0:
- panic(mLT0)
- case n < 0:
- panic(nLT0)
- case k < 0:
- panic(kLT0)
- case store == lapack.ColumnWise && ldv < max(1, k):
- panic(badLdV)
- case store == lapack.RowWise && ldv < max(1, nv):
- panic(badLdV)
- case ldt < max(1, k):
- panic(badLdT)
- case ldc < max(1, n):
- panic(badLdC)
- case ldwork < max(1, k):
- panic(badLdWork)
- }
- if m == 0 || n == 0 {
- return
- }
- nw := n
- if side == blas.Right {
- nw = m
- }
- switch {
- case store == lapack.ColumnWise && len(v) < (nv-1)*ldv+k:
- panic(shortV)
- case store == lapack.RowWise && len(v) < (k-1)*ldv+nv:
- panic(shortV)
- case len(t) < (k-1)*ldt+k:
- panic(shortT)
- case len(c) < (m-1)*ldc+n:
- panic(shortC)
- case len(work) < (nw-1)*ldwork+k:
- panic(shortWork)
- }
- bi := blas64.Implementation()
- transt := blas.Trans
- if trans == blas.Trans {
- transt = blas.NoTrans
- }
- // TODO(btracey): This follows the original Lapack code where the
- // elements are copied into the columns of the working array. The
- // loops should go in the other direction so the data is written
- // into the rows of work so the copy is not strided. A bigger change
- // would be to replace work with workᵀ, but benchmarks would be
- // needed to see if the change is merited.
- if store == lapack.ColumnWise {
- if direct == lapack.Forward {
- // V1 is the first k rows of C. V2 is the remaining rows.
- if side == blas.Left {
- // W = Cᵀ V = C1ᵀ V1 + C2ᵀ V2 (stored in work).
- // W = C1.
- for j := 0; j < k; j++ {
- bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
- }
- // W = W * V1.
- bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
- n, k, 1,
- v, ldv,
- work, ldwork)
- if m > k {
- // W = W + C2ᵀ V2.
- bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
- 1, c[k*ldc:], ldc, v[k*ldv:], ldv,
- 1, work, ldwork)
- }
- // W = W * Tᵀ or W * T.
- bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
- 1, t, ldt,
- work, ldwork)
- // C -= V * Wᵀ.
- if m > k {
- // C2 -= V2 * Wᵀ.
- bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
- -1, v[k*ldv:], ldv, work, ldwork,
- 1, c[k*ldc:], ldc)
- }
- // W *= V1ᵀ.
- bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
- 1, v, ldv,
- work, ldwork)
- // C1 -= Wᵀ.
- // TODO(btracey): This should use blas.Axpy.
- for i := 0; i < n; i++ {
- for j := 0; j < k; j++ {
- c[j*ldc+i] -= work[i*ldwork+j]
- }
- }
- return
- }
- // Form C = C * H or C * Hᵀ, where C = (C1 C2).
- // W = C1.
- for i := 0; i < k; i++ {
- bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
- }
- // W *= V1.
- bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
- 1, v, ldv,
- work, ldwork)
- if n > k {
- bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
- 1, c[k:], ldc, v[k*ldv:], ldv,
- 1, work, ldwork)
- }
- // W *= T or Tᵀ.
- bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
- 1, t, ldt,
- work, ldwork)
- if n > k {
- bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
- -1, work, ldwork, v[k*ldv:], ldv,
- 1, c[k:], ldc)
- }
- // C -= W * Vᵀ.
- bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
- 1, v, ldv,
- work, ldwork)
- // C -= W.
- // TODO(btracey): This should use blas.Axpy.
- for i := 0; i < m; i++ {
- for j := 0; j < k; j++ {
- c[i*ldc+j] -= work[i*ldwork+j]
- }
- }
- return
- }
- // V = (V1)
- // = (V2) (last k rows)
- // Where V2 is unit upper triangular.
- if side == blas.Left {
- // Form H * C or
- // W = Cᵀ V.
- // W = C2ᵀ.
- for j := 0; j < k; j++ {
- bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
- }
- // W *= V2.
- bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
- 1, v[(m-k)*ldv:], ldv,
- work, ldwork)
- if m > k {
- // W += C1ᵀ * V1.
- bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
- 1, c, ldc, v, ldv,
- 1, work, ldwork)
- }
- // W *= T or Tᵀ.
- bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
- 1, t, ldt,
- work, ldwork)
- // C -= V * Wᵀ.
- if m > k {
- bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
- -1, v, ldv, work, ldwork,
- 1, c, ldc)
- }
- // W *= V2ᵀ.
- bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
- 1, v[(m-k)*ldv:], ldv,
- work, ldwork)
- // C2 -= Wᵀ.
- // TODO(btracey): This should use blas.Axpy.
- for i := 0; i < n; i++ {
- for j := 0; j < k; j++ {
- c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
- }
- }
- return
- }
- // Form C * H or C * Hᵀ where C = (C1 C2).
- // W = C * V.
- // W = C2.
- for j := 0; j < k; j++ {
- bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
- }
- // W = W * V2.
- bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
- 1, v[(n-k)*ldv:], ldv,
- work, ldwork)
- if n > k {
- bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
- 1, c, ldc, v, ldv,
- 1, work, ldwork)
- }
- // W *= T or Tᵀ.
- bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
- 1, t, ldt,
- work, ldwork)
- // C -= W * Vᵀ.
- if n > k {
- // C1 -= W * V1ᵀ.
- bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
- -1, work, ldwork, v, ldv,
- 1, c, ldc)
- }
- // W *= V2ᵀ.
- bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
- 1, v[(n-k)*ldv:], ldv,
- work, ldwork)
- // C2 -= W.
- // TODO(btracey): This should use blas.Axpy.
- for i := 0; i < m; i++ {
- for j := 0; j < k; j++ {
- c[i*ldc+n-k+j] -= work[i*ldwork+j]
- }
- }
- return
- }
- // Store = Rowwise.
- if direct == lapack.Forward {
- // V = (V1 V2) where v1 is unit upper triangular.
- if side == blas.Left {
- // Form H * C or Hᵀ * C where C = (C1; C2).
- // W = Cᵀ * Vᵀ.
- // W = C1ᵀ.
- for j := 0; j < k; j++ {
- bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
- }
- // W *= V1ᵀ.
- bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
- 1, v, ldv,
- work, ldwork)
- if m > k {
- bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
- 1, c[k*ldc:], ldc, v[k:], ldv,
- 1, work, ldwork)
- }
- // W *= T or Tᵀ.
- bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
- 1, t, ldt,
- work, ldwork)
- // C -= Vᵀ * Wᵀ.
- if m > k {
- bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
- -1, v[k:], ldv, work, ldwork,
- 1, c[k*ldc:], ldc)
- }
- // W *= V1.
- bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
- 1, v, ldv,
- work, ldwork)
- // C1 -= Wᵀ.
- // TODO(btracey): This should use blas.Axpy.
- for i := 0; i < n; i++ {
- for j := 0; j < k; j++ {
- c[j*ldc+i] -= work[i*ldwork+j]
- }
- }
- return
- }
- // Form C * H or C * Hᵀ where C = (C1 C2).
- // W = C * Vᵀ.
- // W = C1.
- for j := 0; j < k; j++ {
- bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
- }
- // W *= V1ᵀ.
- bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
- 1, v, ldv,
- work, ldwork)
- if n > k {
- bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
- 1, c[k:], ldc, v[k:], ldv,
- 1, work, ldwork)
- }
- // W *= T or Tᵀ.
- bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
- 1, t, ldt,
- work, ldwork)
- // C -= W * V.
- if n > k {
- bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
- -1, work, ldwork, v[k:], ldv,
- 1, c[k:], ldc)
- }
- // W *= V1.
- bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
- 1, v, ldv,
- work, ldwork)
- // C1 -= W.
- // TODO(btracey): This should use blas.Axpy.
- for i := 0; i < m; i++ {
- for j := 0; j < k; j++ {
- c[i*ldc+j] -= work[i*ldwork+j]
- }
- }
- return
- }
- // V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
- if side == blas.Left {
- // Form H * C or Hᵀ C where C = (C1 ; C2).
- // W = Cᵀ * Vᵀ.
- // W = C2ᵀ.
- for j := 0; j < k; j++ {
- bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
- }
- // W *= V2ᵀ.
- bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
- 1, v[m-k:], ldv,
- work, ldwork)
- if m > k {
- bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
- 1, c, ldc, v, ldv,
- 1, work, ldwork)
- }
- // W *= T or Tᵀ.
- bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
- 1, t, ldt,
- work, ldwork)
- // C -= Vᵀ * Wᵀ.
- if m > k {
- bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
- -1, v, ldv, work, ldwork,
- 1, c, ldc)
- }
- // W *= V2.
- bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
- 1, v[m-k:], ldv,
- work, ldwork)
- // C2 -= Wᵀ.
- // TODO(btracey): This should use blas.Axpy.
- for i := 0; i < n; i++ {
- for j := 0; j < k; j++ {
- c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
- }
- }
- return
- }
- // Form C * H or C * Hᵀ where C = (C1 C2).
- // W = C * Vᵀ.
- // W = C2.
- for j := 0; j < k; j++ {
- bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
- }
- // W *= V2ᵀ.
- bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
- 1, v[n-k:], ldv,
- work, ldwork)
- if n > k {
- bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
- 1, c, ldc, v, ldv,
- 1, work, ldwork)
- }
- // W *= T or Tᵀ.
- bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
- 1, t, ldt,
- work, ldwork)
- // C -= W * V.
- if n > k {
- bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
- -1, work, ldwork, v, ldv,
- 1, c, ldc)
- }
- // W *= V2.
- bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
- 1, v[n-k:], ldv,
- work, ldwork)
- // C1 -= W.
- // TODO(btracey): This should use blas.Axpy.
- for i := 0; i < m; i++ {
- for j := 0; j < k; j++ {
- c[i*ldc+n-k+j] -= work[i*ldwork+j]
- }
- }
- }
|