blob: 8f8223cd3a67891da857d1bf2b6d1655debd6ba9 [file] [log] [blame]
khenaidood948f772021-08-11 17:49:24 -04001// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8 "crypto/rand"
Abhay Kumara2ae5992025-11-10 14:02:24 +00009 "errors"
khenaidood948f772021-08-11 17:49:24 -040010 "fmt"
11 "io"
Abhay Kumara2ae5992025-11-10 14:02:24 +000012 "math"
khenaidood948f772021-08-11 17:49:24 -040013 rdebug "runtime/debug"
14 "sync"
15
16 "github.com/klauspost/compress/zstd/internal/xxhash"
17)
18
19// Encoder provides encoding to Zstandard.
20// An Encoder can be used for either compressing a stream via the
21// io.WriteCloser interface supported by the Encoder or as multiple independent
22// tasks via the EncodeAll function.
23// Smaller encodes are encouraged to use the EncodeAll function.
24// Use NewWriter to create a new instance.
25type Encoder struct {
26 o encoderOptions
27 encoders chan encoder
28 state encoderState
29 init sync.Once
30}
31
32type encoder interface {
33 Encode(blk *blockEnc, src []byte)
34 EncodeNoHist(blk *blockEnc, src []byte)
35 Block() *blockEnc
36 CRC() *xxhash.Digest
37 AppendCRC([]byte) []byte
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +053038 WindowSize(size int64) int32
khenaidood948f772021-08-11 17:49:24 -040039 UseBlock(*blockEnc)
40 Reset(d *dict, singleBlock bool)
41}
42
43type encoderState struct {
44 w io.Writer
45 filling []byte
46 current []byte
47 previous []byte
48 encoder encoder
49 writing *blockEnc
50 err error
51 writeErr error
52 nWritten int64
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +053053 nInput int64
54 frameContentSize int64
khenaidood948f772021-08-11 17:49:24 -040055 headerWritten bool
56 eofWritten bool
57 fullFrameWritten bool
58
59 // This waitgroup indicates an encode is running.
60 wg sync.WaitGroup
61 // This waitgroup indicates we have a block encoding/writing.
62 wWg sync.WaitGroup
63}
64
65// NewWriter will create a new Zstandard encoder.
66// If the encoder will be used for encoding blocks a nil writer can be used.
67func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
68 initPredefined()
69 var e Encoder
70 e.o.setDefault()
71 for _, o := range opts {
72 err := o(&e.o)
73 if err != nil {
74 return nil, err
75 }
76 }
77 if w != nil {
78 e.Reset(w)
79 }
80 return &e, nil
81}
82
83func (e *Encoder) initialize() {
84 if e.o.concurrent == 0 {
85 e.o.setDefault()
86 }
87 e.encoders = make(chan encoder, e.o.concurrent)
88 for i := 0; i < e.o.concurrent; i++ {
89 enc := e.o.encoder()
90 e.encoders <- enc
91 }
92}
93
94// Reset will re-initialize the writer and new writes will encode to the supplied writer
95// as a new, independent stream.
96func (e *Encoder) Reset(w io.Writer) {
97 s := &e.state
98 s.wg.Wait()
99 s.wWg.Wait()
100 if cap(s.filling) == 0 {
101 s.filling = make([]byte, 0, e.o.blockSize)
102 }
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530103 if e.o.concurrent > 1 {
104 if cap(s.current) == 0 {
105 s.current = make([]byte, 0, e.o.blockSize)
106 }
107 if cap(s.previous) == 0 {
108 s.previous = make([]byte, 0, e.o.blockSize)
109 }
110 s.current = s.current[:0]
111 s.previous = s.previous[:0]
112 if s.writing == nil {
113 s.writing = &blockEnc{lowMem: e.o.lowMem}
114 s.writing.init()
115 }
116 s.writing.initNewEncode()
khenaidood948f772021-08-11 17:49:24 -0400117 }
118 if s.encoder == nil {
119 s.encoder = e.o.encoder()
120 }
khenaidood948f772021-08-11 17:49:24 -0400121 s.filling = s.filling[:0]
khenaidood948f772021-08-11 17:49:24 -0400122 s.encoder.Reset(e.o.dict, false)
123 s.headerWritten = false
124 s.eofWritten = false
125 s.fullFrameWritten = false
126 s.w = w
127 s.err = nil
128 s.nWritten = 0
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530129 s.nInput = 0
khenaidood948f772021-08-11 17:49:24 -0400130 s.writeErr = nil
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530131 s.frameContentSize = 0
132}
133
134// ResetContentSize will reset and set a content size for the next stream.
135// If the bytes written does not match the size given an error will be returned
136// when calling Close().
137// This is removed when Reset is called.
138// Sizes <= 0 results in no content size set.
139func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
140 e.Reset(w)
141 if size >= 0 {
142 e.state.frameContentSize = size
143 }
khenaidood948f772021-08-11 17:49:24 -0400144}
145
146// Write data to the encoder.
147// Input data will be buffered and as the buffer fills up
148// content will be compressed and written to the output.
149// When done writing, use Close to flush the remaining output
150// and write CRC if requested.
151func (e *Encoder) Write(p []byte) (n int, err error) {
152 s := &e.state
Abhay Kumara2ae5992025-11-10 14:02:24 +0000153 if s.eofWritten {
154 return 0, ErrEncoderClosed
155 }
khenaidood948f772021-08-11 17:49:24 -0400156 for len(p) > 0 {
157 if len(p)+len(s.filling) < e.o.blockSize {
158 if e.o.crc {
159 _, _ = s.encoder.CRC().Write(p)
160 }
161 s.filling = append(s.filling, p...)
162 return n + len(p), nil
163 }
164 add := p
165 if len(p)+len(s.filling) > e.o.blockSize {
166 add = add[:e.o.blockSize-len(s.filling)]
167 }
168 if e.o.crc {
169 _, _ = s.encoder.CRC().Write(add)
170 }
171 s.filling = append(s.filling, add...)
172 p = p[len(add):]
173 n += len(add)
174 if len(s.filling) < e.o.blockSize {
175 return n, nil
176 }
177 err := e.nextBlock(false)
178 if err != nil {
179 return n, err
180 }
181 if debugAsserts && len(s.filling) > 0 {
182 panic(len(s.filling))
183 }
184 }
185 return n, nil
186}
187
188// nextBlock will synchronize and start compressing input in e.state.filling.
189// If an error has occurred during encoding it will be returned.
190func (e *Encoder) nextBlock(final bool) error {
191 s := &e.state
192 // Wait for current block.
193 s.wg.Wait()
194 if s.err != nil {
195 return s.err
196 }
197 if len(s.filling) > e.o.blockSize {
198 return fmt.Errorf("block > maxStoreBlockSize")
199 }
200 if !s.headerWritten {
201 // If we have a single block encode, do a sync compression.
202 if final && len(s.filling) == 0 && !e.o.fullZero {
203 s.headerWritten = true
204 s.fullFrameWritten = true
205 s.eofWritten = true
206 return nil
207 }
208 if final && len(s.filling) > 0 {
Abhay Kumara2ae5992025-11-10 14:02:24 +0000209 s.current = e.encodeAll(s.encoder, s.filling, s.current[:0])
khenaidood948f772021-08-11 17:49:24 -0400210 var n2 int
211 n2, s.err = s.w.Write(s.current)
212 if s.err != nil {
213 return s.err
214 }
215 s.nWritten += int64(n2)
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530216 s.nInput += int64(len(s.filling))
khenaidood948f772021-08-11 17:49:24 -0400217 s.current = s.current[:0]
218 s.filling = s.filling[:0]
219 s.headerWritten = true
220 s.fullFrameWritten = true
221 s.eofWritten = true
222 return nil
223 }
224
225 var tmp [maxHeaderSize]byte
226 fh := frameHeader{
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530227 ContentSize: uint64(s.frameContentSize),
228 WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
khenaidood948f772021-08-11 17:49:24 -0400229 SingleSegment: false,
230 Checksum: e.o.crc,
231 DictID: e.o.dict.ID(),
232 }
233
Abhay Kumara2ae5992025-11-10 14:02:24 +0000234 dst := fh.appendTo(tmp[:0])
khenaidood948f772021-08-11 17:49:24 -0400235 s.headerWritten = true
236 s.wWg.Wait()
237 var n2 int
238 n2, s.err = s.w.Write(dst)
239 if s.err != nil {
240 return s.err
241 }
242 s.nWritten += int64(n2)
243 }
244 if s.eofWritten {
245 // Ensure we only write it once.
246 final = false
247 }
248
249 if len(s.filling) == 0 {
250 // Final block, but no data.
251 if final {
252 enc := s.encoder
253 blk := enc.Block()
254 blk.reset(nil)
255 blk.last = true
256 blk.encodeRaw(nil)
257 s.wWg.Wait()
258 _, s.err = s.w.Write(blk.output)
259 s.nWritten += int64(len(blk.output))
260 s.eofWritten = true
261 }
262 return s.err
263 }
264
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530265 // SYNC:
266 if e.o.concurrent == 1 {
267 src := s.filling
268 s.nInput += int64(len(s.filling))
269 if debugEncoder {
270 println("Adding sync block,", len(src), "bytes, final:", final)
271 }
272 enc := s.encoder
273 blk := enc.Block()
274 blk.reset(nil)
275 enc.Encode(blk, src)
276 blk.last = final
277 if final {
278 s.eofWritten = true
279 }
280
Abhay Kumara2ae5992025-11-10 14:02:24 +0000281 s.err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
282 if s.err != nil {
283 return s.err
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530284 }
285 _, s.err = s.w.Write(blk.output)
286 s.nWritten += int64(len(blk.output))
287 s.filling = s.filling[:0]
288 return s.err
289 }
290
khenaidood948f772021-08-11 17:49:24 -0400291 // Move blocks forward.
292 s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530293 s.nInput += int64(len(s.current))
khenaidood948f772021-08-11 17:49:24 -0400294 s.wg.Add(1)
Abhay Kumara2ae5992025-11-10 14:02:24 +0000295 if final {
296 s.eofWritten = true
297 }
khenaidood948f772021-08-11 17:49:24 -0400298 go func(src []byte) {
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530299 if debugEncoder {
khenaidood948f772021-08-11 17:49:24 -0400300 println("Adding block,", len(src), "bytes, final:", final)
301 }
302 defer func() {
303 if r := recover(); r != nil {
304 s.err = fmt.Errorf("panic while encoding: %v", r)
305 rdebug.PrintStack()
306 }
307 s.wg.Done()
308 }()
309 enc := s.encoder
310 blk := enc.Block()
311 enc.Encode(blk, src)
312 blk.last = final
khenaidood948f772021-08-11 17:49:24 -0400313 // Wait for pending writes.
314 s.wWg.Wait()
315 if s.writeErr != nil {
316 s.err = s.writeErr
317 return
318 }
319 // Transfer encoders from previous write block.
320 blk.swapEncoders(s.writing)
321 // Transfer recent offsets to next.
322 enc.UseBlock(s.writing)
323 s.writing = blk
324 s.wWg.Add(1)
325 go func() {
326 defer func() {
327 if r := recover(); r != nil {
328 s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
329 rdebug.PrintStack()
330 }
331 s.wWg.Done()
332 }()
Abhay Kumara2ae5992025-11-10 14:02:24 +0000333 s.writeErr = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
334 if s.writeErr != nil {
khenaidood948f772021-08-11 17:49:24 -0400335 return
336 }
337 _, s.writeErr = s.w.Write(blk.output)
338 s.nWritten += int64(len(blk.output))
339 }()
340 }(s.current)
341 return nil
342}
343
344// ReadFrom reads data from r until EOF or error.
345// The return value n is the number of bytes read.
346// Any error except io.EOF encountered during the read is also returned.
347//
348// The Copy function uses ReaderFrom if available.
349func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530350 if debugEncoder {
khenaidood948f772021-08-11 17:49:24 -0400351 println("Using ReadFrom")
352 }
353
354 // Flush any current writes.
355 if len(e.state.filling) > 0 {
356 if err := e.nextBlock(false); err != nil {
357 return 0, err
358 }
359 }
360 e.state.filling = e.state.filling[:e.o.blockSize]
361 src := e.state.filling
362 for {
363 n2, err := r.Read(src)
364 if e.o.crc {
365 _, _ = e.state.encoder.CRC().Write(src[:n2])
366 }
367 // src is now the unfilled part...
368 src = src[n2:]
369 n += int64(n2)
370 switch err {
371 case io.EOF:
372 e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530373 if debugEncoder {
khenaidood948f772021-08-11 17:49:24 -0400374 println("ReadFrom: got EOF final block:", len(e.state.filling))
375 }
376 return n, nil
377 case nil:
378 default:
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530379 if debugEncoder {
khenaidood948f772021-08-11 17:49:24 -0400380 println("ReadFrom: got error:", err)
381 }
382 e.state.err = err
383 return n, err
384 }
385 if len(src) > 0 {
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530386 if debugEncoder {
khenaidood948f772021-08-11 17:49:24 -0400387 println("ReadFrom: got space left in source:", len(src))
388 }
389 continue
390 }
391 err = e.nextBlock(false)
392 if err != nil {
393 return n, err
394 }
395 e.state.filling = e.state.filling[:e.o.blockSize]
396 src = e.state.filling
397 }
398}
399
400// Flush will send the currently written data to output
401// and block until everything has been written.
402// This should only be used on rare occasions where pushing the currently queued data is critical.
403func (e *Encoder) Flush() error {
404 s := &e.state
405 if len(s.filling) > 0 {
406 err := e.nextBlock(false)
407 if err != nil {
Abhay Kumara2ae5992025-11-10 14:02:24 +0000408 // Ignore Flush after Close.
409 if errors.Is(s.err, ErrEncoderClosed) {
410 return nil
411 }
khenaidood948f772021-08-11 17:49:24 -0400412 return err
413 }
414 }
415 s.wg.Wait()
416 s.wWg.Wait()
417 if s.err != nil {
Abhay Kumara2ae5992025-11-10 14:02:24 +0000418 // Ignore Flush after Close.
419 if errors.Is(s.err, ErrEncoderClosed) {
420 return nil
421 }
khenaidood948f772021-08-11 17:49:24 -0400422 return s.err
423 }
424 return s.writeErr
425}
426
427// Close will flush the final output and close the stream.
428// The function will block until everything has been written.
429// The Encoder can still be re-used after calling this.
430func (e *Encoder) Close() error {
431 s := &e.state
432 if s.encoder == nil {
433 return nil
434 }
435 err := e.nextBlock(true)
436 if err != nil {
Abhay Kumara2ae5992025-11-10 14:02:24 +0000437 if errors.Is(s.err, ErrEncoderClosed) {
438 return nil
439 }
khenaidood948f772021-08-11 17:49:24 -0400440 return err
441 }
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530442 if s.frameContentSize > 0 {
443 if s.nInput != s.frameContentSize {
444 return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
445 }
446 }
khenaidood948f772021-08-11 17:49:24 -0400447 if e.state.fullFrameWritten {
448 return s.err
449 }
450 s.wg.Wait()
451 s.wWg.Wait()
452
453 if s.err != nil {
454 return s.err
455 }
456 if s.writeErr != nil {
457 return s.writeErr
458 }
459
460 // Write CRC
461 if e.o.crc && s.err == nil {
462 // heap alloc.
463 var tmp [4]byte
464 _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
465 s.nWritten += 4
466 }
467
468 // Add padding with content from crypto/rand.Reader
469 if s.err == nil && e.o.pad > 0 {
470 add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
471 frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
472 if err != nil {
473 return err
474 }
475 _, s.err = s.w.Write(frame)
476 }
Abhay Kumara2ae5992025-11-10 14:02:24 +0000477 if s.err == nil {
478 s.err = ErrEncoderClosed
479 return nil
480 }
481
khenaidood948f772021-08-11 17:49:24 -0400482 return s.err
483}
484
485// EncodeAll will encode all input in src and append it to dst.
486// This function can be called concurrently, but each call will only run on a single goroutine.
487// If empty input is given, nothing is returned, unless WithZeroFrames is specified.
488// Encoded blocks can be concatenated and the result will be the combined input stream.
489// Data compressed with EncodeAll can be decoded with the Decoder,
490// using either a stream or DecodeAll.
491func (e *Encoder) EncodeAll(src, dst []byte) []byte {
Abhay Kumara2ae5992025-11-10 14:02:24 +0000492 e.init.Do(e.initialize)
493 enc := <-e.encoders
494 defer func() {
495 e.encoders <- enc
496 }()
497 return e.encodeAll(enc, src, dst)
498}
499
500func (e *Encoder) encodeAll(enc encoder, src, dst []byte) []byte {
khenaidood948f772021-08-11 17:49:24 -0400501 if len(src) == 0 {
502 if e.o.fullZero {
503 // Add frame header.
504 fh := frameHeader{
505 ContentSize: 0,
506 WindowSize: MinWindowSize,
507 SingleSegment: true,
508 // Adding a checksum would be a waste of space.
509 Checksum: false,
510 DictID: 0,
511 }
Abhay Kumara2ae5992025-11-10 14:02:24 +0000512 dst = fh.appendTo(dst)
khenaidood948f772021-08-11 17:49:24 -0400513
514 // Write raw block as last one only.
515 var blk blockHeader
516 blk.setSize(0)
517 blk.setType(blockTypeRaw)
518 blk.setLast(true)
519 dst = blk.appendTo(dst)
520 }
521 return dst
522 }
Abhay Kumara2ae5992025-11-10 14:02:24 +0000523
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530524 // Use single segments when above minimum window and below window size.
525 single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
khenaidood948f772021-08-11 17:49:24 -0400526 if e.o.single != nil {
527 single = *e.o.single
528 }
529 fh := frameHeader{
530 ContentSize: uint64(len(src)),
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530531 WindowSize: uint32(enc.WindowSize(int64(len(src)))),
khenaidood948f772021-08-11 17:49:24 -0400532 SingleSegment: single,
533 Checksum: e.o.crc,
534 DictID: e.o.dict.ID(),
535 }
536
537 // If less than 1MB, allocate a buffer up front.
538 if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
539 dst = make([]byte, 0, len(src))
540 }
Abhay Kumara2ae5992025-11-10 14:02:24 +0000541 dst = fh.appendTo(dst)
khenaidood948f772021-08-11 17:49:24 -0400542
543 // If we can do everything in one block, prefer that.
Akash Reddy Kankanalacf045372025-06-10 14:11:24 +0530544 if len(src) <= e.o.blockSize {
khenaidood948f772021-08-11 17:49:24 -0400545 enc.Reset(e.o.dict, true)
546 // Slightly faster with no history and everything in one block.
547 if e.o.crc {
548 _, _ = enc.CRC().Write(src)
549 }
550 blk := enc.Block()
551 blk.last = true
552 if e.o.dict == nil {
553 enc.EncodeNoHist(blk, src)
554 } else {
555 enc.Encode(blk, src)
556 }
557
558 // If we got the exact same number of literals as input,
559 // assume the literals cannot be compressed.
khenaidood948f772021-08-11 17:49:24 -0400560 oldout := blk.output
Abhay Kumara2ae5992025-11-10 14:02:24 +0000561 // Output directly to dst
562 blk.output = dst
khenaidood948f772021-08-11 17:49:24 -0400563
Abhay Kumara2ae5992025-11-10 14:02:24 +0000564 err := blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
565 if err != nil {
khenaidood948f772021-08-11 17:49:24 -0400566 panic(err)
567 }
Abhay Kumara2ae5992025-11-10 14:02:24 +0000568 dst = blk.output
khenaidood948f772021-08-11 17:49:24 -0400569 blk.output = oldout
570 } else {
571 enc.Reset(e.o.dict, false)
572 blk := enc.Block()
573 for len(src) > 0 {
574 todo := src
575 if len(todo) > e.o.blockSize {
576 todo = todo[:e.o.blockSize]
577 }
578 src = src[len(todo):]
579 if e.o.crc {
580 _, _ = enc.CRC().Write(todo)
581 }
582 blk.pushOffsets()
583 enc.Encode(blk, todo)
584 if len(src) == 0 {
585 blk.last = true
586 }
Abhay Kumara2ae5992025-11-10 14:02:24 +0000587 err := blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
588 if err != nil {
khenaidood948f772021-08-11 17:49:24 -0400589 panic(err)
590 }
Abhay Kumara2ae5992025-11-10 14:02:24 +0000591 dst = append(dst, blk.output...)
khenaidood948f772021-08-11 17:49:24 -0400592 blk.reset(nil)
593 }
594 }
595 if e.o.crc {
596 dst = enc.AppendCRC(dst)
597 }
598 // Add padding with content from crypto/rand.Reader
599 if e.o.pad > 0 {
600 add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
Abhay Kumara2ae5992025-11-10 14:02:24 +0000601 var err error
khenaidood948f772021-08-11 17:49:24 -0400602 dst, err = skippableFrame(dst, add, rand.Reader)
603 if err != nil {
604 panic(err)
605 }
606 }
607 return dst
608}
Abhay Kumara2ae5992025-11-10 14:02:24 +0000609
610// MaxEncodedSize returns the expected maximum
611// size of an encoded block or stream.
612func (e *Encoder) MaxEncodedSize(size int) int {
613 frameHeader := 4 + 2 // magic + frame header & window descriptor
614 if e.o.dict != nil {
615 frameHeader += 4
616 }
617 // Frame content size:
618 if size < 256 {
619 frameHeader++
620 } else if size < 65536+256 {
621 frameHeader += 2
622 } else if size < math.MaxInt32 {
623 frameHeader += 4
624 } else {
625 frameHeader += 8
626 }
627 // Final crc
628 if e.o.crc {
629 frameHeader += 4
630 }
631
632 // Max overhead is 3 bytes/block.
633 // There cannot be 0 blocks.
634 blocks := (size + e.o.blockSize) / e.o.blockSize
635
636 // Combine, add padding.
637 maxSz := frameHeader + 3*blocks + size
638 if e.o.pad > 1 {
639 maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
640 }
641 return maxSz
642}