blob: 55a388553df5d51acd707fa39c61482bc9c3fdf9 [file] [log] [blame]
khenaidoo26721882021-08-11 17:42:52 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "fmt"
9 "io"
khenaidoo26721882021-08-11 17:42:52 -040010)
11
12type byteBuffer interface {
13 // Read up to 8 bytes.
Abhay Kumar40252eb2025-10-13 13:25:53 +000014 // Returns io.ErrUnexpectedEOF if this cannot be satisfied.
15 readSmall(n int) ([]byte, error)
khenaidoo26721882021-08-11 17:42:52 -040016
17 // Read >8 bytes.
18 // MAY use the destination slice.
19 readBig(n int, dst []byte) ([]byte, error)
20
21 // Read a single byte.
22 readByte() (byte, error)
23
24 // Skip n bytes.
Abhay Kumar40252eb2025-10-13 13:25:53 +000025 skipN(n int64) error
khenaidoo26721882021-08-11 17:42:52 -040026}
27
28// in-memory buffer
29type byteBuf []byte
30
Abhay Kumar40252eb2025-10-13 13:25:53 +000031func (b *byteBuf) readSmall(n int) ([]byte, error) {
khenaidoo26721882021-08-11 17:42:52 -040032 if debugAsserts && n > 8 {
33 panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
34 }
35 bb := *b
36 if len(bb) < n {
Abhay Kumar40252eb2025-10-13 13:25:53 +000037 return nil, io.ErrUnexpectedEOF
khenaidoo26721882021-08-11 17:42:52 -040038 }
39 r := bb[:n]
40 *b = bb[n:]
Abhay Kumar40252eb2025-10-13 13:25:53 +000041 return r, nil
khenaidoo26721882021-08-11 17:42:52 -040042}
43
44func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) {
45 bb := *b
46 if len(bb) < n {
47 return nil, io.ErrUnexpectedEOF
48 }
49 r := bb[:n]
50 *b = bb[n:]
51 return r, nil
52}
53
khenaidoo26721882021-08-11 17:42:52 -040054func (b *byteBuf) readByte() (byte, error) {
55 bb := *b
56 if len(bb) < 1 {
Abhay Kumar40252eb2025-10-13 13:25:53 +000057 return 0, io.ErrUnexpectedEOF
khenaidoo26721882021-08-11 17:42:52 -040058 }
59 r := bb[0]
60 *b = bb[1:]
61 return r, nil
62}
63
Abhay Kumar40252eb2025-10-13 13:25:53 +000064func (b *byteBuf) skipN(n int64) error {
khenaidoo26721882021-08-11 17:42:52 -040065 bb := *b
Abhay Kumar40252eb2025-10-13 13:25:53 +000066 if n < 0 {
67 return fmt.Errorf("negative skip (%d) requested", n)
68 }
69 if int64(len(bb)) < n {
khenaidoo26721882021-08-11 17:42:52 -040070 return io.ErrUnexpectedEOF
71 }
72 *b = bb[n:]
73 return nil
74}
75
76// wrapper around a reader.
77type readerWrapper struct {
78 r io.Reader
79 tmp [8]byte
80}
81
Abhay Kumar40252eb2025-10-13 13:25:53 +000082func (r *readerWrapper) readSmall(n int) ([]byte, error) {
khenaidoo26721882021-08-11 17:42:52 -040083 if debugAsserts && n > 8 {
84 panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
85 }
86 n2, err := io.ReadFull(r.r, r.tmp[:n])
87 // We only really care about the actual bytes read.
Abhay Kumar40252eb2025-10-13 13:25:53 +000088 if err != nil {
89 if err == io.EOF {
90 return nil, io.ErrUnexpectedEOF
91 }
92 if debugDecoder {
khenaidoo26721882021-08-11 17:42:52 -040093 println("readSmall: got", n2, "want", n, "err", err)
94 }
Abhay Kumar40252eb2025-10-13 13:25:53 +000095 return nil, err
khenaidoo26721882021-08-11 17:42:52 -040096 }
Abhay Kumar40252eb2025-10-13 13:25:53 +000097 return r.tmp[:n], nil
khenaidoo26721882021-08-11 17:42:52 -040098}
99
100func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) {
101 if cap(dst) < n {
102 dst = make([]byte, n)
103 }
104 n2, err := io.ReadFull(r.r, dst[:n])
105 if err == io.EOF && n > 0 {
106 err = io.ErrUnexpectedEOF
107 }
108 return dst[:n2], err
109}
110
111func (r *readerWrapper) readByte() (byte, error) {
Abhay Kumar40252eb2025-10-13 13:25:53 +0000112 n2, err := io.ReadFull(r.r, r.tmp[:1])
khenaidoo26721882021-08-11 17:42:52 -0400113 if err != nil {
Abhay Kumar40252eb2025-10-13 13:25:53 +0000114 if err == io.EOF {
115 err = io.ErrUnexpectedEOF
116 }
khenaidoo26721882021-08-11 17:42:52 -0400117 return 0, err
118 }
119 if n2 != 1 {
120 return 0, io.ErrUnexpectedEOF
121 }
122 return r.tmp[0], nil
123}
124
Abhay Kumar40252eb2025-10-13 13:25:53 +0000125func (r *readerWrapper) skipN(n int64) error {
126 n2, err := io.CopyN(io.Discard, r.r, n)
127 if n2 != n {
khenaidoo26721882021-08-11 17:42:52 -0400128 err = io.ErrUnexpectedEOF
129 }
130 return err
131}