diagonal.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. // Copyright ©2018 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "gonum.org/v1/gonum/blas"
  7. "gonum.org/v1/gonum/blas/blas64"
  8. )
  9. var (
  10. diagDense *DiagDense
  11. _ Matrix = diagDense
  12. _ Diagonal = diagDense
  13. _ MutableDiagonal = diagDense
  14. _ Triangular = diagDense
  15. _ TriBanded = diagDense
  16. _ Symmetric = diagDense
  17. _ SymBanded = diagDense
  18. _ Banded = diagDense
  19. _ RawBander = diagDense
  20. _ RawSymBander = diagDense
  21. diag Diagonal
  22. _ Matrix = diag
  23. _ Diagonal = diag
  24. _ Triangular = diag
  25. _ TriBanded = diag
  26. _ Symmetric = diag
  27. _ SymBanded = diag
  28. _ Banded = diag
  29. )
  30. // Diagonal represents a diagonal matrix, that is a square matrix that only
  31. // has non-zero terms on the diagonal.
  32. type Diagonal interface {
  33. Matrix
  34. // Diag returns the number of rows/columns in the matrix.
  35. Diag() int
  36. // Bandwidth and TBand are included in the Diagonal interface
  37. // to allow the use of Diagonal types in banded functions.
  38. // Bandwidth will always return (0, 0).
  39. Bandwidth() (kl, ku int)
  40. TBand() Banded
  41. // Triangle and TTri are included in the Diagonal interface
  42. // to allow the use of Diagonal types in triangular functions.
  43. Triangle() (int, TriKind)
  44. TTri() Triangular
  45. // Symmetric and SymBand are included in the Diagonal interface
  46. // to allow the use of Diagonal types in symmetric and banded symmetric
  47. // functions respectively.
  48. Symmetric() int
  49. SymBand() (n, k int)
  50. // TriBand and TTriBand are included in the Diagonal interface
  51. // to allow the use of Diagonal types in triangular banded functions.
  52. TriBand() (n, k int, kind TriKind)
  53. TTriBand() TriBanded
  54. }
  55. // MutableDiagonal is a Diagonal matrix whose elements can be set.
  56. type MutableDiagonal interface {
  57. Diagonal
  58. SetDiag(i int, v float64)
  59. }
  60. // DiagDense represents a diagonal matrix in dense storage format.
  61. type DiagDense struct {
  62. mat blas64.Vector
  63. }
  64. // NewDiagDense creates a new Diagonal matrix with n rows and n columns.
  65. // The length of data must be n or data must be nil, otherwise NewDiagDense
  66. // will panic. NewDiagDense will panic if n is zero.
  67. func NewDiagDense(n int, data []float64) *DiagDense {
  68. if n <= 0 {
  69. if n == 0 {
  70. panic(ErrZeroLength)
  71. }
  72. panic("mat: negative dimension")
  73. }
  74. if data == nil {
  75. data = make([]float64, n)
  76. }
  77. if len(data) != n {
  78. panic(ErrShape)
  79. }
  80. return &DiagDense{
  81. mat: blas64.Vector{N: n, Data: data, Inc: 1},
  82. }
  83. }
  84. // Diag returns the dimension of the receiver.
  85. func (d *DiagDense) Diag() int {
  86. return d.mat.N
  87. }
  88. // Dims returns the dimensions of the matrix.
  89. func (d *DiagDense) Dims() (r, c int) {
  90. return d.mat.N, d.mat.N
  91. }
  92. // T returns the transpose of the matrix.
  93. func (d *DiagDense) T() Matrix {
  94. return d
  95. }
  96. // TTri returns the transpose of the matrix. Note that Diagonal matrices are
  97. // Upper by default.
  98. func (d *DiagDense) TTri() Triangular {
  99. return TransposeTri{d}
  100. }
  101. // TBand performs an implicit transpose by returning the receiver inside a
  102. // TransposeBand.
  103. func (d *DiagDense) TBand() Banded {
  104. return TransposeBand{d}
  105. }
  106. // TTriBand performs an implicit transpose by returning the receiver inside a
  107. // TransposeTriBand. Note that Diagonal matrices are Upper by default.
  108. func (d *DiagDense) TTriBand() TriBanded {
  109. return TransposeTriBand{d}
  110. }
  111. // Bandwidth returns the upper and lower bandwidths of the matrix.
  112. // These values are always zero for diagonal matrices.
  113. func (d *DiagDense) Bandwidth() (kl, ku int) {
  114. return 0, 0
  115. }
  116. // Symmetric implements the Symmetric interface.
  117. func (d *DiagDense) Symmetric() int {
  118. return d.mat.N
  119. }
  120. // SymBand returns the number of rows/columns in the matrix, and the size of
  121. // the bandwidth.
  122. func (d *DiagDense) SymBand() (n, k int) {
  123. return d.mat.N, 0
  124. }
  125. // Triangle implements the Triangular interface.
  126. func (d *DiagDense) Triangle() (int, TriKind) {
  127. return d.mat.N, Upper
  128. }
  129. // TriBand returns the number of rows/columns in the matrix, the
  130. // size of the bandwidth, and the orientation. Note that Diagonal matrices are
  131. // Upper by default.
  132. func (d *DiagDense) TriBand() (n, k int, kind TriKind) {
  133. return d.mat.N, 0, Upper
  134. }
  135. // Reset zeros the length of the matrix so that it can be reused as the
  136. // receiver of a dimensionally restricted operation.
  137. //
  138. // See the Reseter interface for more information.
  139. func (d *DiagDense) Reset() {
  140. // No change of Inc or n to 0 may be
  141. // made unless both are set to 0.
  142. d.mat.Inc = 0
  143. d.mat.N = 0
  144. d.mat.Data = d.mat.Data[:0]
  145. }
  146. // Zero sets all of the matrix elements to zero.
  147. func (d *DiagDense) Zero() {
  148. for i := 0; i < d.mat.N; i++ {
  149. d.mat.Data[d.mat.Inc*i] = 0
  150. }
  151. }
  152. // DiagView returns the diagonal as a matrix backed by the original data.
  153. func (d *DiagDense) DiagView() Diagonal {
  154. return d
  155. }
  156. // DiagFrom copies the diagonal of m into the receiver. The receiver must
  157. // be min(r, c) long or zero. Otherwise DiagFrom will panic.
  158. func (d *DiagDense) DiagFrom(m Matrix) {
  159. n := min(m.Dims())
  160. d.reuseAs(n)
  161. var vec blas64.Vector
  162. switch r := m.(type) {
  163. case *DiagDense:
  164. vec = r.mat
  165. case RawBander:
  166. mat := r.RawBand()
  167. vec = blas64.Vector{
  168. N: n,
  169. Inc: mat.Stride,
  170. Data: mat.Data[mat.KL : (n-1)*mat.Stride+mat.KL+1],
  171. }
  172. case RawMatrixer:
  173. mat := r.RawMatrix()
  174. vec = blas64.Vector{
  175. N: n,
  176. Inc: mat.Stride + 1,
  177. Data: mat.Data[:(n-1)*mat.Stride+n],
  178. }
  179. case RawSymBander:
  180. mat := r.RawSymBand()
  181. vec = blas64.Vector{
  182. N: n,
  183. Inc: mat.Stride,
  184. Data: mat.Data[:(n-1)*mat.Stride+1],
  185. }
  186. case RawSymmetricer:
  187. mat := r.RawSymmetric()
  188. vec = blas64.Vector{
  189. N: n,
  190. Inc: mat.Stride + 1,
  191. Data: mat.Data[:(n-1)*mat.Stride+n],
  192. }
  193. case RawTriBander:
  194. mat := r.RawTriBand()
  195. data := mat.Data
  196. if mat.Uplo == blas.Lower {
  197. data = data[mat.K:]
  198. }
  199. vec = blas64.Vector{
  200. N: n,
  201. Inc: mat.Stride,
  202. Data: data[:(n-1)*mat.Stride+1],
  203. }
  204. case RawTriangular:
  205. mat := r.RawTriangular()
  206. if mat.Diag == blas.Unit {
  207. for i := 0; i < n; i += d.mat.Inc {
  208. d.mat.Data[i] = 1
  209. }
  210. return
  211. }
  212. vec = blas64.Vector{
  213. N: n,
  214. Inc: mat.Stride + 1,
  215. Data: mat.Data[:(n-1)*mat.Stride+n],
  216. }
  217. case RawVectorer:
  218. d.mat.Data[0] = r.RawVector().Data[0]
  219. return
  220. default:
  221. for i := 0; i < n; i++ {
  222. d.setDiag(i, m.At(i, i))
  223. }
  224. return
  225. }
  226. blas64.Copy(vec, d.mat)
  227. }
  228. // RawBand returns the underlying data used by the receiver represented
  229. // as a blas64.Band.
  230. // Changes to elements in the receiver following the call will be reflected
  231. // in returned blas64.Band.
  232. func (d *DiagDense) RawBand() blas64.Band {
  233. return blas64.Band{
  234. Rows: d.mat.N,
  235. Cols: d.mat.N,
  236. KL: 0,
  237. KU: 0,
  238. Stride: d.mat.Inc,
  239. Data: d.mat.Data,
  240. }
  241. }
  242. // RawSymBand returns the underlying data used by the receiver represented
  243. // as a blas64.SymmetricBand.
  244. // Changes to elements in the receiver following the call will be reflected
  245. // in returned blas64.Band.
  246. func (d *DiagDense) RawSymBand() blas64.SymmetricBand {
  247. return blas64.SymmetricBand{
  248. N: d.mat.N,
  249. K: 0,
  250. Stride: d.mat.Inc,
  251. Uplo: blas.Upper,
  252. Data: d.mat.Data,
  253. }
  254. }
  255. // reuseAs resizes an empty diagonal to a r×r diagonal,
  256. // or checks that a non-empty matrix is r×r.
  257. func (d *DiagDense) reuseAs(r int) {
  258. if r == 0 {
  259. panic(ErrZeroLength)
  260. }
  261. if d.IsZero() {
  262. d.mat = blas64.Vector{
  263. Inc: 1,
  264. Data: use(d.mat.Data, r),
  265. }
  266. d.mat.N = r
  267. return
  268. }
  269. if r != d.mat.N {
  270. panic(ErrShape)
  271. }
  272. }
  273. // IsZero returns whether the receiver is zero-sized. Zero-sized vectors can be the
  274. // receiver for size-restricted operations. DiagDenses can be zeroed using Reset.
  275. func (d *DiagDense) IsZero() bool {
  276. // It must be the case that d.Dims() returns
  277. // zeros in this case. See comment in Reset().
  278. return d.mat.Inc == 0
  279. }