blob: d41e3e1709b825b0bc7a8718a531e77ad402d700 [file] [log] [blame]
khenaidood948f772021-08-11 17:49:24 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
khenaidood948f772021-08-11 17:49:24 -04008 "errors"
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +05309 "fmt"
khenaidood948f772021-08-11 17:49:24 -040010 "io"
11 "math/bits"
Abhay Kumara2ae5992025-11-10 14:02:24 +000012
13 "github.com/klauspost/compress/internal/le"
khenaidood948f772021-08-11 17:49:24 -040014)
15
16// bitReader reads a bitstream in reverse.
17// The last set bit indicates the start of the stream and is used
18// for aligning the input.
19type bitReader struct {
20 in []byte
khenaidood948f772021-08-11 17:49:24 -040021 value uint64 // Maybe use [16]byte, but shifting is awkward.
Abhay Kumara2ae5992025-11-10 14:02:24 +000022 cursor int // offset where next read should end
khenaidood948f772021-08-11 17:49:24 -040023 bitsRead uint8
24}
25
26// init initializes and resets the bit reader.
27func (b *bitReader) init(in []byte) error {
28 if len(in) < 1 {
29 return errors.New("corrupt stream: too short")
30 }
31 b.in = in
khenaidood948f772021-08-11 17:49:24 -040032 // The highest bit of the last byte indicates where to start
33 v := in[len(in)-1]
34 if v == 0 {
35 return errors.New("corrupt stream, did not find end of stream")
36 }
Abhay Kumara2ae5992025-11-10 14:02:24 +000037 b.cursor = len(in)
khenaidood948f772021-08-11 17:49:24 -040038 b.bitsRead = 64
39 b.value = 0
40 if len(in) >= 8 {
41 b.fillFastStart()
42 } else {
43 b.fill()
44 b.fill()
45 }
46 b.bitsRead += 8 - uint8(highBits(uint32(v)))
47 return nil
48}
49
50// getBits will return n bits. n can be 0.
51func (b *bitReader) getBits(n uint8) int {
52 if n == 0 /*|| b.bitsRead >= 64 */ {
53 return 0
54 }
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +053055 return int(b.get32BitsFast(n))
khenaidood948f772021-08-11 17:49:24 -040056}
57
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +053058// get32BitsFast requires that at least one bit is requested every time.
khenaidood948f772021-08-11 17:49:24 -040059// There are no checks if the buffer is filled.
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +053060func (b *bitReader) get32BitsFast(n uint8) uint32 {
khenaidood948f772021-08-11 17:49:24 -040061 const regMask = 64 - 1
62 v := uint32((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask))
63 b.bitsRead += n
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +053064 return v
khenaidood948f772021-08-11 17:49:24 -040065}
66
67// fillFast() will make sure at least 32 bits are available.
68// There must be at least 4 bytes available.
69func (b *bitReader) fillFast() {
70 if b.bitsRead < 32 {
71 return
72 }
Abhay Kumara2ae5992025-11-10 14:02:24 +000073 b.cursor -= 4
74 b.value = (b.value << 32) | uint64(le.Load32(b.in, b.cursor))
khenaidood948f772021-08-11 17:49:24 -040075 b.bitsRead -= 32
khenaidood948f772021-08-11 17:49:24 -040076}
77
78// fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read.
79func (b *bitReader) fillFastStart() {
Abhay Kumara2ae5992025-11-10 14:02:24 +000080 b.cursor -= 8
81 b.value = le.Load64(b.in, b.cursor)
khenaidood948f772021-08-11 17:49:24 -040082 b.bitsRead = 0
khenaidood948f772021-08-11 17:49:24 -040083}
84
85// fill() will make sure at least 32 bits are available.
86func (b *bitReader) fill() {
87 if b.bitsRead < 32 {
88 return
89 }
Abhay Kumara2ae5992025-11-10 14:02:24 +000090 if b.cursor >= 4 {
91 b.cursor -= 4
92 b.value = (b.value << 32) | uint64(le.Load32(b.in, b.cursor))
khenaidood948f772021-08-11 17:49:24 -040093 b.bitsRead -= 32
khenaidood948f772021-08-11 17:49:24 -040094 return
95 }
Abhay Kumara2ae5992025-11-10 14:02:24 +000096
97 b.bitsRead -= uint8(8 * b.cursor)
98 for b.cursor > 0 {
99 b.cursor -= 1
100 b.value = (b.value << 8) | uint64(b.in[b.cursor])
khenaidood948f772021-08-11 17:49:24 -0400101 }
102}
103
104// finished returns true if all bits have been read from the bit stream.
105func (b *bitReader) finished() bool {
Abhay Kumara2ae5992025-11-10 14:02:24 +0000106 return b.cursor == 0 && b.bitsRead >= 64
khenaidood948f772021-08-11 17:49:24 -0400107}
108
109// overread returns true if more bits have been requested than is on the stream.
110func (b *bitReader) overread() bool {
111 return b.bitsRead > 64
112}
113
114// remain returns the number of bits remaining.
115func (b *bitReader) remain() uint {
Abhay Kumara2ae5992025-11-10 14:02:24 +0000116 return 8*uint(b.cursor) + 64 - uint(b.bitsRead)
khenaidood948f772021-08-11 17:49:24 -0400117}
118
119// close the bitstream and returns an error if out-of-buffer reads occurred.
120func (b *bitReader) close() error {
121 // Release reference.
122 b.in = nil
Abhay Kumara2ae5992025-11-10 14:02:24 +0000123 b.cursor = 0
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530124 if !b.finished() {
125 return fmt.Errorf("%d extra bits on block, should be 0", b.remain())
126 }
khenaidood948f772021-08-11 17:49:24 -0400127 if b.bitsRead > 64 {
128 return io.ErrUnexpectedEOF
129 }
130 return nil
131}
132
133func highBits(val uint32) (n uint32) {
134 return uint32(bits.Len32(val) - 1)
135}