Skip to content

Commit

Permalink
feat: allow setting zstd codec options
Browse files Browse the repository at this point in the history
  • Loading branch information
vpapp committed Jan 6, 2025
1 parent cb79c9f commit c742a5b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 23 deletions.
15 changes: 10 additions & 5 deletions ocf/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ const (
ZStandard CodecName = "zstandard"
)

func resolveCodec(name CodecName, lvl int) (Codec, error) {
type zstdOptions struct {
EOptions []zstd.EOption
DOptions []zstd.DOption
}

func resolveCodec(name CodecName, lvl int, zstdOpts zstdOptions) (Codec, error) {
switch name {
case Null, "":
return &NullCodec{}, nil
Expand All @@ -36,7 +41,7 @@ func resolveCodec(name CodecName, lvl int) (Codec, error) {
return &SnappyCodec{}, nil

case ZStandard:
return newZStandardCodec(), nil
return newZStandardCodec(zstdOpts), nil

default:
return nil, fmt.Errorf("unknown codec %s", name)
Expand Down Expand Up @@ -132,9 +137,9 @@ type ZStandardCodec struct {
encoder *zstd.Encoder
}

func newZStandardCodec() *ZStandardCodec {
decoder, _ := zstd.NewReader(nil)
encoder, _ := zstd.NewWriter(nil)
func newZStandardCodec(opts zstdOptions) *ZStandardCodec {
decoder, _ := zstd.NewReader(nil, opts.DOptions...)
encoder, _ := zstd.NewWriter(nil, opts.EOptions...)
return &ZStandardCodec{
decoder: decoder,
encoder: encoder,
Expand Down
6 changes: 3 additions & 3 deletions ocf/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func BenchmarkZstdEncodeDecodeLowEntropyLong(b *testing.B) {

input := makeTestData(8762, func() byte { return 'a' })

codec, err := resolveCodec(ZStandard, 0)
codec, err := resolveCodec(ZStandard, 0, zstdOptions{})
require.NoError(b, err)

b.ReportAllocs()
Expand All @@ -74,7 +74,7 @@ func BenchmarkZstdEncodeDecodeLowEntropyLong(b *testing.B) {
func BenchmarkZstdEncodeDecodeHighEntropyLong(b *testing.B) {
input := makeTestData(8762, func() byte { return byte(rand.Uint32()) })

codec, err := resolveCodec(ZStandard, 0)
codec, err := resolveCodec(ZStandard, 0, zstdOptions{})
require.NoError(b, err)

b.ReportAllocs()
Expand All @@ -87,7 +87,7 @@ func BenchmarkZstdEncodeDecodeHighEntropyLong(b *testing.B) {
}

func verifyZstdEncodeDecode(t *testing.T, input []byte) {
codec, err := resolveCodec(ZStandard, 0)
codec, err := resolveCodec(ZStandard, 0, zstdOptions{})
require.NoError(t, err)

compressed := codec.Encode(input)
Expand Down
47 changes: 32 additions & 15 deletions ocf/ocf.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/hamba/avro/v2"
"github.com/hamba/avro/v2/internal/bytesx"
"github.com/klauspost/compress/zstd"
)

const (
Expand Down Expand Up @@ -52,8 +53,9 @@ type Header struct {
}

type decoderConfig struct {
DecoderConfig avro.API
SchemaCache *avro.SchemaCache
DecoderConfig avro.API
SchemaCache *avro.SchemaCache
ZStandardDecoderOptions []zstd.DOption
}

// DecoderFunc represents a configuration function for Decoder.
Expand All @@ -74,6 +76,13 @@ func WithDecoderSchemaCache(cache *avro.SchemaCache) DecoderFunc {
}
}

// WithZStandardDecoderOptions sets the options for the ZStandard decoder.
func WithZStandardDecoderOptions(opts ...zstd.DOption) DecoderFunc {
return func(cfg *decoderConfig) {
cfg.ZStandardDecoderOptions = append(cfg.ZStandardDecoderOptions, opts...)
}
}

// Decoder reads and decodes Avro values from a container file.
type Decoder struct {
reader *avro.Reader
Expand All @@ -100,7 +109,7 @@ func NewDecoder(r io.Reader, opts ...DecoderFunc) (*Decoder, error) {

reader := avro.NewReader(r, 1024)

h, err := readHeader(reader, cfg.SchemaCache)
h, err := readHeader(reader, cfg.SchemaCache, zstdOptions{DOptions: cfg.ZStandardDecoderOptions})
if err != nil {
return nil, fmt.Errorf("decoder: %w", err)
}
Expand Down Expand Up @@ -197,14 +206,15 @@ func (d *Decoder) readBlock() int64 {
}

type encoderConfig struct {
BlockLength int
CodecName CodecName
CodecCompression int
Metadata map[string][]byte
Sync [16]byte
EncodingConfig avro.API
SchemaCache *avro.SchemaCache
SchemaMarshaler func(avro.Schema) ([]byte, error)
BlockLength int
CodecName CodecName
CodecCompression int
ZStandardEncoderOptions []zstd.EOption
Metadata map[string][]byte
Sync [16]byte
EncodingConfig avro.API
SchemaCache *avro.SchemaCache
SchemaMarshaler func(avro.Schema) ([]byte, error)
}

// EncoderFunc represents a configuration function for Encoder.
Expand Down Expand Up @@ -233,6 +243,13 @@ func WithCompressionLevel(compLvl int) EncoderFunc {
}
}

// WithZStandardEncoderOptions sets the options for the ZStandard encoder.
func WithZStandardEncoderOptions(opts ...zstd.EOption) EncoderFunc {
return func(cfg *encoderConfig) {
cfg.ZStandardEncoderOptions = append(cfg.ZStandardEncoderOptions, opts...)
}
}

// WithMetadata sets the metadata on the encoder header.
func WithMetadata(meta map[string][]byte) EncoderFunc {
return func(cfg *encoderConfig) {
Expand Down Expand Up @@ -316,7 +333,7 @@ func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, e

if info.Size() > 0 {
reader := avro.NewReader(file, 1024)
h, err := readHeader(reader, cfg.SchemaCache)
h, err := readHeader(reader, cfg.SchemaCache, zstdOptions{EOptions: cfg.ZStandardEncoderOptions})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -354,7 +371,7 @@ func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, e
_, _ = rand.Read(header.Sync[:])
}

codec, err := resolveCodec(cfg.CodecName, cfg.CodecCompression)
codec, err := resolveCodec(cfg.CodecName, cfg.CodecCompression, zstdOptions{EOptions: cfg.ZStandardEncoderOptions})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -469,7 +486,7 @@ type ocfHeader struct {
Sync [16]byte
}

func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache) (*ocfHeader, error) {
func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache, zstdOpts zstdOptions) (*ocfHeader, error) {
var h Header
reader.ReadVal(HeaderSchema, &h)
if reader.Error != nil {
Expand All @@ -484,7 +501,7 @@ func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache) (*ocfHeader,
return nil, err
}

codec, err := resolveCodec(CodecName(h.Meta[codecKey]), -1)
codec, err := resolveCodec(CodecName(h.Meta[codecKey]), -1, zstdOpts)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit c742a5b

Please sign in to comment.