From d973cba9694cffbde179bb9d4a2f74030b9baafe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 9 Jul 2024 16:22:47 +0200 Subject: [PATCH] Add schema serde utilities (#73) * add schema serde utilities * update go version * move avro builder into avro package, rename avro.Schema to avro.Serde * resolve lint warnings * do not expose Sort method * use bytes instead of string * clarify comment * add comment about sorting --- go.mod | 5 +- go.sum | 2 + proto/schema/v1/schema.pb.go | 31 +- proto/schema/v1/schema.proto | 14 +- rabin/rabin.go | 86 +++ rabin/rabin_test.go | 52 ++ schema/{ => avro}/avro_builder.go | 20 +- .../{ => avro}/avro_builder_example_test.go | 6 +- schema/{ => avro}/avro_builder_test.go | 6 +- schema/avro/errors.go | 22 + schema/avro/extractor.go | 404 +++++++++++ schema/avro/serde.go | 105 +++ schema/avro/serde_test.go | 657 ++++++++++++++++++ schema/avro/traverse.go | 194 ++++++ schema/avro/union.go | 483 +++++++++++++ schema/avro/union_test.go | 155 +++++ schema/errors.go | 10 +- schema/proto.go | 4 +- schema/proto_test.go | 5 +- schema/schema.go | 97 +++ 20 files changed, 2324 insertions(+), 34 deletions(-) create mode 100644 rabin/rabin.go create mode 100644 rabin/rabin_test.go rename schema/{ => avro}/avro_builder.go (79%) rename schema/{ => avro}/avro_builder_example_test.go (94%) rename schema/{ => avro}/avro_builder_test.go (91%) create mode 100644 schema/avro/errors.go create mode 100644 schema/avro/extractor.go create mode 100644 schema/avro/serde.go create mode 100644 schema/avro/serde_test.go create mode 100644 schema/avro/traverse.go create mode 100644 schema/avro/union.go create mode 100644 schema/avro/union_test.go diff --git a/go.mod b/go.mod index 069d183..0da7f52 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/conduitio/conduit-commons -go 1.21.1 +go 1.22.4 require ( github.com/bufbuild/buf v1.34.0 @@ -11,6 +11,8 @@ require ( github.com/hamba/avro/v2 v2.22.1 github.com/matryer/is v1.4.1 github.com/mitchellh/mapstructure v1.5.0 + github.com/modern-go/reflect2 v1.0.2 + github.com/twmb/go-cache v1.2.1 go.uber.org/goleak v1.3.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/tools v0.23.0 @@ -158,7 +160,6 @@ require ( github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/term v0.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect - github.com/modern-go/reflect2 v1.0.2 // indirect github.com/moricho/tparallel v0.3.1 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/nakabonne/nestif v0.3.1 // indirect diff --git a/go.sum b/go.sum index 742a126..b1631ce 100644 --- a/go.sum +++ b/go.sum @@ -662,6 +662,8 @@ github.com/tomarrell/wrapcheck/v2 v2.8.3 h1:5ov+Cbhlgi7s/a42BprYoxsr73CbdMUTzE3b github.com/tomarrell/wrapcheck/v2 v2.8.3/go.mod h1:g9vNIyhb5/9TQgumxQyOEqDHsmGYcGsVMOx/xGkqdMo= github.com/tommy-muehle/go-mnd/v2 v2.5.1 h1:NowYhSdyE/1zwK9QCLeRb6USWdoif80Ie+v+yU8u1Zw= github.com/tommy-muehle/go-mnd/v2 v2.5.1/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw= +github.com/twmb/go-cache v1.2.1 h1:yUkLutow4S2x5NMbqFW24o14OsucoFI5Fzmlb6uBinM= +github.com/twmb/go-cache v1.2.1/go.mod h1:lArg9KhCl+GTFMikitLGhIBh/i11OK0lhSveqlMbbrY= github.com/ultraware/funlen v0.1.0 h1:BuqclbkY6pO+cvxoq7OsktIXZpgBSkYTQtmwhAK81vI= github.com/ultraware/funlen v0.1.0/go.mod h1:XJqmOQja6DpxarLj6Jj1U7JuoS8PvL4nEqDaQhy22p4= github.com/ultraware/whitespace v0.1.1 h1:bTPOGejYFulW3PkcrqkeQwOd6NKOOXvmGD9bo/Gk8VQ= diff --git a/proto/schema/v1/schema.pb.go b/proto/schema/v1/schema.pb.go index 22c177b..3fe8312 100644 --- a/proto/schema/v1/schema.pb.go +++ b/proto/schema/v1/schema.pb.go @@ -72,11 +72,18 @@ type Schema struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Subject string `protobuf:"bytes,1,opt,name=subject,proto3" json:"subject,omitempty"` - Version int32 `protobuf:"varint,2,opt,name=version,proto3" json:"version,omitempty"` - Type Schema_Type `protobuf:"varint,3,opt,name=type,proto3,enum=schema.v1.Schema_Type" json:"type,omitempty"` - // The schema contents - Bytes []byte `protobuf:"bytes,4,opt,name=bytes,proto3" json:"bytes,omitempty"` + // The subject of the schema. Together with the version, this uniquely + // identifies the schema. + Subject string `protobuf:"bytes,1,opt,name=subject,proto3" json:"subject,omitempty"` + // The version of the schema. Together with the subject, this uniquely + // identifies the schema. + Version int32 `protobuf:"varint,2,opt,name=version,proto3" json:"version,omitempty"` + // Uniquely identifies the schema contents (not the schema itself!). + Id int32 `protobuf:"varint,3,opt,name=id,proto3" json:"id,omitempty"` + // The schema type. + Type Schema_Type `protobuf:"varint,4,opt,name=type,proto3,enum=schema.v1.Schema_Type" json:"type,omitempty"` + // The schema contents. + Bytes []byte `protobuf:"bytes,5,opt,name=bytes,proto3" json:"bytes,omitempty"` } func (x *Schema) Reset() { @@ -125,6 +132,13 @@ func (x *Schema) GetVersion() int32 { return 0 } +func (x *Schema) GetId() int32 { + if x != nil { + return x.Id + } + return 0 +} + func (x *Schema) GetType() Schema_Type { if x != nil { return x.Type @@ -144,14 +158,15 @@ var File_schema_v1_schema_proto protoreflect.FileDescriptor var file_schema_v1_schema_proto_rawDesc = []byte{ 0x0a, 0x16, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x09, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, - 0x2e, 0x76, 0x31, 0x22, 0xab, 0x01, 0x0a, 0x06, 0x53, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x12, 0x18, + 0x2e, 0x76, 0x31, 0x22, 0xbb, 0x01, 0x0a, 0x06, 0x53, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x73, 0x75, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, - 0x6f, 0x6e, 0x12, 0x2a, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, + 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, + 0x69, 0x64, 0x12, 0x2a, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x14, - 0x0a, 0x05, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x62, + 0x0a, 0x05, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x62, 0x79, 0x74, 0x65, 0x73, 0x22, 0x2b, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x10, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x41, 0x56, 0x52, 0x4f, 0x10, diff --git a/proto/schema/v1/schema.proto b/proto/schema/v1/schema.proto index cb8d049..330780b 100644 --- a/proto/schema/v1/schema.proto +++ b/proto/schema/v1/schema.proto @@ -11,9 +11,17 @@ message Schema { TYPE_AVRO = 1; } + // The subject of the schema. Together with the version, this uniquely + // identifies the schema. string subject = 1; + // The version of the schema. Together with the subject, this uniquely + // identifies the schema. int32 version = 2; - Type type = 3; - // The schema contents - bytes bytes = 4; + + // Uniquely identifies the schema contents (not the schema itself!). + int32 id = 3; + // The schema type. + Type type = 4; + // The schema contents. + bytes bytes = 5; } diff --git a/rabin/rabin.go b/rabin/rabin.go new file mode 100644 index 0000000..7a498e6 --- /dev/null +++ b/rabin/rabin.go @@ -0,0 +1,86 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package rabin provides a Rabin fingerprint hash.Hash64 implementation compatible +// with the Avro spec: https://avro.apache.org/docs/1.8.2/spec.html#schema_fingerprints. +package rabin + +import "hash" + +type digest uint64 + +// New constructs a new Rabin fingerprint hash.Hash64 initialized to the empty +// state according to the Avro spec: https://avro.apache.org/docs/1.8.2/spec.html#schema_fingerprints. +func New() hash.Hash64 { + var d digest + d.Reset() + return &d +} + +func (d *digest) Write(p []byte) (n int, err error) { + *d = update(*d, p) + return len(p), nil +} + +func (d *digest) Sum64() uint64 { + return uint64(*d) +} + +func (d *digest) Sum(in []byte) []byte { + s := d.Sum64() + return append(in, byte(s>>56), byte(s>>48), byte(s>>40), byte(s>>32), byte(s>>24), byte(s>>16), byte(s>>8), byte(s)) +} + +func (d *digest) Reset() { + *d = digest(rabinEmpty) +} + +func (d *digest) Size() int { return 8 } +func (d *digest) BlockSize() int { return 1 } + +const rabinEmpty = uint64(0xc15d213aa4d7a795) + +// rabinTable is used to compute the CRC-64-AVRO fingerprint. +var rabinTable = newRabinFingerprintTable() + +// newRabinFingerprintTable initializes the fingerprint table according to the +// spec: https://avro.apache.org/docs/1.8.2/spec.html#schema_fingerprints +func newRabinFingerprintTable() [256]uint64 { + fpTable := [256]uint64{} + for i := 0; i < 256; i++ { + fp := uint64(i) + for j := 0; j < 8; j++ { + fp = (fp >> 1) ^ (rabinEmpty & -(fp & 1)) + } + fpTable[i] = fp + } + return fpTable +} + +// Bytes creates a Rabin fingerprint according to the spec: +// https://avro.apache.org/docs/1.8.2/spec.html#schema_fingerprints +func Bytes(buf []byte) uint64 { + h := New() + _, _ = h.Write(buf) // it never returns an error + return h.Sum64() +} + +// update adds p to the running checksum d. +func update(d digest, p []byte) digest { + fp := uint64(d) + for i := 0; i < len(p); i++ { + fp = (fp >> 8) ^ rabinTable[(byte(fp)^p[i])&0xff] + } + return digest(fp) +} diff --git a/rabin/rabin_test.go b/rabin/rabin_test.go new file mode 100644 index 0000000..9483290 --- /dev/null +++ b/rabin/rabin_test.go @@ -0,0 +1,52 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rabin + +import ( + "fmt" + "testing" + + "github.com/matryer/is" +) + +func TestRabin(t *testing.T) { + testCases := []struct { + have string + want uint64 + }{ + {have: `"int"`, want: 0x7275d51a3f395c8f}, + {have: `"string"`, want: 0x8f014872634503c7}, + {have: `"bool"`, want: 0x4a1c6b80ca0bcf48}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("Bytes_%v", tc.have), func(t *testing.T) { + is := is.New(t) + got := Bytes([]byte(tc.have)) + is.Equal(tc.want, got) + }) + t.Run(fmt.Sprintf("Hash_%v", tc.have), func(t *testing.T) { + is := is.New(t) + d := New() + + n, err := d.Write([]byte(tc.have)) + is.NoErr(err) + is.Equal(n, len(tc.have)) + + got := d.Sum64() + is.Equal(tc.want, got) + }) + } +} diff --git a/schema/avro_builder.go b/schema/avro/avro_builder.go similarity index 79% rename from schema/avro_builder.go rename to schema/avro/avro_builder.go index 7657af2..fa8830d 100644 --- a/schema/avro_builder.go +++ b/schema/avro/avro_builder.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package schema +package avro import ( "errors" @@ -21,21 +21,21 @@ import ( "github.com/hamba/avro/v2" ) -// AvroBuilder builds avro.RecordSchema instances and marshals them into JSON. -// AvroBuilder accepts arguments for creating fields and creates them internally +// Builder builds avro.RecordSchema instances and marshals them into JSON. +// Builder accepts arguments for creating fields and creates them internally // (i.e. a user doesn't need to create the fields). // All errors will be returned as a joined error when marshaling the schema to JSON. -type AvroBuilder struct { +type Builder struct { errs []error fields []*avro.Field name string namespace string } -// NewAvroBuilder constructs a new AvroBuilder and initializes it +// NewBuilder constructs a new Builder and initializes it // with the given name and namespace. -func NewAvroBuilder(name, namespace string) *AvroBuilder { - return &AvroBuilder{ +func NewBuilder(name, namespace string) *Builder { + return &Builder{ name: name, namespace: namespace, } @@ -44,7 +44,7 @@ func NewAvroBuilder(name, namespace string) *AvroBuilder { // AddField adds a new field with the given name, schema and schema options. // If creating the field returns an error, the error is saved, joined with // other errors (if any), and returned when marshaling to JSON. -func (b *AvroBuilder) AddField(name string, typ avro.Schema, opts ...avro.SchemaOption) *AvroBuilder { +func (b *Builder) AddField(name string, typ avro.Schema, opts ...avro.SchemaOption) *Builder { f, err := avro.NewField(name, typ, opts...) if err != nil { b.errs = append(b.errs, fmt.Errorf("field %v: %w", name, err)) @@ -58,7 +58,7 @@ func (b *AvroBuilder) AddField(name string, typ avro.Schema, opts ...avro.Schema // Build builds the underlying schema. // Errors that occurred while creating fields or constructing // the schema will be returned as a joined error. -func (b *AvroBuilder) Build() (*avro.RecordSchema, error) { +func (b *Builder) Build() (*avro.RecordSchema, error) { if b.errs != nil { return nil, errors.Join(b.errs...) } @@ -74,7 +74,7 @@ func (b *AvroBuilder) Build() (*avro.RecordSchema, error) { // MarshalJSON marshals the underlying schema to JSON. // Errors that occurred while creating fields, constructing // the schema or marshaling it will be returned as a joined error. -func (b *AvroBuilder) MarshalJSON() ([]byte, error) { +func (b *Builder) MarshalJSON() ([]byte, error) { schema, err := b.Build() if err != nil { return nil, err diff --git a/schema/avro_builder_example_test.go b/schema/avro/avro_builder_example_test.go similarity index 94% rename from schema/avro_builder_example_test.go rename to schema/avro/avro_builder_example_test.go index dd2676f..1b4a70e 100644 --- a/schema/avro_builder_example_test.go +++ b/schema/avro/avro_builder_example_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package schema +package avro import ( "fmt" @@ -21,12 +21,12 @@ import ( "github.com/hamba/avro/v2" ) -func ExampleAvroBuilder() { +func ExampleBuilder() { enumSchema, err := avro.NewEnumSchema("enum_schema", "enum_namespace", []string{"val1", "val2", "val3"}) if err != nil { panic(err) } - bytes, err := NewAvroBuilder("schema_name", "schema_namespace"). + bytes, err := NewBuilder("schema_name", "schema_namespace"). AddField("int_field", avro.NewPrimitiveSchema(avro.Int, nil), avro.WithDefault(100)). AddField("enum_field", enumSchema). MarshalJSON() diff --git a/schema/avro_builder_test.go b/schema/avro/avro_builder_test.go similarity index 91% rename from schema/avro_builder_test.go rename to schema/avro/avro_builder_test.go index 6628322..ad82a89 100644 --- a/schema/avro_builder_test.go +++ b/schema/avro/avro_builder_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package schema +package avro import ( "testing" @@ -21,7 +21,7 @@ import ( "github.com/matryer/is" ) -func TestAvroBuilder_Build(t *testing.T) { +func TestBuilder_Build(t *testing.T) { is := is.New(t) enumSchema, err := avro.NewEnumSchema("enum_schema", "enum_namespace", []string{"val1", "val2", "val3"}) @@ -43,7 +43,7 @@ func TestAvroBuilder_Build(t *testing.T) { want, err := wantSchema.MarshalJSON() is.NoErr(err) - got, err := NewAvroBuilder("schema_name", "schema_namespace"). + got, err := NewBuilder("schema_name", "schema_namespace"). AddField("int_field", avro.NewPrimitiveSchema(avro.Int, nil), avro.WithDefault(100)). AddField("enum_field", enumSchema). MarshalJSON() diff --git a/schema/avro/errors.go b/schema/avro/errors.go new file mode 100644 index 0000000..62114d1 --- /dev/null +++ b/schema/avro/errors.go @@ -0,0 +1,22 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package avro + +import "errors" + +var ( + ErrUnsupportedType = errors.New("unsupported avro type") + ErrSchemaValueMismatch = errors.New("avro schema doesn't match supplied value") +) diff --git a/schema/avro/extractor.go b/schema/avro/extractor.go new file mode 100644 index 0000000..3d5e2fc --- /dev/null +++ b/schema/avro/extractor.go @@ -0,0 +1,404 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package avro + +import ( + "fmt" + "reflect" + "strings" + "time" + + "github.com/conduitio/conduit-commons/opencdc" + "github.com/hamba/avro/v2" +) + +var ( + structuredDataType = reflect.TypeFor[opencdc.StructuredData]() + byteType = reflect.TypeFor[byte]() + timeType = reflect.TypeFor[time.Time]() +) + +// extractor exposes a way to extract an Avro schema from a Go value. +type extractor struct{} + +// Extract uses reflection to traverse the value and type of v and extract an +// Avro schema from it. There are some limitations that will cause this function +// to return an error, here are all known cases: +// - A fixed array of a type other than byte (e.g. [4]int). +// - A map with a key type other than string (e.g. map[int]any). +// - We only support built-in Avro types, which means that the following Go +// types are NOT supported: +// uint, uint64, complex64, complex128, chan, func, uintptr +// +// The function does its best to infer the schema, but it's working with limited +// information and has to make some assumptions: +// - If a map does not specify the type of its values (e.g. map[string]any), +// Extract will traverse all values in the map, extract their types and +// combine them in a union type. If the map is empty, the extracted value +// type will default to a nullable string (union type of string and null). +// - If a slice does not specify the type of its values (e.g. []any), Extract +// will traverse all values in the slice, extract their types and combine +// them in a union type. If the slice is empty, the extracted value type +// will default to a nullable string (union type of string and null). +// - If Extract encounters a value with the type of opencdc.StructuredData it +// will treat it as a record and extract a record schema, where each key in +// the structured data is extracted into its own record field. +func (e extractor) Extract(v any) (avro.Schema, error) { + return e.extract([]string{"record"}, reflect.ValueOf(v), reflect.TypeOf(v)) +} + +func (e extractor) extract(path []string, v reflect.Value, t reflect.Type) (avro.Schema, error) { + if t == nil { + return nil, fmt.Errorf("%s: can't get schema for untyped nil: %w", strings.Join(path, "."), ErrUnsupportedType) + } + switch t.Kind() { //nolint:exhaustive // some types are not supported + case reflect.Bool: + return avro.NewPrimitiveSchema(avro.Boolean, nil), nil + case reflect.Int64, reflect.Uint32: + return avro.NewPrimitiveSchema(avro.Long, nil), nil + case reflect.Int, reflect.Int32, reflect.Int16, reflect.Uint16, reflect.Int8, reflect.Uint8: + return avro.NewPrimitiveSchema(avro.Int, nil), nil + case reflect.Float32: + return avro.NewPrimitiveSchema(avro.Float, nil), nil + case reflect.Float64: + return avro.NewPrimitiveSchema(avro.Double, nil), nil + case reflect.String: + return avro.NewPrimitiveSchema(avro.String, nil), nil + case reflect.Pointer: + return e.extractPointer(path, v, t) + case reflect.Interface: + return e.extractInterface(path, v, t) + case reflect.Array: + if t.Elem() != byteType { + return nil, fmt.Errorf("%s: arrays with value type %v not supported, avro only supports bytes as values: %w", strings.Join(path, "."), t.Elem().String(), ErrUnsupportedType) + } + s, err := avro.NewFixedSchema(strings.Join(path, "."), "", t.Len(), nil) + if err != nil { + return nil, fmt.Errorf("%s: %w", strings.Join(path, "."), err) + } + return s, nil + case reflect.Slice: + return e.extractSlice(path, v, t) + case reflect.Map: + return e.extractMap(path, v, t) + case reflect.Struct: + if t == timeType { + return avro.NewPrimitiveSchema( + avro.Long, + avro.NewPrimitiveLogicalSchema(avro.TimestampMicros), + ), nil + } + return e.extractStruct(path, v, t) + default: + // Invalid, Uintptr, UnsafePointer, Uint64, Uint, Complex64, Complex128, Chan, Func + return nil, fmt.Errorf("%s: can't get schema for type %v: %w", strings.Join(path, "."), t, ErrUnsupportedType) + } +} + +// extractPointer extracts the schema behind the pointer and makes it nullable +// (if it's not already nullable). +func (e extractor) extractPointer(path []string, v reflect.Value, t reflect.Type) (avro.Schema, error) { + var vElem reflect.Value + if v.IsValid() { + vElem = v.Elem() + } + s, err := e.extract(path, vElem, t.Elem()) + if err != nil { + return nil, err + } + + var schemas avro.Schemas + if us, ok := s.(*avro.UnionSchema); ok && us.Nullable() { + // it's already a nullable schema + return s, nil + } else if ok { + // take types from union schema + schemas = us.Types() + } else if s.Type() != avro.Null { + // non-nil type + schemas = avro.Schemas{s} + } + + s, err = avro.NewUnionSchema(append(schemas, &avro.NullSchema{})) + if err != nil { + return nil, fmt.Errorf("%s: %w", strings.Join(path, "."), err) + } + + return s, nil +} + +// extractInterface ignores the type, since an interface doesn't say anything +// about the concrete type behind it. Instead, it looks at the value behind the +// interface and tries to extract the schema based on its actual type. +// If the value is nil we have no way of knowing the actual type, but since we +// need to be able to encode untyped nil values, we default to a nullable string. +func (e extractor) extractInterface(path []string, v reflect.Value, _ reflect.Type) (avro.Schema, error) { + if !v.IsValid() || v.IsNil() { + // unknown type, fall back to nullable string + s, err := avro.NewUnionSchema([]avro.Schema{ + avro.NewPrimitiveSchema(avro.String, nil), + &avro.NullSchema{}, + }) + if err != nil { + return nil, fmt.Errorf("%s: %w", strings.Join(path, "."), err) + } + return s, nil + } + return e.extract(path, v.Elem(), v.Elem().Type()) +} + +// extractSlice tries to extract the schema based on the slice value type. If +// that type is an interface it falls back to looping through all values, +// extracting their types and combining them into a nullable union schema. +func (e extractor) extractSlice(path []string, v reflect.Value, t reflect.Type) (avro.Schema, error) { + if t.Elem().Kind() == reflect.Uint8 { + return avro.NewPrimitiveSchema(avro.Bytes, nil), nil + } + + // try getting value type based on the slice type + if t.Elem().Kind() != reflect.Interface { + vs, err := e.extract(append(path, "item"), reflect.Value{}, t.Elem()) + if err != nil { + return nil, err + } + return avro.NewArraySchema(vs), nil + } + + // this is []any, loop through all values and extracting their types + // into a union schema, null is included by default + types := []avro.Schema{&avro.NullSchema{}} + for i := 0; i < v.Len(); i++ { + itemSchema, err := e.extract( + append(path, fmt.Sprintf("item%d", i)), + v.Index(i), t.Elem(), + ) + if err != nil { + return nil, err + } + types = append(types, itemSchema) + } + // we could have duplicate schemas, deduplicate them + types, err := e.deduplicate(types) + if err != nil { + return nil, err + } + + if v.Len() == 0 { + // it's an empty slice, add string to types to have a valid schema + types = append(types, avro.NewPrimitiveSchema(avro.String, nil)) + } + + itemsSchema, err := avro.NewUnionSchema(types) + if err != nil { + return nil, fmt.Errorf("%s: %w", strings.Join(path, "."), err) + } + return avro.NewArraySchema(itemsSchema), nil +} + +// extractMap tries to extract the schema based on the map value type. If that +// type is an interface it falls back to looping through all values, extracting +// their types and combining them into a nullable union schema. +// If the key of the map is not a string, this function returns an error. If the +// type of the map is opencdc.StructuredData it will treat it as a record and +// extract a record schema, where each key in the structured data is extracted +// into its own record field. +func (e extractor) extractMap(path []string, v reflect.Value, t reflect.Type) (avro.Schema, error) { + if t == structuredDataType { + // special case - we treat StructuredData like a struct + var fields []*avro.Field + valType := t.Elem() + for _, keyValue := range v.MapKeys() { + fs, err := e.extract(append(path, keyValue.String()), v.MapIndex(keyValue), valType) + if err != nil { + return nil, err + } + field, err := avro.NewField(keyValue.String(), fs) + if err != nil { + return nil, fmt.Errorf("%s: %w", strings.Join(path, "."), err) + } + fields = append(fields, field) + } + rs, err := avro.NewRecordSchema(strings.Join(path, "."), "", fields) + if err != nil { + return nil, fmt.Errorf("%s: %w", strings.Join(path, "."), err) + } + return rs, nil + } + if t.Key().Kind() != reflect.String { + return nil, fmt.Errorf("%s: maps with key type %v not supported, avro only supports strings as keys: %w", strings.Join(path, "."), t.Key().Kind(), ErrUnsupportedType) + } + // try getting value type based on the map type + if t.Elem().Kind() != reflect.Interface { + vs, err := e.extract(append(path, "value"), reflect.Value{}, t.Elem()) + if err != nil { + return nil, err + } + return avro.NewMapSchema(vs), nil + } + + // this is map[string]any, loop through all values and extracting their + // types into a union schema, null is included by default + types := []avro.Schema{&avro.NullSchema{}} + for _, kv := range v.MapKeys() { + valValue := v.MapIndex(kv) + vs, err := e.extract(append(path, "value"), valValue, t.Elem()) + if err != nil { + return nil, err + } + types = append(types, vs) + } + // we could have duplicate schemas, deduplicate them + types, err := e.deduplicate(types) + if err != nil { + return nil, err + } + + if len(v.MapKeys()) == 0 { + // it's an empty map, add string to types to have a valid schema + types = append(types, avro.NewPrimitiveSchema(avro.String, nil)) + } + vs, err := avro.NewUnionSchema(types) + if err != nil { + return nil, fmt.Errorf("%s: %w", strings.Join(path, "."), err) + } + return avro.NewMapSchema(vs), nil +} + +// extractStruct traverses the struct fields, extracts the schema for each field +// and combines them into a record schema. If the field contains a json tag, +// that tag is used for the extracted name of the field, otherwise it is the +// name of the Go struct field. If the json tag of a field contains "-" (i.e. +// ignored field), then the field is skipped. +func (e extractor) extractStruct(path []string, v reflect.Value, t reflect.Type) (avro.Schema, error) { + var fields []*avro.Field + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + name, ok := e.getStructFieldJSONName(sf) + if !ok { + continue // skip this field + } + var vfi reflect.Value + if v.IsValid() { + vfi = v.Field(i) + } + fs, err := e.extract(append(path, name), vfi, t.Field(i).Type) + if err != nil { + return nil, err + } + + field, err := avro.NewField(name, fs) + if err != nil { + return nil, fmt.Errorf("%s: %w", strings.Join(path, "."), err) + } + fields = append(fields, field) + } + rs, err := avro.NewRecordSchema(strings.Join(path, "."), "", fields) + if err != nil { + return nil, fmt.Errorf("%s: %w", strings.Join(path, "."), err) + } + return rs, nil +} + +//nolint:gocognit,funlen // this function is complex by nature +func (e extractor) deduplicate(schemas []avro.Schema) ([]avro.Schema, error) { + out := make([]avro.Schema, 0, len(schemas)) + typesSet := make(map[[32]byte]struct{}) + + var appendSchema func(schema avro.Schema) error + appendSchema = func(schema avro.Schema) error { + if _, ok := typesSet[schema.Fingerprint()]; ok { + return nil + } + if us, ok := schema.(*avro.UnionSchema); ok { + for _, st := range us.Types() { + if err := appendSchema(st); err != nil { + return err + } + } + return nil + } + for _, s := range out { + if s.Type() != schema.Type() { + continue + } + switch s := s.(type) { + case *avro.ArraySchema: + // we are combining two array schemas with different item + // schemas, combine them and create a new array schema + schema, ok := schema.(*avro.ArraySchema) + if !ok { + return fmt.Errorf("can't combine schemas of type %T and %T: %w", s, schema, ErrSchemaValueMismatch) + } + itemsSchema, err := e.deduplicate([]avro.Schema{s.Items(), schema.Items()}) + if err != nil { + return err + } + if len(itemsSchema) == 1 { + *s = *avro.NewArraySchema(itemsSchema[0]) + } else { + itemsUnionSchema, err := avro.NewUnionSchema(itemsSchema) + if err != nil { + return fmt.Errorf("failed to create union schema: %w", err) + } + *s = *avro.NewArraySchema(itemsUnionSchema) + } + case *avro.MapSchema: + schema, ok := schema.(*avro.MapSchema) + if !ok { + return fmt.Errorf("can't combine schemas of type %T and %T: %w", s, schema, ErrSchemaValueMismatch) + } + valuesSchema, err := e.deduplicate([]avro.Schema{s.Values(), schema.Values()}) + if err != nil { + return err + } + if len(valuesSchema) == 1 { + *s = *avro.NewMapSchema(valuesSchema[0]) + } else { + valuesUnionSchema, err := avro.NewUnionSchema(valuesSchema) + if err != nil { + return fmt.Errorf("failed to create union schema: %w", err) + } + *s = *avro.NewMapSchema(valuesUnionSchema) + } + default: + return fmt.Errorf("can't combine schemas of type %T: %w", s, ErrUnsupportedType) + } + return nil + } + + // schema does not exist yet + out = append(out, schema) + typesSet[schema.Fingerprint()] = struct{}{} + return nil + } + + for _, schema := range schemas { + if err := appendSchema(schema); err != nil { + return nil, err + } + } + return out, nil +} + +func (extractor) getStructFieldJSONName(sf reflect.StructField) (string, bool) { + jsonTag := strings.Split(sf.Tag.Get("json"), ",")[0] // ignore tag options (omitempty) + if jsonTag == "-" { + return "", false + } + if jsonTag != "" { + return jsonTag, true + } + return sf.Name, true +} diff --git a/schema/avro/serde.go b/schema/avro/serde.go new file mode 100644 index 0000000..5c3ba58 --- /dev/null +++ b/schema/avro/serde.go @@ -0,0 +1,105 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package avro + +import ( + "fmt" + + "github.com/hamba/avro/v2" +) + +// Serde represents an Avro schema. It exposes methods for marshaling and +// unmarshalling data. +type Serde struct { + schema avro.Schema + unionResolver unionResolver +} + +// Marshal returns the Avro encoding of v. Note that this function may mutate v. +// Limitations: +// - Map keys need to be of type string, +// - Array values need to be of type uint8 (byte). +func (s *Serde) Marshal(v any) ([]byte, error) { + err := s.unionResolver.BeforeMarshal(v) + if err != nil { + return nil, err + } + bytes, err := avro.Marshal(s.schema, v) + if err != nil { + return nil, fmt.Errorf("could not marshal into avro: %w", err) + } + return bytes, nil +} + +// Unmarshal parses the Avro encoded data and stores the result in the value +// pointed to by v. If v is nil or not a pointer, Unmarshal returns an error. +// Note that arrays and maps are unmarshalled into slices and maps with untyped +// values (i.e. []any and map[string]any). This is a limitation of the Avro +// library used for encoding/decoding the payload. +func (s *Serde) Unmarshal(b []byte, v any) error { + err := avro.Unmarshal(s.schema, b, v) + if err != nil { + return fmt.Errorf("could not unmarshal from avro: %w", err) + } + err = s.unionResolver.AfterUnmarshal(v) + if err != nil { + return err + } + return nil +} + +// String returns the canonical form of the schema. +func (s *Serde) String() string { + return s.schema.String() +} + +// sort fields in the schema. It can be used in tests to ensure the schemas can +// be compared. +func (s *Serde) sort() { + traverseSchema(s.schema, sortFn) +} + +// Parse parses a schema byte slice. +func Parse(text []byte) (*Serde, error) { + schema, err := avro.ParseBytes(text) + if err != nil { + return nil, fmt.Errorf("could not parse avro schema: %w", err) + } + // Note: We do not sort fields here because field order is significant in + // Avro schemas. Sorting would alter the schema and change the output. In + // SerdeForType, sorting ensures consistency when creating a schema from a + // value. However, when using Parse, we must preserve the original field + // order to match the schema definition. + return &Serde{ + schema: schema, + unionResolver: newUnionResolver(schema), + }, nil +} + +// SerdeForType uses reflection to extract an Avro schema from v. Maps are +// regarded as structs. +func SerdeForType(v any) (*Serde, error) { + schema, err := extractor{}.Extract(v) + if err != nil { + return nil, err + } + s := &Serde{ + schema: schema, + unionResolver: newUnionResolver(schema), + } + // Sort fields to ensure consistent schema representation. + s.sort() + return s, nil +} diff --git a/schema/avro/serde_test.go b/schema/avro/serde_test.go new file mode 100644 index 0000000..74b120f --- /dev/null +++ b/schema/avro/serde_test.go @@ -0,0 +1,657 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package avro + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/conduitio/conduit-commons/opencdc" + "github.com/hamba/avro/v2" + "github.com/matryer/is" +) + +func TestSerde_MarshalUnmarshal(t *testing.T) { + now := time.Now().UTC() + + testCases := []struct { + name string + // haveValue is the value we use to extract the schema and which gets marshaled + haveValue any + // wantValue is the expected value we get when haveValue gets marshaled and unmarshaled + wantValue any + // wantSchema is the schema expected to be extracted from haveValue + wantSchema avro.Schema + }{{ + name: "boolean", + haveValue: true, + wantValue: true, + wantSchema: avro.NewPrimitiveSchema(avro.Boolean, nil), + }, { + name: "boolean ptr (false)", + haveValue: func() *bool { var v bool; return &v }(), + wantValue: false, // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Boolean, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "boolean ptr (nil)", + haveValue: (*bool)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Boolean, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int", + haveValue: int(1), + wantValue: int(1), + wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "int ptr (0)", + haveValue: func() *int { var v int; return &v }(), + wantValue: 0, // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int ptr (nil)", + haveValue: (*int)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int64", + haveValue: int64(1), + wantValue: int64(1), + wantSchema: avro.NewPrimitiveSchema(avro.Long, nil), + }, { + name: "int64 ptr (0)", + haveValue: func() *int64 { var v int64; return &v }(), + wantValue: int64(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Long, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int64 ptr (nil)", + haveValue: (*int64)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Long, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int32", + haveValue: int32(1), + wantValue: int(1), + wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "int32 ptr (0)", + haveValue: func() *int32 { var v int32; return &v }(), + wantValue: int(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int32 ptr (nil)", + haveValue: (*int32)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int16", + haveValue: int16(1), + wantValue: int(1), + wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "int16 ptr (0)", + haveValue: func() *int16 { var v int16; return &v }(), + wantValue: int(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int16 ptr (nil)", + haveValue: (*int16)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int8", + haveValue: int8(1), + wantValue: int(1), + wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "int8 ptr (0)", + haveValue: func() *int8 { var v int8; return &v }(), + wantValue: int(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int8 ptr (nil)", + haveValue: (*int8)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "uint32", + haveValue: uint32(1), + wantValue: int64(1), + wantSchema: avro.NewPrimitiveSchema(avro.Long, nil), + }, { + name: "uint32 ptr (0)", + haveValue: func() *uint32 { var v uint32; return &v }(), + wantValue: int64(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Long, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "uint32 ptr (nil)", + haveValue: (*uint32)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Long, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "uint16", + haveValue: uint16(1), + wantValue: int(1), + wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "uint16 ptr (0)", + haveValue: func() *uint16 { var v uint16; return &v }(), + wantValue: int(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "uint16 ptr (nil)", + haveValue: (*uint16)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "uint8", + haveValue: uint8(1), + wantValue: int(1), + wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "uint8 ptr (0)", + haveValue: func() *uint8 { var v uint8; return &v }(), + wantValue: int(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "uint8 ptr (nil)", + haveValue: (*uint8)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "float64", + haveValue: float64(1), + wantValue: float64(1), + wantSchema: avro.NewPrimitiveSchema(avro.Double, nil), + }, { + name: "float64 ptr (0)", + haveValue: func() *float64 { var v float64; return &v }(), + wantValue: float64(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Double, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "float64 ptr (nil)", + haveValue: (*float64)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Double, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "float32", + haveValue: float32(1), + wantValue: float32(1), + wantSchema: avro.NewPrimitiveSchema(avro.Float, nil), + }, { + name: "float32 ptr (0)", + haveValue: func() *float32 { var v float32; return &v }(), + wantValue: float32(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Float, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "float32 ptr (nil)", + haveValue: (*float32)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Float, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "string", + haveValue: "1", + wantValue: "1", + wantSchema: avro.NewPrimitiveSchema(avro.String, nil), + }, { + name: "string ptr (empty)", + haveValue: func() *string { var v string; return &v }(), + wantValue: "", // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "string ptr (nil)", + haveValue: (*string)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "[]byte", + haveValue: []byte{1, 2, 3}, + wantValue: []byte{1, 2, 3}, + wantSchema: avro.NewPrimitiveSchema(avro.Bytes, nil), + }, { + name: "[4]byte", + haveValue: [4]byte{1, 2, 3, 4}, + wantValue: [4]byte{1, 2, 3, 4}, + wantSchema: must(avro.NewFixedSchema("record.foo", "", 4, nil)), + }, { + name: "nil", + haveValue: nil, + wantValue: nil, + wantSchema: must(avro.NewUnionSchema( // untyped nils default to nullable strings + []avro.Schema{ + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "[]int", + haveValue: []int{1, 2, 3}, + wantValue: []any{1, 2, 3}, + wantSchema: avro.NewArraySchema(avro.NewPrimitiveSchema(avro.Int, nil)), + }, { + name: "[]any (with data)", + haveValue: []any{1, "foo"}, + wantValue: []any{1, "foo"}, + wantSchema: avro.NewArraySchema(must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + ))), + }, { + name: "[]any (no data)", + haveValue: []any{}, + wantValue: []any(nil), // TODO: smells like a bug, should be []any{} + wantSchema: avro.NewArraySchema(must(avro.NewUnionSchema( // empty slice values default to nullable strings + []avro.Schema{ + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + ))), + }, { + name: "[][]int", + haveValue: [][]int{{1}, {2, 3}}, + wantValue: []any{[]any{1}, []any{2, 3}}, + wantSchema: avro.NewArraySchema(avro.NewArraySchema(avro.NewPrimitiveSchema(avro.Int, nil))), + }, { + name: "map[string]int", + haveValue: map[string]int{ + "foo": 1, + "bar": 2, + }, + wantValue: map[string]any{ // all maps are unmarshaled into map[string]any + "foo": 1, + "bar": 2, + }, + wantSchema: avro.NewMapSchema(avro.NewPrimitiveSchema(avro.Int, nil)), + }, { + name: "map[string]any (with primitive data)", + haveValue: map[string]any{ + "foo": "bar", + "foo2": "bar2", + "bar": 1, + "baz": true, + }, + wantValue: map[string]any{ + "foo": "bar", + "foo2": "bar2", + "bar": 1, + "baz": true, + }, + wantSchema: avro.NewMapSchema(must(avro.NewUnionSchema([]avro.Schema{ + &avro.NullSchema{}, + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewPrimitiveSchema(avro.Boolean, nil), + }))), + }, { + name: "map[string]any (with primitive array)", + haveValue: map[string]any{ + "foo": "bar", + "foo2": "bar2", + "bar": 1, + "baz": []int{1, 2, 3}, + }, + wantValue: map[string]any{ + "foo": "bar", + "foo2": "bar2", + "bar": 1, + "baz": []any{1, 2, 3}, + }, + wantSchema: avro.NewMapSchema(must(avro.NewUnionSchema([]avro.Schema{ + &avro.NullSchema{}, + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewArraySchema(avro.NewPrimitiveSchema(avro.Int, nil)), + }))), + }, { + name: "map[string]any (with union array)", + haveValue: map[string]any{ + "foo": "bar", + "foo2": "bar2", + "bar": 1, + "baz": []int{1, 2, 3}, + "baz2": []any{"foo", true}, + }, + wantValue: map[string]any{ + "foo": "bar", + "foo2": "bar2", + "bar": 1, + "baz": []any{1, 2, 3}, + "baz2": []any{"foo", true}, + }, + wantSchema: avro.NewMapSchema(must(avro.NewUnionSchema([]avro.Schema{ + &avro.NullSchema{}, + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewArraySchema(must(avro.NewUnionSchema([]avro.Schema{ + &avro.NullSchema{}, + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewPrimitiveSchema(avro.Boolean, nil), + }))), + }))), + }, { + name: "map[string]any (no data)", + haveValue: map[string]any{}, + wantValue: map[string]any{}, + wantSchema: avro.NewMapSchema(must(avro.NewUnionSchema([]avro.Schema{ // empty map values default to nullable strings + &avro.NullSchema{}, + avro.NewPrimitiveSchema(avro.String, nil), + }))), + }, { + name: "map[string]any (nested)", + haveValue: map[string]any{ + "foo": map[string]any{ + "bar": "baz", + "baz": 1, + }, + }, + wantValue: map[string]any{ + "foo": map[string]any{ + "bar": "baz", + "baz": 1, + }, + }, + wantSchema: avro.NewMapSchema(must(avro.NewUnionSchema([]avro.Schema{ + &avro.NullSchema{}, + avro.NewMapSchema(must(avro.NewUnionSchema([]avro.Schema{ + &avro.NullSchema{}, + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.String, nil), + }))), + }))), + }, { + name: "opencdc.StructuredData", + haveValue: opencdc.StructuredData{ + "foo": "bar", + "bar": 1, + "baz": []int{1, 2, 3}, + "tz": now, + }, + wantValue: map[string]any{ // structured data is unmarshaled into a map + "foo": "bar", + "bar": 1, + "baz": []any{1, 2, 3}, + "tz": now.Truncate(time.Microsecond), // Avro cannot does not support nanoseconds + }, + wantSchema: must(avro.NewRecordSchema( + "record.foo", + "", + []*avro.Field{ + must(avro.NewField("foo", avro.NewPrimitiveSchema(avro.String, nil))), + must(avro.NewField("bar", avro.NewPrimitiveSchema(avro.Int, nil))), + must(avro.NewField("baz", avro.NewArraySchema(avro.NewPrimitiveSchema(avro.Int, nil)))), + must(avro.NewField("tz", avro.NewPrimitiveSchema(avro.Long, avro.NewPrimitiveLogicalSchema(avro.TimestampMicros)))), + }, + )), + }} + + newRecord := func(v any) opencdc.StructuredData { + return opencdc.StructuredData{"foo": v} + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + is := is.New(t) + + // create new record with haveValue in field "foo" + haveValue := newRecord(tc.haveValue) + + // extract serde and ensure it matches the expectation + gotSerde, err := SerdeForType(haveValue) + is.NoErr(err) + + wantSerde := &Serde{ + schema: must(avro.NewRecordSchema("record", "", + []*avro.Field{must(avro.NewField("foo", tc.wantSchema))}, + )), + } + wantSerde.sort() + is.Equal(wantSerde.String(), gotSerde.String()) + + // now try to marshal the value with the schema + bytes, err := gotSerde.Marshal(haveValue) + is.NoErr(err) + + // unmarshal the bytes back into structured data and compare the value + var gotValue opencdc.StructuredData + err = gotSerde.Unmarshal(bytes, &gotValue) + is.NoErr(err) + + wantValue := newRecord(tc.wantValue) + is.Equal(wantValue, gotValue) + }) + } +} + +func TestSerdeForType_NestedStructuredData(t *testing.T) { + is := is.New(t) + + have := opencdc.StructuredData{ + "foo": "bar", + "level1": opencdc.StructuredData{ + "foo": "bar", + "level2": opencdc.StructuredData{ + "foo": "bar", + "level3": opencdc.StructuredData{ + "foo": "bar", + "regularMap": map[string]bool{}, + }, + }, + }, + } + + want := &Serde{schema: must(avro.NewRecordSchema( + "record", "", + []*avro.Field{ + must(avro.NewField("foo", avro.NewPrimitiveSchema(avro.String, nil))), + must(avro.NewField("level1", + must(avro.NewRecordSchema( + "record.level1", "", + []*avro.Field{ + must(avro.NewField("foo", avro.NewPrimitiveSchema(avro.String, nil))), + must(avro.NewField("level2", + must(avro.NewRecordSchema( + "record.level1.level2", "", + []*avro.Field{ + must(avro.NewField("foo", avro.NewPrimitiveSchema(avro.String, nil))), + must(avro.NewField("level3", + must(avro.NewRecordSchema( + "record.level1.level2.level3", "", + []*avro.Field{ + must(avro.NewField("foo", avro.NewPrimitiveSchema(avro.String, nil))), + must(avro.NewField("regularMap", avro.NewMapSchema( + avro.NewPrimitiveSchema(avro.Boolean, nil), + ))), + }, + )), + )), + }, + )), + )), + }, + )), + )), + }, + ))} + want.sort() + + got, err := SerdeForType(have) + is.NoErr(err) + is.Equal(want.String(), got.String()) + + bytes, err := got.Marshal(have) + is.NoErr(err) + // only try to unmarshal to ensure there's no error, other tests assert that + // umarshaled data matches the expectations + var unmarshaled opencdc.StructuredData + err = got.Unmarshal(bytes, &unmarshaled) + is.NoErr(err) +} + +func TestSerdeForType_UnsupportedTypes(t *testing.T) { + testCases := []struct { + val any + wantErr error + }{ + // avro only supports fixed byte arrays + {val: [4]int{}, wantErr: errors.New("record: arrays with value type int not supported, avro only supports bytes as values: unsupported avro type")}, + {val: [4]bool{}, wantErr: errors.New("record: arrays with value type bool not supported, avro only supports bytes as values: unsupported avro type")}, + // avro only supports maps with string keys + {val: map[int]string{}, wantErr: errors.New("record: maps with key type int not supported, avro only supports strings as keys: unsupported avro type")}, + {val: map[bool]string{}, wantErr: errors.New("record: maps with key type bool not supported, avro only supports strings as keys: unsupported avro type")}, + // avro only supports signed integers + {val: uint64(1), wantErr: errors.New("record: can't get schema for type uint64: unsupported avro type")}, + {val: uint(1), wantErr: errors.New("record: can't get schema for type uint: unsupported avro type")}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("%T", tc.val), func(t *testing.T) { + is := is.New(t) + _, err := SerdeForType(tc.val) + is.True(err != nil) + is.Equal(err.Error(), tc.wantErr.Error()) + }) + } +} + +func must[T any](f T, err error) T { + if err != nil { + panic(err) + } + return f +} diff --git a/schema/avro/traverse.go b/schema/avro/traverse.go new file mode 100644 index 0000000..8eced9f --- /dev/null +++ b/schema/avro/traverse.go @@ -0,0 +1,194 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package avro + +import ( + "errors" + "fmt" + "reflect" + "sort" + + "github.com/conduitio/conduit-commons/opencdc" + "github.com/hamba/avro/v2" +) + +type ( + // path represents a path from the root to a certain type in an avro schema. + path []leg + // leg is a single leg of a path. + leg struct { + schema avro.Schema + field *avro.Field + } +) + +// traverseSchema is a utility for traversing an avro schema and executing fn on +// every schema in the tree. +func traverseSchema(s avro.Schema, fn func(path)) { + var traverse func(avro.Schema, path) + traverse = func(s avro.Schema, p path) { + p = append(p, leg{s, nil}) + fn(p) + + // traverse deeper into types that have nested types + switch s := s.(type) { + case *avro.MapSchema: + traverse(s.Values(), p) + case *avro.ArraySchema: + traverse(s.Items(), p) + case *avro.RefSchema: + traverse(s.Schema(), p) + case *avro.RecordSchema: + fields := s.Fields() + p = p[:len(p)-1] + for _, field := range fields { + p = append(p, leg{s, field}) + traverse(field.Type(), p) + p = p[:len(p)-1] + } + case *avro.UnionSchema: + for _, st := range s.Types() { + traverse(st, p) + } + } + } + traverse(s, nil) +} + +// sortFn can be passed to traverse to deterministically sort fields in every +// record and types in every union. +func sortFn(p path) { + switch s := p[len(p)-1].schema.(type) { + case *avro.RecordSchema: + fields := s.Fields() + sort.Slice(fields, func(i, j int) bool { + return fields[i].Name() < fields[j].Name() + }) + case *avro.UnionSchema: + schemas := s.Types() + sort.Slice(schemas, func(i, j int) bool { + return schemas[i].String() < schemas[j].String() + }) + } +} + +// traverseValue is a utility to traverse val down to the path and call fn with +// all values found at the end of the path. If hasEncodedUnions is set to true, +// any map and array with a union type is expected to contain a map[string]any +// with a single key representing the name of the type it contains +// (e.g. {"int": 1}). +// If the value structure does not match the path p, traverseValue returns an +// error. +// +//nolint:gocognit,funlen // need to switch on avro type and have a case for each type +func traverseValue(val any, p path, hasEncodedUnions bool, fn func(v any)) error { + var traverse func(any, int) error + traverse = func(val any, index int) error { + if index == len(p)-1 { + // reached the end of the path, call fn + fn(val) + return nil + } + if val == nil { + return nil // can't traverse further, not an error though + } + switch l := p[index]; l.schema.Type() { //nolint:exhaustive // some types are not supported + case avro.Record: + switch val := val.(type) { + case map[string]any: + return traverse(val[l.field.Name()], index+1) + case opencdc.StructuredData: + return traverse(val[l.field.Name()], index+1) + case *map[string]any: + return traverse(*val, index) // traverse value + case *opencdc.StructuredData: + return traverse(*val, index) // traverse value + } + return newUnexpectedTypeError(avro.Record, map[string]any{}, val) + case avro.Array: + valArr, ok := val.([]any) + if !ok { + return newUnexpectedTypeError(avro.Array, []any{}, val) + } + for _, item := range valArr { + if err := traverse(item, index+1); err != nil { + return err + } + } + return nil + case avro.Map: + valMap, ok := val.(map[string]any) + if !ok { + return newUnexpectedTypeError(avro.Map, map[string]any{}, val) + } + for _, v := range valMap { + if err := traverse(v, index+1); err != nil { + return err + } + } + return nil + case avro.Ref: + // ignore ref and go deeper + return traverse(val, index+1) + case avro.Union: + if hasEncodedUnions && index > 0 && + (p[index-1].schema.Type() == avro.Map || p[index-1].schema.Type() == avro.Array) { + // it's a union value encoded as a map, traverse it + valMap, ok := val.(map[string]any) + if !ok { + return newUnexpectedTypeError(avro.Union, map[string]any{}, val) + } + if len(valMap) != 1 { + return fmt.Errorf("expected single value encoded as a map, got %d elements: %w", len(valMap), ErrSchemaValueMismatch) + } + for _, v := range valMap { + return traverse(v, index+1) // there's only one value, return + } + } + + // values are encoded normally, skip union + err := traverse(val, index+1) + var uterr *unexpectedTypeError + if errors.As(err, &uterr) { + // We allow unexpected type errors, we could be traversing a + // different branch in the union type that does not have the + // same structure. + return nil + } + return err + default: + return fmt.Errorf("can not traverse deeper in avro type %s: %w", l.schema.Type(), ErrUnsupportedType) + } + } + return traverse(val, 0) +} + +type unexpectedTypeError struct { + avroType avro.Type + expectedGoType string + actualGoType string +} + +func newUnexpectedTypeError(avroType avro.Type, expected any, actual any) *unexpectedTypeError { + return &unexpectedTypeError{ + avroType: avroType, + expectedGoType: reflect.TypeOf(expected).String(), + actualGoType: reflect.TypeOf(actual).String(), + } +} + +func (e *unexpectedTypeError) Error() string { + return fmt.Sprintf("expected Go type %s for avro type %s, got %s", e.expectedGoType, e.avroType, e.actualGoType) +} diff --git a/schema/avro/union.go b/schema/avro/union.go new file mode 100644 index 0000000..92308fb --- /dev/null +++ b/schema/avro/union.go @@ -0,0 +1,483 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package avro + +import ( + "fmt" + "reflect" + + "github.com/conduitio/conduit-commons/opencdc" + "github.com/hamba/avro/v2" + "github.com/modern-go/reflect2" +) + +// unionResolver provides hooks before marshaling and after unmarshaling a value +// with an Avro schema, which make sure that values under the schema type Union +// are in the correct shape (see https://github.com/hamba/avro#unions). +// NB: It currently supports union types nested in maps, but not nested in +// slices. For example, hooks will not work for values like []any{[]any{"foo"}}. +type unionResolver struct { + mapUnionPaths []path + arrayUnionPaths []path + nullUnionPaths []path + resolver *avro.TypeResolver +} + +// newUnionResolver takes a schema and extracts the paths to all maps and arrays +// with union types. With this information the resolver can traverse the values +// in BeforeMarshal and AfterUnmarshal directly to the value that needs to be +// substituted. +func newUnionResolver(schema avro.Schema) unionResolver { + var mapUnionPaths []path + var arrayUnionPaths []path + var nullUnionPaths []path + // traverse the schema and extract paths to all maps and arrays with a union + // as the value type + traverseSchema(schema, func(p path) { + switch { + case isMapUnion(p[len(p)-1].schema): + // path points to a map with a union type, copy and store it + pCopy := make(path, len(p)) + copy(pCopy, p) + mapUnionPaths = append(mapUnionPaths, pCopy) + case isArrayUnion(p[len(p)-1].schema): + // path points to an array with a union type, copy and store it + pCopy := make(path, len(p)) + copy(pCopy, p) + arrayUnionPaths = append(arrayUnionPaths, pCopy) + case isNullUnion(p[len(p)-1].schema): + // path points to a null union, copy and store it + pCopy := make(path, len(p)-1) + copy(pCopy, p[:len(p)-1]) + nullUnionPaths = append(nullUnionPaths, pCopy) + } + }) + return unionResolver{ + mapUnionPaths: mapUnionPaths, + arrayUnionPaths: arrayUnionPaths, + nullUnionPaths: nullUnionPaths, + resolver: avro.NewTypeResolver(), + } +} + +// AfterUnmarshal traverses the input value 'val' using the schema and finds all +// fields that are of the Avro Union type. In the input 'val', these Union type +// fields are represented as maps with a single key that contains the name of +// the type (e.g. map[string]any{"string":"foo"}). This function processes those +// maps and extracts the actual value from them (e.g. "foo"), replacing the map +// representation with the actual value. +func (r unionResolver) AfterUnmarshal(val any) error { + if len(r.mapUnionPaths) == 0 && + len(r.arrayUnionPaths) == 0 && + len(r.nullUnionPaths) == 0 { + return nil // shortcut + } + + substitutions, err := r.afterUnmarshalMapSubstitutions(val, nil) + if err != nil { + return err + } + substitutions, err = r.afterUnmarshalArraySubstitutions(val, substitutions) + if err != nil { + return err + } + substitutions, err = r.afterUnmarshalNullUnionSubstitutions(val, substitutions) + if err != nil { + return err + } + + // We now have a list of substitutions, simply apply them. + for _, sub := range substitutions { + sub.substitute() + } + return nil +} + +func (r unionResolver) afterUnmarshalMapSubstitutions(val any, substitutions []substitution) ([]substitution, error) { + for _, p := range r.mapUnionPaths { + // first collect all maps that have a union type in the schema + var maps []map[string]any + err := traverseValue(val, p, true, func(v any) { + if mapUnion, ok := v.(map[string]any); ok { + maps = append(maps, mapUnion) + } + }) + if err != nil { + return nil, err + } + + // Loop through collected maps and collect all substitutions. These maps + // contain values encoded as maps with a single key:value pair, where + // key is the type name (e.g. {"int":1}). We want to replace all these + // maps with the actual value (e.g. 1). + // We don't replace them in the loop, because we want to make sure all + // maps actually contain only 1 value. + for i, mapUnion := range maps { + for k, v := range mapUnion { + if v == nil { + // do no change nil values + continue + } + vmap, ok := v.(map[string]any) + if !ok { + return nil, fmt.Errorf("expected map[string]any, got %T: %w", v, ErrSchemaValueMismatch) + } + if len(vmap) != 1 { + return nil, fmt.Errorf("expected single value encoded as a map, got %d elements: %w", len(vmap), ErrSchemaValueMismatch) + } + + // this is a map with a single value, store the substitution + for _, actualVal := range vmap { + substitutions = append(substitutions, mapSubstitution{ + m: maps[i], + key: k, + val: actualVal, + }) + break + } + } + } + } + return substitutions, nil +} + +func (r unionResolver) afterUnmarshalArraySubstitutions(val any, substitutions []substitution) ([]substitution, error) { + for _, p := range r.arrayUnionPaths { + // first collect all arrays that have a union type in the schema + var arrays [][]any + err := traverseValue(val, p, true, func(v any) { + if arrayUnion, ok := v.([]any); ok { + arrays = append(arrays, arrayUnion) + } + }) + if err != nil { + return nil, err + } + + // Loop through collected arrays and collect all substitutions. These + // arrays contain values encoded as maps with a single key:value pair, + // where key is the type name (e.g. {"int":1}). We want to replace all + // these maps with the actual value (e.g. 1). + // We don't replace them in the loop, because we want to make sure all + // maps actually contain only 1 value. + for i, arrayUnion := range arrays { + for index, v := range arrayUnion { + if v == nil { + // do no change nil values + continue + } + vmap, ok := v.(map[string]any) + if !ok { + return nil, fmt.Errorf("expected map[string]any, got %T: %w", v, ErrSchemaValueMismatch) + } + if len(vmap) != 1 { + return nil, fmt.Errorf("expected single value encoded as a map, got %d elements: %w", len(vmap), ErrSchemaValueMismatch) + } + + // this is a map with a single value, store the substitution + for _, actualVal := range vmap { + substitutions = append(substitutions, arraySubstitution{ + a: arrays[i], + index: index, + val: actualVal, + }) + break + } + } + } + } + return substitutions, nil +} + +func (r unionResolver) afterUnmarshalNullUnionSubstitutions(val any, substitutions []substitution) ([]substitution, error) { + for _, p := range r.nullUnionPaths { + // first collect all values that are nullable + var maps []map[string]any + err := traverseValue(val, p, true, func(v any) { + switch v := v.(type) { + case map[string]any: + maps = append(maps, v) + case *map[string]any: + maps = append(maps, *v) + case *opencdc.StructuredData: + maps = append(maps, *v) + } + }) + if err != nil { + return nil, err + } + + // Loop through collected maps and collect all substitutions. These maps + // contain values encoded as maps with a single key:value pair, where + // key is the type name (e.g. {"int":1}). We want to replace all these + // maps with the actual value (e.g. 1). + // We don't replace them in the loop, because we want to make sure all + // maps actually contain only 1 value. + for i, mapUnion := range maps { + for k, v := range mapUnion { + if v == nil { + // do no change nil values + continue + } + vmap, ok := v.(map[string]any) + if !ok { + // if the value is not a map, it's not a nil value + continue + } + if len(vmap) != 1 { + return nil, fmt.Errorf("expected single value encoded as a map, got %d elements: %w", len(vmap), ErrSchemaValueMismatch) + } + + // this is a map with a single value, store the substitution + for _, actualVal := range vmap { + substitutions = append(substitutions, mapSubstitution{ + m: maps[i], + key: k, + val: actualVal, + }) + break + } + } + } + } + return substitutions, nil +} + +// BeforeMarshal traverses the value using the schema and finds all values that +// have the Avro type Union. Those values need to be changed to a map with a +// single key that contains the name of the type. This function takes that value +// (e.g. "foo") and hoists it into a map (e.g. map[string]any{"string":"foo"}). +func (r unionResolver) BeforeMarshal(val any) error { + if len(r.mapUnionPaths) == 0 && len(r.arrayUnionPaths) == 0 { + return nil // shortcut + } + + substitutions, err := r.beforeMarshalMapSubstitutions(val, nil) + if err != nil { + return err + } + substitutions, err = r.beforeMarshalArraySubstitutions(val, substitutions) + if err != nil { + return err + } + + // We now have a list of substitutions, simply apply them. + for _, sub := range substitutions { + sub.substitute() + } + return nil +} + +func (r unionResolver) beforeMarshalMapSubstitutions(val any, substitutions []substitution) ([]substitution, error) { + for _, p := range r.mapUnionPaths { + mapSchema, ok := p[len(p)-1].schema.(*avro.MapSchema) + if !ok { + return nil, fmt.Errorf("expected *avro.MapSchema, got %T: %w", p[len(p)-1].schema, ErrSchemaValueMismatch) + } + unionSchema, ok := mapSchema.Values().(*avro.UnionSchema) + if !ok { + return nil, fmt.Errorf("expected *avro.UnionSchema, got %T: %w", mapSchema.Values(), ErrSchemaValueMismatch) + } + + // first collect all maps that have a union type in the schema + var maps []map[string]any + err := traverseValue(val, p, false, func(v any) { + if mapUnion, ok := v.(map[string]any); ok { + maps = append(maps, mapUnion) + } + }) + if err != nil { + return nil, err + } + + // Loop through collected maps and collect all substitutions. We want + // to replace all non-nil values in these maps with maps that contain a + // single value, the key corresponds to the resolved name. + // We don't replace them in the loop, because we want to make sure all + // type names can be resolved first. + for i, mapUnion := range maps { + for k, v := range mapUnion { + if v == nil { + // do no change nil values + continue + } + valTypeName, err := r.resolveNameForType(v, unionSchema) + if err != nil { + return nil, err + } + substitutions = append(substitutions, mapSubstitution{ + m: maps[i], + key: k, + val: map[string]any{valTypeName: v}, + }) + } + } + } + return substitutions, nil +} + +func (r unionResolver) beforeMarshalArraySubstitutions(val any, substitutions []substitution) ([]substitution, error) { + for _, p := range r.arrayUnionPaths { + arraySchema, ok := p[len(p)-1].schema.(*avro.ArraySchema) + if !ok { + return nil, fmt.Errorf("expected *avro.ArraySchema, got %T: %w", p[len(p)-1].schema, ErrSchemaValueMismatch) + } + unionSchema, ok := arraySchema.Items().(*avro.UnionSchema) + if !ok { + return nil, fmt.Errorf("expected *avro.UnionSchema, got %T: %w", arraySchema.Items(), ErrSchemaValueMismatch) + } + + // first collect all array that have a union type in the schema + var arrays [][]any + err := traverseValue(val, p, false, func(v any) { + if arrayUnion, ok := v.([]any); ok { + arrays = append(arrays, arrayUnion) + } + }) + if err != nil { + return nil, err + } + + // Loop through collected arrays and collect all substitutions. We want + // to replace all non-nil values in these arrays with maps that contain a + // single value, the key corresponds to the resolved name. + // We don't replace them in the loop, because we want to make sure all + // type names can be resolved first. + for i, arrayUnion := range arrays { + for index, v := range arrayUnion { + if v == nil { + // do no change nil values + continue + } + valTypeName, err := r.resolveNameForType(v, unionSchema) + if err != nil { + return nil, err + } + substitutions = append(substitutions, arraySubstitution{ + a: arrays[i], + index: index, + val: map[string]any{valTypeName: v}, + }) + } + } + } + return substitutions, nil +} + +func (r unionResolver) resolveNameForType(v any, us *avro.UnionSchema) (string, error) { + var names []string + + t := reflect2.TypeOf(v) + switch t.Kind() { //nolint:exhaustive // some types are not supported + case reflect.Map: + names = []string{"map"} + case reflect.Slice: + if !t.Type1().Elem().AssignableTo(byteType) { // []byte is handled differently + names = []string{"array"} + break + } + fallthrough + default: + var err error + names, err = r.resolver.Name(t) + if err != nil { + return "", fmt.Errorf("could not resolve type name for %T: %w", v, err) + } + } + + for _, n := range names { + _, pos := us.Types().Get(n) + if pos > -1 { + return n, nil + } + } + return "", fmt.Errorf("can't resolve %v in union type %v: %w", names, us.String(), ErrSchemaValueMismatch) +} + +func isMapUnion(schema avro.Schema) bool { + s, ok := schema.(*avro.MapSchema) + if !ok { + return false + } + us, ok := s.Values().(*avro.UnionSchema) + if !ok { + return false + } + for _, s := range us.Types() { + // at least one of the types in the union must be a map or array for this + // to count as a map with a union type + if s.Type() == avro.Array || s.Type() == avro.Map { + return true + } + } + return false +} + +func isArrayUnion(schema avro.Schema) bool { + s, ok := schema.(*avro.ArraySchema) + if !ok { + return false + } + us, ok := s.Items().(*avro.UnionSchema) + if !ok { + return false + } + for _, s := range us.Types() { + // at least one of the types in the union must be a map or array for this + // to count as a map with a union type + if s.Type() == avro.Array || s.Type() == avro.Map { + return true + } + } + return false +} + +func isNullUnion(schema avro.Schema) bool { + s, ok := schema.(*avro.UnionSchema) + if !ok { + return false + } + if len(s.Types()) != 2 { + return false + } + for _, s := range s.Types() { + // at least one of the types in the union must be a map or array for this + // to count as a map with a union type + if s.Type() == avro.Null { + return true + } + } + return false +} + +type substitution interface { + substitute() +} + +type mapSubstitution struct { + m map[string]any + key string + val any +} + +func (s mapSubstitution) substitute() { s.m[s.key] = s.val } + +type arraySubstitution struct { + a []any + index int + val any +} + +func (s arraySubstitution) substitute() { s.a[s.index] = s.val } diff --git a/schema/avro/union_test.go b/schema/avro/union_test.go new file mode 100644 index 0000000..50c9aa0 --- /dev/null +++ b/schema/avro/union_test.go @@ -0,0 +1,155 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package avro + +import ( + "reflect" + "testing" + + "github.com/conduitio/conduit-commons/opencdc" + "github.com/matryer/is" +) + +func TestUnionResolver(t *testing.T) { + is := is.New(t) + + testCases := []struct { + name string + have any + want any + }{{ + name: "string", + have: "foo", + want: map[string]any{"string": "foo"}, + }, { + name: "int", + have: 123, + want: map[string]any{"int": 123}, + }, { + name: "boolean", + have: true, + want: map[string]any{"boolean": true}, + }, { + name: "double", + have: 1.23, + want: map[string]any{"double": 1.23}, + }, { + name: "float", + have: float32(1.23), + want: map[string]any{"float": float32(1.23)}, + }, { + name: "long", + have: int64(321), + want: map[string]any{"long": int64(321)}, + }, { + name: "bytes", + have: []byte{1, 2, 3, 4}, + want: map[string]any{"bytes": []byte{1, 2, 3, 4}}, + }, { + name: "null", + have: nil, + want: nil, + }, { + name: "int array", + have: []int{1, 2, 3, 4}, + want: map[string]any{"array": []int{1, 2, 3, 4}}, + }, { + name: "nil bool array", + have: []bool(nil), + want: map[string]any{"array": []bool(nil)}, + }} + + isSlice := func(a any) bool { + if a == nil { + return false + } + // returns true if the type is a slice and not a byte slice + t := reflect.TypeOf(a) + return t.Kind() == reflect.Slice && !t.Elem().AssignableTo(byteType) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + is := is.New(t) + + newRecord := func() opencdc.StructuredData { + sd := opencdc.StructuredData{ + "foo1": tc.have, + "map1": map[string]any{ + "foo2": tc.have, + "map2": map[string]any{ + "foo3": tc.have, + }, + }, + "arr1": []any{ + tc.have, + []any{tc.have}, + }, + } + return sd + } + want := opencdc.StructuredData{ + "foo1": tc.have, // normal field shouldn't change + "map1": map[string]any{ + "foo2": tc.want, + "map2": map[string]any{ + "map": map[string]any{ + "foo3": func() any { + // if the original value is a slice, we consider + // the type a union and wrap it in a map, otherwise + // we keep the original value + if isSlice(tc.have) { + return tc.want + } + return tc.have + }(), + }, + }, + }, + "arr1": []any{ + tc.want, + map[string]any{ + "array": []any{ + func() any { + // if the original value is a slice, we consider + // the type a union and wrap it in a map, otherwise + // we keep the original value + if isSlice(tc.have) { + return tc.want + } + return tc.have + }(), + }, + }, + }, + } + have := newRecord() + + serde, err := SerdeForType(have) + is.NoErr(err) + mur := newUnionResolver(serde.schema) + + // before marshal we should change the nested map + err = mur.BeforeMarshal(have) + is.NoErr(err) + is.Equal(want, have) + + // after unmarshal we should have the same record as at the start + err = mur.AfterUnmarshal(have) + is.NoErr(err) + is.Equal(newRecord(), have) + }) + } +} diff --git a/schema/errors.go b/schema/errors.go index 716b7bc..6e80659 100644 --- a/schema/errors.go +++ b/schema/errors.go @@ -16,5 +16,11 @@ package schema import "errors" -// errInvalidProtoIsNil is returned when trying to convert a schema object to a proto schema, and the proto is nil. -var errInvalidProtoIsNil = errors.New("invalid proto: nil") +var ( + // ErrInvalidProtoIsNil is returned when trying to convert a schema object to a + // proto schema, and the proto is nil. + ErrInvalidProtoIsNil = errors.New("invalid proto: nil") + + // ErrUnsupportedType is returned when an unsupported type is encountered. + ErrUnsupportedType = errors.New("unsupported type") +) diff --git a/schema/proto.go b/schema/proto.go index 0cc3bd6..0deedf6 100644 --- a/schema/proto.go +++ b/schema/proto.go @@ -38,6 +38,7 @@ func (s *Schema) FromProto(proto *schemav1.Schema) error { s.Subject = proto.Subject s.Version = int(proto.Version) + s.ID = int(proto.Id) s.Type = Type(proto.Type) s.Bytes = proto.Bytes @@ -51,11 +52,12 @@ func (s *Schema) FromProto(proto *schemav1.Schema) error { // populated. func (s *Schema) ToProto(proto *schemav1.Schema) error { if proto == nil { - return errInvalidProtoIsNil + return ErrInvalidProtoIsNil } proto.Subject = s.Subject proto.Version = int32(s.Version) + proto.Id = int32(s.ID) proto.Type = schemav1.Schema_Type(s.Type) proto.Bytes = s.Bytes diff --git a/schema/proto_test.go b/schema/proto_test.go index 81192bb..987035b 100644 --- a/schema/proto_test.go +++ b/schema/proto_test.go @@ -15,6 +15,7 @@ package schema import ( + "errors" "testing" schemav1 "github.com/conduitio/conduit-commons/proto/schema/v1" @@ -57,7 +58,7 @@ func TestSchema_ToProto(t *testing.T) { name: "when proto object is nil", in: nil, want: nil, - wantErr: errInvalidProtoIsNil, + wantErr: ErrInvalidProtoIsNil, }, { name: "when proto object is not nil", @@ -88,7 +89,7 @@ func TestSchema_ToProto(t *testing.T) { is.NoErr(err) is.Equal(tc.in, tc.want) } else { - is.Equal(err.Error(), tc.wantErr.Error()) + is.True(errors.Is(err, tc.wantErr)) } }) } diff --git a/schema/schema.go b/schema/schema.go index 456db57..92c7b40 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -16,6 +16,15 @@ package schema +import ( + "fmt" + "time" + + "github.com/conduitio/conduit-commons/rabin" + "github.com/conduitio/conduit-commons/schema/avro" + "github.com/twmb/go-cache/cache" +) + type Type int32 const ( @@ -25,6 +34,94 @@ const ( type Schema struct { Subject string Version int + ID int Type Type Bytes []byte } + +// Marshal returns the encoded representation of v. +func (s Schema) Marshal(v any) ([]byte, error) { + srd, err := s.Serde() + if err != nil { + return nil, err + } + out, err := srd.Marshal(v) + if err != nil { + return nil, fmt.Errorf("failed to marshal data with schema %v:%v (id: %v): %w", s.Subject, s.Version, s.ID, err) + } + return out, nil +} + +// Unmarshal parses encoded data and stores the result in the value pointed +// to by v. If v is nil or not a pointer, Unmarshal returns an error. +func (s Schema) Unmarshal(b []byte, v any) error { + srd, err := s.Serde() + if err != nil { + return err + } + err = srd.Unmarshal(b, v) + if err != nil { + return fmt.Errorf("failed to unmarshal data with schema %v:%v (id: %v): %w", s.Subject, s.Version, s.ID, err) + } + return nil +} + +// Fingerprint returns a unique 64 bit identifier for the schema. +func (s Schema) Fingerprint() uint64 { + return rabin.Bytes(s.Bytes) +} + +// Serde returns the serde for the schema. +func (s Schema) Serde() (Serde, error) { + srd, err, _ := globalSerdeCache.Get(s.Fingerprint(), func() (Serde, error) { + factory, ok := KnownSerdeFactories[s.Type] + if !ok { + return nil, fmt.Errorf("failed to get serde for schema type %s: %w", s.Type, ErrUnsupportedType) + } + srd, err := factory.Parse(s.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse schema of type %s: %w", s.Type, err) + } + return srd, nil + }) + if err != nil { + return nil, err //nolint:wrapcheck // errors are already wrapped in the miss function + } + return srd, nil +} + +// globalSerdeCache is a concurrency safe cache of serdes by schema fingerprint. +// Every process uses a global cache to avoid re-parsing the same schema multiple +// times. Since the cache is global, it is important to ensure that the cache is +// cleaned up periodically to avoid memory leaks (e.g. if a pipeline is stopped +// and the schemas it processed are no longer needed). +var globalSerdeCache = cache.New[uint64, Serde]( + cache.AutoCleanInterval(time.Hour), // clean up every hour + cache.MaxAge(4*time.Hour), // expire entries after 4 hours +) + +// Serde represents a serializer/deserializer. +type Serde interface { + // Marshal returns the encoded representation of v. + Marshal(v any) ([]byte, error) + // Unmarshal parses encoded data and stores the result in the value pointed + // to by v. If v is nil or not a pointer, Unmarshal returns an error. + Unmarshal(b []byte, v any) error + // String returns the textual representation of the schema used by this serde. + String() string +} + +type SerdeFactory struct { + // Parse takes the textual representation of the schema and parses it into + // a Schema. + Parse func([]byte) (Serde, error) + // SerdeForType returns a Schema that matches the structure of v. + SerdeForType func(v any) (Serde, error) +} + +var KnownSerdeFactories = map[Type]SerdeFactory{ + TypeAvro: { + Parse: func(s []byte) (Serde, error) { return avro.Parse(s) }, + SerdeForType: func(v any) (Serde, error) { return avro.SerdeForType(v) }, + }, +}