shadow.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "gonum.org/v1/gonum/blas/blas64"
  7. )
  8. const (
  9. // regionOverlap is the panic string used for the general case
  10. // of a matrix region overlap between a source and destination.
  11. regionOverlap = "mat: bad region: overlap"
  12. // regionIdentity is the panic string used for the specific
  13. // case of complete agreement between a source and a destination.
  14. regionIdentity = "mat: bad region: identical"
  15. // mismatchedStrides is the panic string used for overlapping
  16. // data slices with differing strides.
  17. mismatchedStrides = "mat: bad region: different strides"
  18. )
  19. // checkOverlap returns false if the receiver does not overlap data elements
  20. // referenced by the parameter and panics otherwise.
  21. //
  22. // checkOverlap methods return a boolean to allow the check call to be added to a
  23. // boolean expression, making use of short-circuit operators.
  24. func checkOverlap(a, b blas64.General) bool {
  25. if cap(a.Data) == 0 || cap(b.Data) == 0 {
  26. return false
  27. }
  28. off := offset(a.Data[:1], b.Data[:1])
  29. if off == 0 {
  30. // At least one element overlaps.
  31. if a.Cols == b.Cols && a.Rows == b.Rows && a.Stride == b.Stride {
  32. panic(regionIdentity)
  33. }
  34. panic(regionOverlap)
  35. }
  36. if off > 0 && len(a.Data) <= off {
  37. // We know a is completely before b.
  38. return false
  39. }
  40. if off < 0 && len(b.Data) <= -off {
  41. // We know a is completely after b.
  42. return false
  43. }
  44. if a.Stride != b.Stride {
  45. // Too hard, so assume the worst.
  46. panic(mismatchedStrides)
  47. }
  48. if off < 0 {
  49. off = -off
  50. a.Cols, b.Cols = b.Cols, a.Cols
  51. }
  52. if rectanglesOverlap(off, a.Cols, b.Cols, a.Stride) {
  53. panic(regionOverlap)
  54. }
  55. return false
  56. }
  57. func (m *Dense) checkOverlap(a blas64.General) bool {
  58. return checkOverlap(m.RawMatrix(), a)
  59. }
  60. func (m *Dense) checkOverlapMatrix(a Matrix) bool {
  61. if m == a {
  62. return false
  63. }
  64. var amat blas64.General
  65. switch a := a.(type) {
  66. default:
  67. return false
  68. case RawMatrixer:
  69. amat = a.RawMatrix()
  70. case RawSymmetricer:
  71. amat = generalFromSymmetric(a.RawSymmetric())
  72. case RawTriangular:
  73. amat = generalFromTriangular(a.RawTriangular())
  74. }
  75. return m.checkOverlap(amat)
  76. }
  77. func (s *SymDense) checkOverlap(a blas64.General) bool {
  78. return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a)
  79. }
  80. func (s *SymDense) checkOverlapMatrix(a Matrix) bool {
  81. if s == a {
  82. return false
  83. }
  84. var amat blas64.General
  85. switch a := a.(type) {
  86. default:
  87. return false
  88. case RawMatrixer:
  89. amat = a.RawMatrix()
  90. case RawSymmetricer:
  91. amat = generalFromSymmetric(a.RawSymmetric())
  92. case RawTriangular:
  93. amat = generalFromTriangular(a.RawTriangular())
  94. }
  95. return s.checkOverlap(amat)
  96. }
  97. // generalFromSymmetric returns a blas64.General with the backing
  98. // data and dimensions of a.
  99. func generalFromSymmetric(a blas64.Symmetric) blas64.General {
  100. return blas64.General{
  101. Rows: a.N,
  102. Cols: a.N,
  103. Stride: a.Stride,
  104. Data: a.Data,
  105. }
  106. }
  107. func (t *TriDense) checkOverlap(a blas64.General) bool {
  108. return checkOverlap(generalFromTriangular(t.RawTriangular()), a)
  109. }
  110. func (t *TriDense) checkOverlapMatrix(a Matrix) bool {
  111. if t == a {
  112. return false
  113. }
  114. var amat blas64.General
  115. switch a := a.(type) {
  116. default:
  117. return false
  118. case RawMatrixer:
  119. amat = a.RawMatrix()
  120. case RawSymmetricer:
  121. amat = generalFromSymmetric(a.RawSymmetric())
  122. case RawTriangular:
  123. amat = generalFromTriangular(a.RawTriangular())
  124. }
  125. return t.checkOverlap(amat)
  126. }
  127. // generalFromTriangular returns a blas64.General with the backing
  128. // data and dimensions of a.
  129. func generalFromTriangular(a blas64.Triangular) blas64.General {
  130. return blas64.General{
  131. Rows: a.N,
  132. Cols: a.N,
  133. Stride: a.Stride,
  134. Data: a.Data,
  135. }
  136. }
  137. func (v *VecDense) checkOverlap(a blas64.Vector) bool {
  138. mat := v.mat
  139. if cap(mat.Data) == 0 || cap(a.Data) == 0 {
  140. return false
  141. }
  142. off := offset(mat.Data[:1], a.Data[:1])
  143. if off == 0 {
  144. // At least one element overlaps.
  145. if mat.Inc == a.Inc && len(mat.Data) == len(a.Data) {
  146. panic(regionIdentity)
  147. }
  148. panic(regionOverlap)
  149. }
  150. if off > 0 && len(mat.Data) <= off {
  151. // We know v is completely before a.
  152. return false
  153. }
  154. if off < 0 && len(a.Data) <= -off {
  155. // We know v is completely after a.
  156. return false
  157. }
  158. if mat.Inc != a.Inc {
  159. // Too hard, so assume the worst.
  160. panic(mismatchedStrides)
  161. }
  162. if mat.Inc == 1 || off&mat.Inc == 0 {
  163. panic(regionOverlap)
  164. }
  165. return false
  166. }
  167. // rectanglesOverlap returns whether the strided rectangles a and b overlap
  168. // when b is offset by off elements after a but has at least one element before
  169. // the end of a. off must be positive. a and b have aCols and bCols respectively.
  170. //
  171. // rectanglesOverlap works by shifting both matrices left such that the left
  172. // column of a is at 0. The column indexes are flattened by obtaining the shifted
  173. // relative left and right column positions modulo the common stride. This allows
  174. // direct comparison of the column offsets when the matrix backing data slices
  175. // are known to overlap.
  176. func rectanglesOverlap(off, aCols, bCols, stride int) bool {
  177. if stride == 1 {
  178. // Unit stride means overlapping data
  179. // slices must overlap as matrices.
  180. return true
  181. }
  182. // Flatten the shifted matrix column positions
  183. // so a starts at 0, modulo the common stride.
  184. aTo := aCols
  185. // The mod stride operations here make the from
  186. // and to indexes comparable between a and b when
  187. // the data slices of a and b overlap.
  188. bFrom := off % stride
  189. bTo := (bFrom + bCols) % stride
  190. if bTo == 0 || bFrom < bTo {
  191. // b matrix is not wrapped: compare for
  192. // simple overlap.
  193. return bFrom < aTo
  194. }
  195. // b strictly wraps and so must overlap with a.
  196. return true
  197. }