io.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "bytes"
  7. "encoding/binary"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "math"
  12. )
  13. // version is the current on-disk codec version.
  14. const version uint32 = 0x1
  15. // maxLen is the biggest slice/array len one can create on a 32/64b platform.
  16. const maxLen = int64(int(^uint(0) >> 1))
  17. var (
  18. headerSize = binary.Size(storage{})
  19. sizeInt64 = binary.Size(int64(0))
  20. sizeFloat64 = binary.Size(float64(0))
  21. errWrongType = errors.New("mat: wrong data type")
  22. errTooBig = errors.New("mat: resulting data slice too big")
  23. errTooSmall = errors.New("mat: input slice too small")
  24. errBadBuffer = errors.New("mat: data buffer size mismatch")
  25. errBadSize = errors.New("mat: invalid dimension")
  26. )
  27. // Type encoding scheme:
  28. //
  29. // Type Form Packing Uplo Unit Rows Columns kU kL
  30. // uint8 [GST] uint8 [BPF] uint8 [AUL] bool int64 int64 int64 int64
  31. // General 'G' 'F' 'A' false r c 0 0
  32. // Band 'G' 'B' 'A' false r c kU kL
  33. // Symmetric 'S' 'F' ul false n n 0 0
  34. // SymmetricBand 'S' 'B' ul false n n k k
  35. // SymmetricPacked 'S' 'P' ul false n n 0 0
  36. // Triangular 'T' 'F' ul Diag==Unit n n 0 0
  37. // TriangularBand 'T' 'B' ul Diag==Unit n n k k
  38. // TriangularPacked 'T' 'P' ul Diag==Unit n n 0 0
  39. //
  40. // G - general, S - symmetric, T - triangular
  41. // F - full, B - band, P - packed
  42. // A - all, U - upper, L - lower
  43. // MarshalBinary encodes the receiver into a binary form and returns the result.
  44. //
  45. // Dense is little-endian encoded as follows:
  46. // 0 - 3 Version = 1 (uint32)
  47. // 4 'G' (byte)
  48. // 5 'F' (byte)
  49. // 6 'A' (byte)
  50. // 7 0 (byte)
  51. // 8 - 15 number of rows (int64)
  52. // 16 - 23 number of columns (int64)
  53. // 24 - 31 0 (int64)
  54. // 32 - 39 0 (int64)
  55. // 40 - .. matrix data elements (float64)
  56. // [0,0] [0,1] ... [0,ncols-1]
  57. // [1,0] [1,1] ... [1,ncols-1]
  58. // ...
  59. // [nrows-1,0] ... [nrows-1,ncols-1]
  60. func (m Dense) MarshalBinary() ([]byte, error) {
  61. bufLen := int64(headerSize) + int64(m.mat.Rows)*int64(m.mat.Cols)*int64(sizeFloat64)
  62. if bufLen <= 0 {
  63. // bufLen is too big and has wrapped around.
  64. return nil, errTooBig
  65. }
  66. header := storage{
  67. Form: 'G', Packing: 'F', Uplo: 'A',
  68. Rows: int64(m.mat.Rows), Cols: int64(m.mat.Cols),
  69. Version: version,
  70. }
  71. buf := make([]byte, bufLen)
  72. n, err := header.marshalBinaryTo(bytes.NewBuffer(buf[:0]))
  73. if err != nil {
  74. return buf[:n], err
  75. }
  76. p := headerSize
  77. r, c := m.Dims()
  78. for i := 0; i < r; i++ {
  79. for j := 0; j < c; j++ {
  80. binary.LittleEndian.PutUint64(buf[p:p+sizeFloat64], math.Float64bits(m.at(i, j)))
  81. p += sizeFloat64
  82. }
  83. }
  84. return buf, nil
  85. }
  86. // MarshalBinaryTo encodes the receiver into a binary form and writes it into w.
  87. // MarshalBinaryTo returns the number of bytes written into w and an error, if any.
  88. //
  89. // See MarshalBinary for the on-disk layout.
  90. func (m Dense) MarshalBinaryTo(w io.Writer) (int, error) {
  91. header := storage{
  92. Form: 'G', Packing: 'F', Uplo: 'A',
  93. Rows: int64(m.mat.Rows), Cols: int64(m.mat.Cols),
  94. Version: version,
  95. }
  96. n, err := header.marshalBinaryTo(w)
  97. if err != nil {
  98. return n, err
  99. }
  100. r, c := m.Dims()
  101. var b [8]byte
  102. for i := 0; i < r; i++ {
  103. for j := 0; j < c; j++ {
  104. binary.LittleEndian.PutUint64(b[:], math.Float64bits(m.at(i, j)))
  105. nn, err := w.Write(b[:])
  106. n += nn
  107. if err != nil {
  108. return n, err
  109. }
  110. }
  111. }
  112. return n, nil
  113. }
  114. // UnmarshalBinary decodes the binary form into the receiver.
  115. // It panics if the receiver is a non-zero Dense matrix.
  116. //
  117. // See MarshalBinary for the on-disk layout.
  118. //
  119. // Limited checks on the validity of the binary input are performed:
  120. // - matrix.ErrShape is returned if the number of rows or columns is negative,
  121. // - an error is returned if the resulting Dense matrix is too
  122. // big for the current architecture (e.g. a 16GB matrix written by a
  123. // 64b application and read back from a 32b application.)
  124. // UnmarshalBinary does not limit the size of the unmarshaled matrix, and so
  125. // it should not be used on untrusted data.
  126. func (m *Dense) UnmarshalBinary(data []byte) error {
  127. if !m.IsZero() {
  128. panic("mat: unmarshal into non-zero matrix")
  129. }
  130. if len(data) < headerSize {
  131. return errTooSmall
  132. }
  133. var header storage
  134. err := header.unmarshalBinary(data[:headerSize])
  135. if err != nil {
  136. return err
  137. }
  138. rows := header.Rows
  139. cols := header.Cols
  140. header.Version = 0
  141. header.Rows = 0
  142. header.Cols = 0
  143. if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
  144. return errWrongType
  145. }
  146. if rows < 0 || cols < 0 {
  147. return errBadSize
  148. }
  149. size := rows * cols
  150. if size == 0 {
  151. return ErrZeroLength
  152. }
  153. if int(size) < 0 || size > maxLen {
  154. return errTooBig
  155. }
  156. if len(data) != headerSize+int(rows*cols)*sizeFloat64 {
  157. return errBadBuffer
  158. }
  159. p := headerSize
  160. m.reuseAs(int(rows), int(cols))
  161. for i := range m.mat.Data {
  162. m.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[p : p+sizeFloat64]))
  163. p += sizeFloat64
  164. }
  165. return nil
  166. }
  167. // UnmarshalBinaryFrom decodes the binary form into the receiver and returns
  168. // the number of bytes read and an error if any.
  169. // It panics if the receiver is a non-zero Dense matrix.
  170. //
  171. // See MarshalBinary for the on-disk layout.
  172. //
  173. // Limited checks on the validity of the binary input are performed:
  174. // - matrix.ErrShape is returned if the number of rows or columns is negative,
  175. // - an error is returned if the resulting Dense matrix is too
  176. // big for the current architecture (e.g. a 16GB matrix written by a
  177. // 64b application and read back from a 32b application.)
  178. // UnmarshalBinary does not limit the size of the unmarshaled matrix, and so
  179. // it should not be used on untrusted data.
  180. func (m *Dense) UnmarshalBinaryFrom(r io.Reader) (int, error) {
  181. if !m.IsZero() {
  182. panic("mat: unmarshal into non-zero matrix")
  183. }
  184. var header storage
  185. n, err := header.unmarshalBinaryFrom(r)
  186. if err != nil {
  187. return n, err
  188. }
  189. rows := header.Rows
  190. cols := header.Cols
  191. header.Version = 0
  192. header.Rows = 0
  193. header.Cols = 0
  194. if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
  195. return n, errWrongType
  196. }
  197. if rows < 0 || cols < 0 {
  198. return n, errBadSize
  199. }
  200. size := rows * cols
  201. if size == 0 {
  202. return n, ErrZeroLength
  203. }
  204. if int(size) < 0 || size > maxLen {
  205. return n, errTooBig
  206. }
  207. m.reuseAs(int(rows), int(cols))
  208. var b [8]byte
  209. for i := range m.mat.Data {
  210. nn, err := readFull(r, b[:])
  211. n += nn
  212. if err != nil {
  213. if err == io.EOF {
  214. return n, io.ErrUnexpectedEOF
  215. }
  216. return n, err
  217. }
  218. m.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(b[:]))
  219. }
  220. return n, nil
  221. }
  222. // MarshalBinary encodes the receiver into a binary form and returns the result.
  223. //
  224. // VecDense is little-endian encoded as follows:
  225. //
  226. // 0 - 3 Version = 1 (uint32)
  227. // 4 'G' (byte)
  228. // 5 'F' (byte)
  229. // 6 'A' (byte)
  230. // 7 0 (byte)
  231. // 8 - 15 number of elements (int64)
  232. // 16 - 23 1 (int64)
  233. // 24 - 31 0 (int64)
  234. // 32 - 39 0 (int64)
  235. // 40 - .. vector's data elements (float64)
  236. func (v VecDense) MarshalBinary() ([]byte, error) {
  237. bufLen := int64(headerSize) + int64(v.mat.N)*int64(sizeFloat64)
  238. if bufLen <= 0 {
  239. // bufLen is too big and has wrapped around.
  240. return nil, errTooBig
  241. }
  242. header := storage{
  243. Form: 'G', Packing: 'F', Uplo: 'A',
  244. Rows: int64(v.mat.N), Cols: 1,
  245. Version: version,
  246. }
  247. buf := make([]byte, bufLen)
  248. n, err := header.marshalBinaryTo(bytes.NewBuffer(buf[:0]))
  249. if err != nil {
  250. return buf[:n], err
  251. }
  252. p := headerSize
  253. for i := 0; i < v.mat.N; i++ {
  254. binary.LittleEndian.PutUint64(buf[p:p+sizeFloat64], math.Float64bits(v.at(i)))
  255. p += sizeFloat64
  256. }
  257. return buf, nil
  258. }
  259. // MarshalBinaryTo encodes the receiver into a binary form, writes it to w and
  260. // returns the number of bytes written and an error if any.
  261. //
  262. // See MarshalBainry for the on-disk format.
  263. func (v VecDense) MarshalBinaryTo(w io.Writer) (int, error) {
  264. header := storage{
  265. Form: 'G', Packing: 'F', Uplo: 'A',
  266. Rows: int64(v.mat.N), Cols: 1,
  267. Version: version,
  268. }
  269. n, err := header.marshalBinaryTo(w)
  270. if err != nil {
  271. return n, err
  272. }
  273. var buf [8]byte
  274. for i := 0; i < v.mat.N; i++ {
  275. binary.LittleEndian.PutUint64(buf[:], math.Float64bits(v.at(i)))
  276. nn, err := w.Write(buf[:])
  277. n += nn
  278. if err != nil {
  279. return n, err
  280. }
  281. }
  282. return n, nil
  283. }
  284. // UnmarshalBinary decodes the binary form into the receiver.
  285. // It panics if the receiver is a non-zero VecDense.
  286. //
  287. // See MarshalBinary for the on-disk layout.
  288. //
  289. // Limited checks on the validity of the binary input are performed:
  290. // - matrix.ErrShape is returned if the number of rows is negative,
  291. // - an error is returned if the resulting VecDense is too
  292. // big for the current architecture (e.g. a 16GB vector written by a
  293. // 64b application and read back from a 32b application.)
  294. // UnmarshalBinary does not limit the size of the unmarshaled vector, and so
  295. // it should not be used on untrusted data.
  296. func (v *VecDense) UnmarshalBinary(data []byte) error {
  297. if !v.IsZero() {
  298. panic("mat: unmarshal into non-zero vector")
  299. }
  300. if len(data) < headerSize {
  301. return errTooSmall
  302. }
  303. var header storage
  304. err := header.unmarshalBinary(data[:headerSize])
  305. if err != nil {
  306. return err
  307. }
  308. if header.Cols != 1 {
  309. return ErrShape
  310. }
  311. n := header.Rows
  312. header.Version = 0
  313. header.Rows = 0
  314. header.Cols = 0
  315. if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
  316. return errWrongType
  317. }
  318. if n == 0 {
  319. return ErrZeroLength
  320. }
  321. if n < 0 {
  322. return errBadSize
  323. }
  324. if int64(maxLen) < n {
  325. return errTooBig
  326. }
  327. if len(data) != headerSize+int(n)*sizeFloat64 {
  328. return errBadBuffer
  329. }
  330. p := headerSize
  331. v.reuseAs(int(n))
  332. for i := range v.mat.Data {
  333. v.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[p : p+sizeFloat64]))
  334. p += sizeFloat64
  335. }
  336. return nil
  337. }
  338. // UnmarshalBinaryFrom decodes the binary form into the receiver, from the
  339. // io.Reader and returns the number of bytes read and an error if any.
  340. // It panics if the receiver is a non-zero VecDense.
  341. //
  342. // See MarshalBinary for the on-disk layout.
  343. // See UnmarshalBinary for the list of sanity checks performed on the input.
  344. func (v *VecDense) UnmarshalBinaryFrom(r io.Reader) (int, error) {
  345. if !v.IsZero() {
  346. panic("mat: unmarshal into non-zero vector")
  347. }
  348. var header storage
  349. n, err := header.unmarshalBinaryFrom(r)
  350. if err != nil {
  351. return n, err
  352. }
  353. if header.Cols != 1 {
  354. return n, ErrShape
  355. }
  356. l := header.Rows
  357. header.Version = 0
  358. header.Rows = 0
  359. header.Cols = 0
  360. if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
  361. return n, errWrongType
  362. }
  363. if l == 0 {
  364. return n, ErrZeroLength
  365. }
  366. if l < 0 {
  367. return n, errBadSize
  368. }
  369. if int64(maxLen) < l {
  370. return n, errTooBig
  371. }
  372. v.reuseAs(int(l))
  373. var b [8]byte
  374. for i := range v.mat.Data {
  375. nn, err := readFull(r, b[:])
  376. n += nn
  377. if err != nil {
  378. if err == io.EOF {
  379. return n, io.ErrUnexpectedEOF
  380. }
  381. return n, err
  382. }
  383. v.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(b[:]))
  384. }
  385. return n, nil
  386. }
  387. // storage is the internal representation of the storage format of a
  388. // serialised matrix.
  389. type storage struct {
  390. Version uint32 // Keep this first.
  391. Form byte // [GST]
  392. Packing byte // [BPF]
  393. Uplo byte // [AUL]
  394. Unit bool
  395. Rows int64
  396. Cols int64
  397. KU int64
  398. KL int64
  399. }
  400. // TODO(kortschak): Consider replacing these with calls to direct
  401. // encoding/decoding of fields rather than to binary.Write/binary.Read.
  402. func (s storage) marshalBinaryTo(w io.Writer) (int, error) {
  403. buf := bytes.NewBuffer(make([]byte, 0, headerSize))
  404. err := binary.Write(buf, binary.LittleEndian, s)
  405. if err != nil {
  406. return 0, err
  407. }
  408. return w.Write(buf.Bytes())
  409. }
  410. func (s *storage) unmarshalBinary(buf []byte) error {
  411. err := binary.Read(bytes.NewReader(buf), binary.LittleEndian, s)
  412. if err != nil {
  413. return err
  414. }
  415. if s.Version != version {
  416. return fmt.Errorf("mat: incorrect version: %d", s.Version)
  417. }
  418. return nil
  419. }
  420. func (s *storage) unmarshalBinaryFrom(r io.Reader) (int, error) {
  421. buf := make([]byte, headerSize)
  422. n, err := readFull(r, buf)
  423. if err != nil {
  424. return n, err
  425. }
  426. return n, s.unmarshalBinary(buf[:n])
  427. }
  428. // readFull reads from r into buf until it has read len(buf).
  429. // It returns the number of bytes copied and an error if fewer bytes were read.
  430. // If an EOF happens after reading fewer than len(buf) bytes, io.ErrUnexpectedEOF is returned.
  431. func readFull(r io.Reader, buf []byte) (int, error) {
  432. var n int
  433. var err error
  434. for n < len(buf) && err == nil {
  435. var nn int
  436. nn, err = r.Read(buf[n:])
  437. n += nn
  438. }
  439. if n == len(buf) {
  440. return n, nil
  441. }
  442. if err == io.EOF {
  443. return n, io.ErrUnexpectedEOF
  444. }
  445. return n, err
  446. }