blob: 6073ce7c4474691163efeaa47ad0ec784e962557 [file] [log] [blame]
Abhay Kumar40252eb2025-10-13 13:25:53 +00001package sarama
2
3import (
4 "sync"
5
6 "github.com/klauspost/compress/zstd"
7)
8
9// zstdMaxBufferedEncoders maximum number of not-in-use zstd encoders
10// If the pool of encoders is exhausted then new encoders will be created on the fly
11const zstdMaxBufferedEncoders = 1
12
13type ZstdEncoderParams struct {
14 Level int
15}
16type ZstdDecoderParams struct {
17}
18
19var zstdDecMap sync.Map
20
21var zstdAvailableEncoders sync.Map
22
23func getZstdEncoderChannel(params ZstdEncoderParams) chan *zstd.Encoder {
24 if c, ok := zstdAvailableEncoders.Load(params); ok {
25 return c.(chan *zstd.Encoder)
26 }
27 c, _ := zstdAvailableEncoders.LoadOrStore(params, make(chan *zstd.Encoder, zstdMaxBufferedEncoders))
28 return c.(chan *zstd.Encoder)
29}
30
31func getZstdEncoder(params ZstdEncoderParams) *zstd.Encoder {
32 select {
33 case enc := <-getZstdEncoderChannel(params):
34 return enc
35 default:
36 encoderLevel := zstd.SpeedDefault
37 if params.Level != CompressionLevelDefault {
38 encoderLevel = zstd.EncoderLevelFromZstd(params.Level)
39 }
40 zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true),
41 zstd.WithEncoderLevel(encoderLevel),
42 zstd.WithEncoderConcurrency(1))
43 return zstdEnc
44 }
45}
46
47func releaseEncoder(params ZstdEncoderParams, enc *zstd.Encoder) {
48 select {
49 case getZstdEncoderChannel(params) <- enc:
50 default:
51 }
52}
53
54func getDecoder(params ZstdDecoderParams) *zstd.Decoder {
55 if ret, ok := zstdDecMap.Load(params); ok {
56 return ret.(*zstd.Decoder)
57 }
58 // It's possible to race and create multiple new readers.
59 // Only one will survive GC after use.
60 zstdDec, _ := zstd.NewReader(nil, zstd.WithDecoderConcurrency(0))
61 zstdDecMap.Store(params, zstdDec)
62 return zstdDec
63}
64
65func zstdDecompress(params ZstdDecoderParams, dst, src []byte) ([]byte, error) {
66 return getDecoder(params).DecodeAll(src, dst)
67}
68
69func zstdCompress(params ZstdEncoderParams, dst, src []byte) ([]byte, error) {
70 enc := getZstdEncoder(params)
71 out := enc.EncodeAll(src, dst)
72 releaseEncoder(params, enc)
73 return out, nil
74}