blob: b7b83164bc76d572c6eb0541b25878e49575db63 [file] [log] [blame]
khenaidoo26721882021-08-11 17:42:52 -04001package zstd
2
3import (
4 "bytes"
5 "encoding/binary"
6 "errors"
7 "fmt"
8 "io"
Abhay Kumar40252eb2025-10-13 13:25:53 +00009 "math"
10 "sort"
khenaidoo26721882021-08-11 17:42:52 -040011
12 "github.com/klauspost/compress/huff0"
13)
14
15type dict struct {
16 id uint32
17
18 litEnc *huff0.Scratch
19 llDec, ofDec, mlDec sequenceDec
Abhay Kumar40252eb2025-10-13 13:25:53 +000020 offsets [3]int
21 content []byte
khenaidoo26721882021-08-11 17:42:52 -040022}
23
Abhay Kumar40252eb2025-10-13 13:25:53 +000024const dictMagic = "\x37\xa4\x30\xec"
25
26// Maximum dictionary size for the reference implementation (1.5.3) is 2 GiB.
27const dictMaxLength = 1 << 31
khenaidoo26721882021-08-11 17:42:52 -040028
29// ID returns the dictionary id or 0 if d is nil.
30func (d *dict) ID() uint32 {
31 if d == nil {
32 return 0
33 }
34 return d.id
35}
36
Abhay Kumar40252eb2025-10-13 13:25:53 +000037// ContentSize returns the dictionary content size or 0 if d is nil.
38func (d *dict) ContentSize() int {
khenaidoo26721882021-08-11 17:42:52 -040039 if d == nil {
40 return 0
41 }
42 return len(d.content)
43}
44
Abhay Kumar40252eb2025-10-13 13:25:53 +000045// Content returns the dictionary content.
46func (d *dict) Content() []byte {
47 if d == nil {
48 return nil
49 }
50 return d.content
51}
52
53// Offsets returns the initial offsets.
54func (d *dict) Offsets() [3]int {
55 if d == nil {
56 return [3]int{}
57 }
58 return d.offsets
59}
60
61// LitEncoder returns the literal encoder.
62func (d *dict) LitEncoder() *huff0.Scratch {
63 if d == nil {
64 return nil
65 }
66 return d.litEnc
67}
68
khenaidoo26721882021-08-11 17:42:52 -040069// Load a dictionary as described in
70// https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format
71func loadDict(b []byte) (*dict, error) {
72 // Check static field size.
73 if len(b) <= 8+(3*4) {
74 return nil, io.ErrUnexpectedEOF
75 }
76 d := dict{
77 llDec: sequenceDec{fse: &fseDecoder{}},
78 ofDec: sequenceDec{fse: &fseDecoder{}},
79 mlDec: sequenceDec{fse: &fseDecoder{}},
80 }
Abhay Kumar40252eb2025-10-13 13:25:53 +000081 if string(b[:4]) != dictMagic {
khenaidoo26721882021-08-11 17:42:52 -040082 return nil, ErrMagicMismatch
83 }
84 d.id = binary.LittleEndian.Uint32(b[4:8])
85 if d.id == 0 {
86 return nil, errors.New("dictionaries cannot have ID 0")
87 }
88
89 // Read literal table
90 var err error
91 d.litEnc, b, err = huff0.ReadTable(b[8:], nil)
92 if err != nil {
Abhay Kumar40252eb2025-10-13 13:25:53 +000093 return nil, fmt.Errorf("loading literal table: %w", err)
khenaidoo26721882021-08-11 17:42:52 -040094 }
95 d.litEnc.Reuse = huff0.ReusePolicyMust
96
97 br := byteReader{
98 b: b,
99 off: 0,
100 }
101 readDec := func(i tableIndex, dec *fseDecoder) error {
102 if err := dec.readNCount(&br, uint16(maxTableSymbol[i])); err != nil {
103 return err
104 }
105 if br.overread() {
106 return io.ErrUnexpectedEOF
107 }
108 err = dec.transform(symbolTableX[i])
109 if err != nil {
110 println("Transform table error:", err)
111 return err
112 }
Abhay Kumar40252eb2025-10-13 13:25:53 +0000113 if debugDecoder || debugEncoder {
khenaidoo26721882021-08-11 17:42:52 -0400114 println("Read table ok", "symbolLen:", dec.symbolLen)
115 }
116 // Set decoders as predefined so they aren't reused.
117 dec.preDefined = true
118 return nil
119 }
120
121 if err := readDec(tableOffsets, d.ofDec.fse); err != nil {
122 return nil, err
123 }
124 if err := readDec(tableMatchLengths, d.mlDec.fse); err != nil {
125 return nil, err
126 }
127 if err := readDec(tableLiteralLengths, d.llDec.fse); err != nil {
128 return nil, err
129 }
130 if br.remain() < 12 {
131 return nil, io.ErrUnexpectedEOF
132 }
133
134 d.offsets[0] = int(br.Uint32())
135 br.advance(4)
136 d.offsets[1] = int(br.Uint32())
137 br.advance(4)
138 d.offsets[2] = int(br.Uint32())
139 br.advance(4)
140 if d.offsets[0] <= 0 || d.offsets[1] <= 0 || d.offsets[2] <= 0 {
141 return nil, errors.New("invalid offset in dictionary")
142 }
143 d.content = make([]byte, br.remain())
144 copy(d.content, br.unread())
145 if d.offsets[0] > len(d.content) || d.offsets[1] > len(d.content) || d.offsets[2] > len(d.content) {
146 return nil, fmt.Errorf("initial offset bigger than dictionary content size %d, offsets: %v", len(d.content), d.offsets)
147 }
148
149 return &d, nil
150}
Abhay Kumar40252eb2025-10-13 13:25:53 +0000151
152// InspectDictionary loads a zstd dictionary and provides functions to inspect the content.
153func InspectDictionary(b []byte) (interface {
154 ID() uint32
155 ContentSize() int
156 Content() []byte
157 Offsets() [3]int
158 LitEncoder() *huff0.Scratch
159}, error) {
160 initPredefined()
161 d, err := loadDict(b)
162 return d, err
163}
164
165type BuildDictOptions struct {
166 // Dictionary ID.
167 ID uint32
168
169 // Content to use to create dictionary tables.
170 Contents [][]byte
171
172 // History to use for all blocks.
173 History []byte
174
175 // Offsets to use.
176 Offsets [3]int
177
178 // CompatV155 will make the dictionary compatible with Zstd v1.5.5 and earlier.
179 // See https://github.com/facebook/zstd/issues/3724
180 CompatV155 bool
181
182 // Use the specified encoder level.
183 // The dictionary will be built using the specified encoder level,
184 // which will reflect speed and make the dictionary tailored for that level.
185 // If not set SpeedBestCompression will be used.
186 Level EncoderLevel
187
188 // DebugOut will write stats and other details here if set.
189 DebugOut io.Writer
190}
191
192func BuildDict(o BuildDictOptions) ([]byte, error) {
193 initPredefined()
194 hist := o.History
195 contents := o.Contents
196 debug := o.DebugOut != nil
197 println := func(args ...interface{}) {
198 if o.DebugOut != nil {
199 fmt.Fprintln(o.DebugOut, args...)
200 }
201 }
202 printf := func(s string, args ...interface{}) {
203 if o.DebugOut != nil {
204 fmt.Fprintf(o.DebugOut, s, args...)
205 }
206 }
207 print := func(args ...interface{}) {
208 if o.DebugOut != nil {
209 fmt.Fprint(o.DebugOut, args...)
210 }
211 }
212
213 if int64(len(hist)) > dictMaxLength {
214 return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength))
215 }
216 if len(hist) < 8 {
217 return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8)
218 }
219 if len(contents) == 0 {
220 return nil, errors.New("no content provided")
221 }
222 d := dict{
223 id: o.ID,
224 litEnc: nil,
225 llDec: sequenceDec{},
226 ofDec: sequenceDec{},
227 mlDec: sequenceDec{},
228 offsets: o.Offsets,
229 content: hist,
230 }
231 block := blockEnc{lowMem: false}
232 block.init()
233 enc := encoder(&bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}})
234 if o.Level != 0 {
235 eOpts := encoderOptions{
236 level: o.Level,
237 blockSize: maxMatchLen,
238 windowSize: maxMatchLen,
239 dict: &d,
240 lowMem: false,
241 }
242 enc = eOpts.encoder()
243 } else {
244 o.Level = SpeedBestCompression
245 }
246 var (
247 remain [256]int
248 ll [256]int
249 ml [256]int
250 of [256]int
251 )
252 addValues := func(dst *[256]int, src []byte) {
253 for _, v := range src {
254 dst[v]++
255 }
256 }
257 addHist := func(dst *[256]int, src *[256]uint32) {
258 for i, v := range src {
259 dst[i] += int(v)
260 }
261 }
262 seqs := 0
263 nUsed := 0
264 litTotal := 0
265 newOffsets := make(map[uint32]int, 1000)
266 for _, b := range contents {
267 block.reset(nil)
268 if len(b) < 8 {
269 continue
270 }
271 nUsed++
272 enc.Reset(&d, true)
273 enc.Encode(&block, b)
274 addValues(&remain, block.literals)
275 litTotal += len(block.literals)
276 if len(block.sequences) == 0 {
277 continue
278 }
279 seqs += len(block.sequences)
280 block.genCodes()
281 addHist(&ll, block.coders.llEnc.Histogram())
282 addHist(&ml, block.coders.mlEnc.Histogram())
283 addHist(&of, block.coders.ofEnc.Histogram())
284 for i, seq := range block.sequences {
285 if i > 3 {
286 break
287 }
288 offset := seq.offset
289 if offset == 0 {
290 continue
291 }
292 if int(offset) >= len(o.History) {
293 continue
294 }
295 if offset > 3 {
296 newOffsets[offset-3]++
297 } else {
298 newOffsets[uint32(o.Offsets[offset-1])]++
299 }
300 }
301 }
302 // Find most used offsets.
303 var sortedOffsets []uint32
304 for k := range newOffsets {
305 sortedOffsets = append(sortedOffsets, k)
306 }
307 sort.Slice(sortedOffsets, func(i, j int) bool {
308 a, b := sortedOffsets[i], sortedOffsets[j]
309 if a == b {
310 // Prefer the longer offset
311 return sortedOffsets[i] > sortedOffsets[j]
312 }
313 return newOffsets[sortedOffsets[i]] > newOffsets[sortedOffsets[j]]
314 })
315 if len(sortedOffsets) > 3 {
316 if debug {
317 print("Offsets:")
318 for i, v := range sortedOffsets {
319 if i > 20 {
320 break
321 }
322 printf("[%d: %d],", v, newOffsets[v])
323 }
324 println("")
325 }
326
327 sortedOffsets = sortedOffsets[:3]
328 }
329 for i, v := range sortedOffsets {
330 o.Offsets[i] = int(v)
331 }
332 if debug {
333 println("New repeat offsets", o.Offsets)
334 }
335
336 if nUsed == 0 || seqs == 0 {
337 return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs)
338 }
339 if debug {
340 println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal)
341 }
342 if seqs/nUsed < 512 {
343 // Use 512 as minimum.
344 nUsed = seqs / 512
345 if nUsed == 0 {
346 nUsed = 1
347 }
348 }
349 copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) {
350 hist := dst.Histogram()
351 var maxSym uint8
352 var maxCount int
353 var fakeLength int
354 for i, v := range src {
355 if v > 0 {
356 v = v / nUsed
357 if v == 0 {
358 v = 1
359 }
360 }
361 if v > maxCount {
362 maxCount = v
363 }
364 if v != 0 {
365 maxSym = uint8(i)
366 }
367 fakeLength += v
368 hist[i] = uint32(v)
369 }
370
371 // Ensure we aren't trying to represent RLE.
372 if maxCount == fakeLength {
373 for i := range hist {
374 if uint8(i) == maxSym {
375 fakeLength++
376 maxSym++
377 hist[i+1] = 1
378 if maxSym > 1 {
379 break
380 }
381 }
382 if hist[0] == 0 {
383 fakeLength++
384 hist[i] = 1
385 if maxSym > 1 {
386 break
387 }
388 }
389 }
390 }
391
392 dst.HistogramFinished(maxSym, maxCount)
393 dst.reUsed = false
394 dst.useRLE = false
395 err := dst.normalizeCount(fakeLength)
396 if err != nil {
397 return nil, err
398 }
399 if debug {
400 println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength)
401 }
402 return dst.writeCount(nil)
403 }
404 if debug {
405 print("Literal lengths: ")
406 }
407 llTable, err := copyHist(block.coders.llEnc, &ll)
408 if err != nil {
409 return nil, err
410 }
411 if debug {
412 print("Match lengths: ")
413 }
414 mlTable, err := copyHist(block.coders.mlEnc, &ml)
415 if err != nil {
416 return nil, err
417 }
418 if debug {
419 print("Offsets: ")
420 }
421 ofTable, err := copyHist(block.coders.ofEnc, &of)
422 if err != nil {
423 return nil, err
424 }
425
426 // Literal table
427 avgSize := litTotal
428 if avgSize > huff0.BlockSizeMax/2 {
429 avgSize = huff0.BlockSizeMax / 2
430 }
431 huffBuff := make([]byte, 0, avgSize)
432 // Target size
433 div := litTotal / avgSize
434 if div < 1 {
435 div = 1
436 }
437 if debug {
438 println("Huffman weights:")
439 }
440 for i, n := range remain[:] {
441 if n > 0 {
442 n = n / div
443 // Allow all entries to be represented.
444 if n == 0 {
445 n = 1
446 }
447 huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
448 if debug {
449 printf("[%d: %d], ", i, n)
450 }
451 }
452 }
453 if o.CompatV155 && remain[255]/div == 0 {
454 huffBuff = append(huffBuff, 255)
455 }
456 scratch := &huff0.Scratch{TableLog: 11}
457 for tries := 0; tries < 255; tries++ {
458 scratch = &huff0.Scratch{TableLog: 11}
459 _, _, err = huff0.Compress1X(huffBuff, scratch)
460 if err == nil {
461 break
462 }
463 if debug {
464 printf("Try %d: Huffman error: %v\n", tries+1, err)
465 }
466 huffBuff = huffBuff[:0]
467 if tries == 250 {
468 if debug {
469 println("Huffman: Bailing out with predefined table")
470 }
471
472 // Bail out.... Just generate something
473 huffBuff = append(huffBuff, bytes.Repeat([]byte{255}, 10000)...)
474 for i := 0; i < 128; i++ {
475 huffBuff = append(huffBuff, byte(i))
476 }
477 continue
478 }
479 if errors.Is(err, huff0.ErrIncompressible) {
480 // Try truncating least common.
481 for i, n := range remain[:] {
482 if n > 0 {
483 n = n / (div * (i + 1))
484 if n > 0 {
485 huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
486 }
487 }
488 }
489 if o.CompatV155 && len(huffBuff) > 0 && huffBuff[len(huffBuff)-1] != 255 {
490 huffBuff = append(huffBuff, 255)
491 }
492 if len(huffBuff) == 0 {
493 huffBuff = append(huffBuff, 0, 255)
494 }
495 }
496 if errors.Is(err, huff0.ErrUseRLE) {
497 for i, n := range remain[:] {
498 n = n / (div * (i + 1))
499 // Allow all entries to be represented.
500 if n == 0 {
501 n = 1
502 }
503 huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
504 }
505 }
506 }
507
508 var out bytes.Buffer
509 out.Write([]byte(dictMagic))
510 out.Write(binary.LittleEndian.AppendUint32(nil, o.ID))
511 out.Write(scratch.OutTable)
512 if debug {
513 println("huff table:", len(scratch.OutTable), "bytes")
514 println("of table:", len(ofTable), "bytes")
515 println("ml table:", len(mlTable), "bytes")
516 println("ll table:", len(llTable), "bytes")
517 }
518 out.Write(ofTable)
519 out.Write(mlTable)
520 out.Write(llTable)
521 out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0])))
522 out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1])))
523 out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2])))
524 out.Write(hist)
525 if debug {
526 _, err := loadDict(out.Bytes())
527 if err != nil {
528 panic(err)
529 }
530 i, err := InspectDictionary(out.Bytes())
531 if err != nil {
532 panic(err)
533 }
534 println("ID:", i.ID())
535 println("Content size:", i.ContentSize())
536 println("Encoder:", i.LitEncoder() != nil)
537 println("Offsets:", i.Offsets())
538 var totalSize int
539 for _, b := range contents {
540 totalSize += len(b)
541 }
542
543 encWith := func(opts ...EOption) int {
544 enc, err := NewWriter(nil, opts...)
545 if err != nil {
546 panic(err)
547 }
548 defer enc.Close()
549 var dst []byte
550 var totalSize int
551 for _, b := range contents {
552 dst = enc.EncodeAll(b, dst[:0])
553 totalSize += len(dst)
554 }
555 return totalSize
556 }
557 plain := encWith(WithEncoderLevel(o.Level))
558 withDict := encWith(WithEncoderLevel(o.Level), WithEncoderDict(out.Bytes()))
559 println("Input size:", totalSize)
560 println("Plain Compressed:", plain)
561 println("Dict Compressed:", withDict)
562 println("Saved:", plain-withDict, (plain-withDict)/len(contents), "bytes per input (rounded down)")
563 }
564 return out.Bytes(), nil
565}