dlarft.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gonum
  5. import (
  6. "gonum.org/v1/gonum/blas"
  7. "gonum.org/v1/gonum/blas/blas64"
  8. "gonum.org/v1/gonum/lapack"
  9. )
  10. // Dlarft forms the triangular factor T of a block reflector H, storing the answer
  11. // in t.
  12. // H = I - V * T * Vᵀ if store == lapack.ColumnWise
  13. // H = I - Vᵀ * T * V if store == lapack.RowWise
  14. // H is defined by a product of the elementary reflectors where
  15. // H = H_0 * H_1 * ... * H_{k-1} if direct == lapack.Forward
  16. // H = H_{k-1} * ... * H_1 * H_0 if direct == lapack.Backward
  17. //
  18. // t is a k×k triangular matrix. t is upper triangular if direct = lapack.Forward
  19. // and lower triangular otherwise. This function will panic if t is not of
  20. // sufficient size.
  21. //
  22. // store describes the storage of the elementary reflectors in v. See
  23. // Dlarfb for a description of layout.
  24. //
  25. // tau contains the scalar factors of the elementary reflectors H_i.
  26. //
  27. // Dlarft is an internal routine. It is exported for testing purposes.
  28. func (Implementation) Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, v []float64, ldv int, tau []float64, t []float64, ldt int) {
  29. mv, nv := n, k
  30. if store == lapack.RowWise {
  31. mv, nv = k, n
  32. }
  33. switch {
  34. case direct != lapack.Forward && direct != lapack.Backward:
  35. panic(badDirect)
  36. case store != lapack.RowWise && store != lapack.ColumnWise:
  37. panic(badStoreV)
  38. case n < 0:
  39. panic(nLT0)
  40. case k < 1:
  41. panic(kLT1)
  42. case ldv < max(1, nv):
  43. panic(badLdV)
  44. case len(tau) < k:
  45. panic(shortTau)
  46. case ldt < max(1, k):
  47. panic(shortT)
  48. }
  49. if n == 0 {
  50. return
  51. }
  52. switch {
  53. case len(v) < (mv-1)*ldv+nv:
  54. panic(shortV)
  55. case len(t) < (k-1)*ldt+k:
  56. panic(shortT)
  57. }
  58. bi := blas64.Implementation()
  59. // TODO(btracey): There are a number of minor obvious loop optimizations here.
  60. // TODO(btracey): It may be possible to rearrange some of the code so that
  61. // index of 1 is more common in the Dgemv.
  62. if direct == lapack.Forward {
  63. prevlastv := n - 1
  64. for i := 0; i < k; i++ {
  65. prevlastv = max(i, prevlastv)
  66. if tau[i] == 0 {
  67. for j := 0; j <= i; j++ {
  68. t[j*ldt+i] = 0
  69. }
  70. continue
  71. }
  72. var lastv int
  73. if store == lapack.ColumnWise {
  74. // skip trailing zeros
  75. for lastv = n - 1; lastv >= i+1; lastv-- {
  76. if v[lastv*ldv+i] != 0 {
  77. break
  78. }
  79. }
  80. for j := 0; j < i; j++ {
  81. t[j*ldt+i] = -tau[i] * v[i*ldv+j]
  82. }
  83. j := min(lastv, prevlastv)
  84. bi.Dgemv(blas.Trans, j-i, i,
  85. -tau[i], v[(i+1)*ldv:], ldv, v[(i+1)*ldv+i:], ldv,
  86. 1, t[i:], ldt)
  87. } else {
  88. for lastv = n - 1; lastv >= i+1; lastv-- {
  89. if v[i*ldv+lastv] != 0 {
  90. break
  91. }
  92. }
  93. for j := 0; j < i; j++ {
  94. t[j*ldt+i] = -tau[i] * v[j*ldv+i]
  95. }
  96. j := min(lastv, prevlastv)
  97. bi.Dgemv(blas.NoTrans, i, j-i,
  98. -tau[i], v[i+1:], ldv, v[i*ldv+i+1:], 1,
  99. 1, t[i:], ldt)
  100. }
  101. bi.Dtrmv(blas.Upper, blas.NoTrans, blas.NonUnit, i, t, ldt, t[i:], ldt)
  102. t[i*ldt+i] = tau[i]
  103. if i > 1 {
  104. prevlastv = max(prevlastv, lastv)
  105. } else {
  106. prevlastv = lastv
  107. }
  108. }
  109. return
  110. }
  111. prevlastv := 0
  112. for i := k - 1; i >= 0; i-- {
  113. if tau[i] == 0 {
  114. for j := i; j < k; j++ {
  115. t[j*ldt+i] = 0
  116. }
  117. continue
  118. }
  119. var lastv int
  120. if i < k-1 {
  121. if store == lapack.ColumnWise {
  122. for lastv = 0; lastv < i; lastv++ {
  123. if v[lastv*ldv+i] != 0 {
  124. break
  125. }
  126. }
  127. for j := i + 1; j < k; j++ {
  128. t[j*ldt+i] = -tau[i] * v[(n-k+i)*ldv+j]
  129. }
  130. j := max(lastv, prevlastv)
  131. bi.Dgemv(blas.Trans, n-k+i-j, k-i-1,
  132. -tau[i], v[j*ldv+i+1:], ldv, v[j*ldv+i:], ldv,
  133. 1, t[(i+1)*ldt+i:], ldt)
  134. } else {
  135. for lastv = 0; lastv < i; lastv++ {
  136. if v[i*ldv+lastv] != 0 {
  137. break
  138. }
  139. }
  140. for j := i + 1; j < k; j++ {
  141. t[j*ldt+i] = -tau[i] * v[j*ldv+n-k+i]
  142. }
  143. j := max(lastv, prevlastv)
  144. bi.Dgemv(blas.NoTrans, k-i-1, n-k+i-j,
  145. -tau[i], v[(i+1)*ldv+j:], ldv, v[i*ldv+j:], 1,
  146. 1, t[(i+1)*ldt+i:], ldt)
  147. }
  148. bi.Dtrmv(blas.Lower, blas.NoTrans, blas.NonUnit, k-i-1,
  149. t[(i+1)*ldt+i+1:], ldt,
  150. t[(i+1)*ldt+i:], ldt)
  151. if i > 0 {
  152. prevlastv = min(prevlastv, lastv)
  153. } else {
  154. prevlastv = lastv
  155. }
  156. }
  157. t[i*ldt+i] = tau[i]
  158. }
  159. }