diff --git a/go.mod b/go.mod index 4853716f6..fc0bdb808 100644 --- a/go.mod +++ b/go.mod @@ -29,7 +29,7 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.1 github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 - github.com/hamba/avro/v2 v2.19.0 + github.com/hamba/avro/v2 v2.21.1 github.com/hashicorp/go-hclog v1.6.3 github.com/hashicorp/go-plugin v1.6.0 github.com/jackc/pgx/v5 v5.5.5 diff --git a/go.sum b/go.sum index a71b7fed8..1b5ea852a 100644 --- a/go.sum +++ b/go.sum @@ -679,8 +679,8 @@ github.com/gostaticanalysis/testutil v0.4.0/go.mod h1:bLIoPefWXrRi/ssLFWX1dx7Rep github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 h1:/c3QmbOGMGTOumP2iT/rCwB7b0QDGLKzqOmktBjT+Is= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1/go.mod h1:5SN9VR2LTsRFsrEC6FHgRbTWrTHu6tqPeKxEQv15giM= -github.com/hamba/avro/v2 v2.19.0 h1:jITwvb03UMLfTFHFKdvaMyU/G96iVWS5EiMsqo3flfE= -github.com/hamba/avro/v2 v2.19.0/go.mod h1:72DkWmMmAyZA+qHoI89u4RMCQ3X54vpEb1ap80iCIBg= +github.com/hamba/avro/v2 v2.21.1 h1:400/jTdLWQ3ib58y83VXlTJKijRouYQszY1SO0cMGt4= +github.com/hamba/avro/v2 v2.21.1/go.mod h1:ouJ4PkiAEP49u0lAtQyd5Gv04MehKj+7lXwD3zpLpY0= github.com/hanwen/go-fuse v1.0.0/go.mod h1:unqXarDXqzAk0rt98O2tVndEPIpUgLD9+rwFisZH3Ok= github.com/hanwen/go-fuse/v2 v2.1.0/go.mod h1:oRyA5eK+pvJyv5otpO/DgccS8y/RvYMaO00GgRLGryc= github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= diff --git a/pkg/plugin/processor/builtin/impl/avro/schemaregistry/avro/schema_test.go b/pkg/plugin/processor/builtin/impl/avro/schemaregistry/avro/schema_test.go index 05b121e3d..a52a8f599 100644 --- a/pkg/plugin/processor/builtin/impl/avro/schemaregistry/avro/schema_test.go +++ b/pkg/plugin/processor/builtin/impl/avro/schemaregistry/avro/schema_test.go @@ -50,7 +50,7 @@ func TestSchema_MarshalUnmarshal(t *testing.T) { )), }, { name: "boolean ptr (nil)", - haveValue: func() *bool { return nil }(), + haveValue: (*bool)(nil), wantValue: nil, // when unmarshaling we get an untyped nil wantSchema: must(avro.NewUnionSchema( []avro.Schema{ @@ -63,56 +63,276 @@ func TestSchema_MarshalUnmarshal(t *testing.T) { haveValue: int(1), wantValue: int(1), wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "int ptr (0)", + haveValue: func() *int { var v int; return &v }(), + wantValue: 0, // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int ptr (nil)", + haveValue: (*int)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "int64", haveValue: int64(1), wantValue: int64(1), wantSchema: avro.NewPrimitiveSchema(avro.Long, nil), + }, { + name: "int64 ptr (0)", + haveValue: func() *int64 { var v int64; return &v }(), + wantValue: int64(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Long, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int64 ptr (nil)", + haveValue: (*int64)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Long, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "int32", haveValue: int32(1), wantValue: int(1), wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "int32 ptr (0)", + haveValue: func() *int32 { var v int32; return &v }(), + wantValue: int(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int32 ptr (nil)", + haveValue: (*int32)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "int16", haveValue: int16(1), wantValue: int(1), wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "int16 ptr (0)", + haveValue: func() *int16 { var v int16; return &v }(), + wantValue: int(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int16 ptr (nil)", + haveValue: (*int16)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "int8", haveValue: int8(1), wantValue: int(1), wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "int8 ptr (0)", + haveValue: func() *int8 { var v int8; return &v }(), + wantValue: int(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "int8 ptr (nil)", + haveValue: (*int8)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "uint32", haveValue: uint32(1), wantValue: int64(1), wantSchema: avro.NewPrimitiveSchema(avro.Long, nil), + }, { + name: "uint32 ptr (0)", + haveValue: func() *uint32 { var v uint32; return &v }(), + wantValue: int64(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Long, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "uint32 ptr (nil)", + haveValue: (*uint32)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Long, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "uint16", haveValue: uint16(1), wantValue: int(1), wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "uint16 ptr (0)", + haveValue: func() *uint16 { var v uint16; return &v }(), + wantValue: int(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "uint16 ptr (nil)", + haveValue: (*uint16)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "uint8", haveValue: uint8(1), wantValue: int(1), wantSchema: avro.NewPrimitiveSchema(avro.Int, nil), + }, { + name: "uint8 ptr (0)", + haveValue: func() *uint8 { var v uint8; return &v }(), + wantValue: int(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "uint8 ptr (nil)", + haveValue: (*uint8)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Int, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "float64", haveValue: float64(1), wantValue: float64(1), wantSchema: avro.NewPrimitiveSchema(avro.Double, nil), + }, { + name: "float64 ptr (0)", + haveValue: func() *float64 { var v float64; return &v }(), + wantValue: float64(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Double, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "float64 ptr (nil)", + haveValue: (*float64)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Double, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "float32", haveValue: float32(1), wantValue: float32(1), wantSchema: avro.NewPrimitiveSchema(avro.Float, nil), + }, { + name: "float32 ptr (0)", + haveValue: func() *float32 { var v float32; return &v }(), + wantValue: float32(0), // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Float, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "float32 ptr (nil)", + haveValue: (*float32)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.Float, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "string", haveValue: "1", wantValue: "1", wantSchema: avro.NewPrimitiveSchema(avro.String, nil), + }, { + name: "string ptr (empty)", + haveValue: func() *string { var v string; return &v }(), + wantValue: "", // ptr is unmarshalled into value + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), + }, { + name: "string ptr (nil)", + haveValue: (*string)(nil), + wantValue: nil, // when unmarshaling we get an untyped nil + wantSchema: must(avro.NewUnionSchema( + []avro.Schema{ + avro.NewPrimitiveSchema(avro.String, nil), + avro.NewPrimitiveSchema(avro.Null, nil), + }, + )), }, { name: "[]byte", haveValue: []byte{1, 2, 3}, @@ -152,7 +372,7 @@ func TestSchema_MarshalUnmarshal(t *testing.T) { }, { name: "[]any (no data)", haveValue: []any{}, - wantValue: []any{}, + wantValue: []any(nil), // TODO: smells like a bug, should be []any{} wantSchema: avro.NewArraySchema(must(avro.NewUnionSchema( // empty slice values default to nullable strings []avro.Schema{ avro.NewPrimitiveSchema(avro.String, nil), diff --git a/pkg/plugin/processor/builtin/impl/avro/schemaregistry/avro/union.go b/pkg/plugin/processor/builtin/impl/avro/schemaregistry/avro/union.go index da8659651..0ca325f17 100644 --- a/pkg/plugin/processor/builtin/impl/avro/schemaregistry/avro/union.go +++ b/pkg/plugin/processor/builtin/impl/avro/schemaregistry/avro/union.go @@ -17,6 +17,7 @@ package avro import ( "reflect" + "github.com/conduitio/conduit-commons/opencdc" "github.com/conduitio/conduit/pkg/foundation/cerrors" "github.com/hamba/avro/v2" "github.com/modern-go/reflect2" @@ -30,6 +31,7 @@ import ( type UnionResolver struct { mapUnionPaths []path arrayUnionPaths []path + nullUnionPaths []path resolver *avro.TypeResolver } @@ -40,6 +42,7 @@ type UnionResolver struct { func NewUnionResolver(schema avro.Schema) UnionResolver { var mapUnionPaths []path var arrayUnionPaths []path + var nullUnionPaths []path // traverse the schema and extract paths to all maps and arrays with a union // as the value type traverseSchema(schema, func(p path) { @@ -53,11 +56,17 @@ func NewUnionResolver(schema avro.Schema) UnionResolver { pCopy := make(path, len(p)) copy(pCopy, p) arrayUnionPaths = append(arrayUnionPaths, pCopy) + } else if isNullUnion(p[len(p)-1].schema) { + // path points to a null union, copy and store it + pCopy := make(path, len(p)-1) + copy(pCopy, p[:len(p)-1]) + nullUnionPaths = append(nullUnionPaths, pCopy) } }) return UnionResolver{ mapUnionPaths: mapUnionPaths, arrayUnionPaths: arrayUnionPaths, + nullUnionPaths: nullUnionPaths, resolver: avro.NewTypeResolver(), } } @@ -68,7 +77,9 @@ func NewUnionResolver(schema avro.Schema) UnionResolver { // (e.g. map[string]any{"string":"foo"}). This function takes that map and // extracts the actual value from it (e.g. "foo"). func (r UnionResolver) AfterUnmarshal(val any) error { - if len(r.mapUnionPaths) == 0 && len(r.arrayUnionPaths) == 0 { + if len(r.mapUnionPaths) == 0 && + len(r.arrayUnionPaths) == 0 && + len(r.nullUnionPaths) == 0 { return nil // shortcut } @@ -80,6 +91,10 @@ func (r UnionResolver) AfterUnmarshal(val any) error { if err != nil { return err } + substitutions, err = r.afterUnmarshalNullUnionSubstitutions(val, substitutions) + if err != nil { + return err + } // We now have a list of substitutions, simply apply them. for _, sub := range substitutions { @@ -184,6 +199,60 @@ func (r UnionResolver) afterUnmarshalArraySubstitutions(val any, substitutions [ return substitutions, nil } +func (r UnionResolver) afterUnmarshalNullUnionSubstitutions(val any, substitutions []substitution) ([]substitution, error) { + for _, p := range r.nullUnionPaths { + // first collect all values that are nullable + var maps []map[string]any + err := traverseValue(val, p, true, func(v any) { + switch v := v.(type) { + case map[string]any: + maps = append(maps, v) + case *map[string]any: + maps = append(maps, *v) + case *opencdc.StructuredData: + maps = append(maps, *v) + } + }) + if err != nil { + return nil, err + } + + // Loop through collected maps and collect all substitutions. These maps + // contain values encoded as maps with a single key:value pair, where + // key is the type name (e.g. {"int":1}). We want to replace all these + // maps with the actual value (e.g. 1). + // We don't replace them in the loop, because we want to make sure all + // maps actually contain only 1 value. + for i, mapUnion := range maps { + for k, v := range mapUnion { + if v == nil { + // do no change nil values + continue + } + vmap, ok := v.(map[string]any) + if !ok { + // if the value is not a map, it's not a nil value + continue + } + if len(vmap) != 1 { + return nil, cerrors.Errorf("expected single value encoded as a map, got %d elements", len(vmap)) + } + + // this is a map with a single value, store the substitution + for _, actualVal := range vmap { + substitutions = append(substitutions, mapSubstitution{ + m: maps[i], + key: k, + val: actualVal, + }) + break + } + } + } + } + return substitutions, nil +} + // BeforeMarshal traverses the value using the schema and finds all values that // have the Avro type Union. Those values need to be changed to a map with a // single key that contains the name of the type. This function takes that value @@ -373,6 +442,24 @@ func isArrayUnion(schema avro.Schema) bool { return false } +func isNullUnion(schema avro.Schema) bool { + s, ok := schema.(*avro.UnionSchema) + if !ok { + return false + } + if len(s.Types()) != 2 { + return false + } + for _, s := range s.Types() { + // at least one of the types in the union must be a map or array for this + // to count as a map with a union type + if s.Type() == avro.Null { + return true + } + } + return false +} + type substitution interface { substitute() }