symband.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. // Copyright ©2017 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "gonum.org/v1/gonum/blas"
  7. "gonum.org/v1/gonum/blas/blas64"
  8. )
  9. var (
  10. symBandDense *SymBandDense
  11. _ Matrix = symBandDense
  12. _ Symmetric = symBandDense
  13. _ Banded = symBandDense
  14. _ SymBanded = symBandDense
  15. _ RawSymBander = symBandDense
  16. _ MutableSymBanded = symBandDense
  17. _ NonZeroDoer = symBandDense
  18. _ RowNonZeroDoer = symBandDense
  19. _ ColNonZeroDoer = symBandDense
  20. )
  21. // SymBandDense represents a symmetric band matrix in dense storage format.
  22. type SymBandDense struct {
  23. mat blas64.SymmetricBand
  24. }
  25. // SymBanded is a symmetric band matrix interface type.
  26. type SymBanded interface {
  27. Banded
  28. // Symmetric returns the number of rows/columns in the matrix.
  29. Symmetric() int
  30. // SymBand returns the number of rows/columns in the matrix, and the size of
  31. // the bandwidth.
  32. SymBand() (n, k int)
  33. }
  34. // MutableSymBanded is a symmetric band matrix interface type that allows elements
  35. // to be altered.
  36. type MutableSymBanded interface {
  37. SymBanded
  38. SetSymBand(i, j int, v float64)
  39. }
  40. // A RawSymBander can return a blas64.SymmetricBand representation of the receiver.
  41. // Changes to the blas64.SymmetricBand.Data slice will be reflected in the original
  42. // matrix, changes to the N, K, Stride and Uplo fields will not.
  43. type RawSymBander interface {
  44. RawSymBand() blas64.SymmetricBand
  45. }
  46. // NewSymBandDense creates a new SymBand matrix with n rows and columns. If data == nil,
  47. // a new slice is allocated for the backing slice. If len(data) == n*(k+1),
  48. // data is used as the backing slice, and changes to the elements of the returned
  49. // SymBandDense will be reflected in data. If neither of these is true, NewSymBandDense
  50. // will panic. k must be at least zero and less than n, otherwise NewSymBandDense will panic.
  51. //
  52. // The data must be arranged in row-major order constructed by removing the zeros
  53. // from the rows outside the band and aligning the diagonals. SymBandDense matrices
  54. // are stored in the upper triangle. For example, the matrix
  55. // 1 2 3 0 0 0
  56. // 2 4 5 6 0 0
  57. // 3 5 7 8 9 0
  58. // 0 6 8 10 11 12
  59. // 0 0 9 11 13 14
  60. // 0 0 0 12 14 15
  61. // becomes (* entries are never accessed)
  62. // 1 2 3
  63. // 4 5 6
  64. // 7 8 9
  65. // 10 11 12
  66. // 13 14 *
  67. // 15 * *
  68. // which is passed to NewSymBandDense as []float64{1, 2, ..., 15, *, *, *} with k=2.
  69. // Only the values in the band portion of the matrix are used.
  70. func NewSymBandDense(n, k int, data []float64) *SymBandDense {
  71. if n <= 0 || k < 0 {
  72. if n == 0 {
  73. panic(ErrZeroLength)
  74. }
  75. panic("mat: negative dimension")
  76. }
  77. if k+1 > n {
  78. panic("mat: band out of range")
  79. }
  80. bc := k + 1
  81. if data != nil && len(data) != n*bc {
  82. panic(ErrShape)
  83. }
  84. if data == nil {
  85. data = make([]float64, n*bc)
  86. }
  87. return &SymBandDense{
  88. mat: blas64.SymmetricBand{
  89. N: n,
  90. K: k,
  91. Stride: bc,
  92. Uplo: blas.Upper,
  93. Data: data,
  94. },
  95. }
  96. }
  97. // Dims returns the number of rows and columns in the matrix.
  98. func (s *SymBandDense) Dims() (r, c int) {
  99. return s.mat.N, s.mat.N
  100. }
  101. // Symmetric returns the size of the receiver.
  102. func (s *SymBandDense) Symmetric() int {
  103. return s.mat.N
  104. }
  105. // Bandwidth returns the bandwidths of the matrix.
  106. func (s *SymBandDense) Bandwidth() (kl, ku int) {
  107. return s.mat.K, s.mat.K
  108. }
  109. // SymBand returns the number of rows/columns in the matrix, and the size of
  110. // the bandwidth.
  111. func (s *SymBandDense) SymBand() (n, k int) {
  112. return s.mat.N, s.mat.K
  113. }
  114. // T implements the Matrix interface. Symmetric matrices, by definition, are
  115. // equal to their transpose, and this is a no-op.
  116. func (s *SymBandDense) T() Matrix {
  117. return s
  118. }
  119. // TBand implements the Banded interface.
  120. func (s *SymBandDense) TBand() Banded {
  121. return s
  122. }
  123. // RawSymBand returns the underlying blas64.SymBand used by the receiver.
  124. // Changes to elements in the receiver following the call will be reflected
  125. // in returned blas64.SymBand.
  126. func (s *SymBandDense) RawSymBand() blas64.SymmetricBand {
  127. return s.mat
  128. }
  129. // SetRawSymBand sets the underlying blas64.SymmetricBand used by the receiver.
  130. // Changes to elements in the receiver following the call will be reflected
  131. // in the input.
  132. //
  133. // The supplied SymmetricBand must use blas.Upper storage format.
  134. func (s *SymBandDense) SetRawSymBand(mat blas64.SymmetricBand) {
  135. if mat.Uplo != blas.Upper {
  136. panic("mat: blas64.SymmetricBand does not have blas.Upper storage")
  137. }
  138. s.mat = mat
  139. }
  140. // Zero sets all of the matrix elements to zero.
  141. func (s *SymBandDense) Zero() {
  142. for i := 0; i < s.mat.N; i++ {
  143. u := min(1+s.mat.K, s.mat.N-i)
  144. zero(s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+u])
  145. }
  146. }
  147. // DiagView returns the diagonal as a matrix backed by the original data.
  148. func (s *SymBandDense) DiagView() Diagonal {
  149. n := s.mat.N
  150. return &DiagDense{
  151. mat: blas64.Vector{
  152. N: n,
  153. Inc: s.mat.Stride,
  154. Data: s.mat.Data[:(n-1)*s.mat.Stride+1],
  155. },
  156. }
  157. }
  158. // DoNonZero calls the function fn for each of the non-zero elements of s. The function fn
  159. // takes a row/column index and the element value of s at (i, j).
  160. func (s *SymBandDense) DoNonZero(fn func(i, j int, v float64)) {
  161. for i := 0; i < s.mat.N; i++ {
  162. for j := max(0, i-s.mat.K); j < min(s.mat.N, i+s.mat.K+1); j++ {
  163. v := s.at(i, j)
  164. if v != 0 {
  165. fn(i, j, v)
  166. }
  167. }
  168. }
  169. }
  170. // DoRowNonZero calls the function fn for each of the non-zero elements of row i of s. The function fn
  171. // takes a row/column index and the element value of s at (i, j).
  172. func (s *SymBandDense) DoRowNonZero(i int, fn func(i, j int, v float64)) {
  173. if i < 0 || s.mat.N <= i {
  174. panic(ErrRowAccess)
  175. }
  176. for j := max(0, i-s.mat.K); j < min(s.mat.N, i+s.mat.K+1); j++ {
  177. v := s.at(i, j)
  178. if v != 0 {
  179. fn(i, j, v)
  180. }
  181. }
  182. }
  183. // DoColNonZero calls the function fn for each of the non-zero elements of column j of s. The function fn
  184. // takes a row/column index and the element value of s at (i, j).
  185. func (s *SymBandDense) DoColNonZero(j int, fn func(i, j int, v float64)) {
  186. if j < 0 || s.mat.N <= j {
  187. panic(ErrColAccess)
  188. }
  189. for i := 0; i < s.mat.N; i++ {
  190. if i-s.mat.K <= j && j < i+s.mat.K+1 {
  191. v := s.at(i, j)
  192. if v != 0 {
  193. fn(i, j, v)
  194. }
  195. }
  196. }
  197. }