Skip to content

Commit

Permalink
Schemas: fix handling of nullable fields (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
hariso authored Nov 8, 2024
1 parent 9537066 commit d237fb3
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 48 deletions.
4 changes: 3 additions & 1 deletion schema/avro/traverse.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ import (
)

type (
// path represents a path from the root to a certain type in an avro schema.
// path represents a path from the root to a certain field/type in an Avro schema.
path []leg
// leg is a single leg of a path.
leg struct {
// schema is the schema of the object that contains the below field
// (i.e. it's not the schema of the field itself).
schema avro.Schema
field *avro.Field
}
Expand Down
139 changes: 92 additions & 47 deletions schema/avro/union.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@ import (
// NB: It currently supports union types nested in maps, but not nested in
// slices. For example, hooks will not work for values like []any{[]any{"foo"}}.
type unionResolver struct {
mapUnionPaths []path
// mapUnionPaths are all the paths to map fields within a schema,
// where value types are union types
mapUnionPaths []path
// arrayUnionPaths are all the paths to array fields within a schema,
// where value types are union types
arrayUnionPaths []path
nullUnionPaths []path
resolver *avro.TypeResolver
// nullUnionPaths are all the paths to nullable fields within a schema.
nullUnionPaths []path
resolver *avro.TypeResolver
}

// newUnionResolver takes a schema and extracts the paths to all maps and arrays
Expand Down Expand Up @@ -202,59 +207,85 @@ func (r unionResolver) afterUnmarshalArraySubstitutions(val any, substitutions [
}

func (r unionResolver) afterUnmarshalNullUnionSubstitutions(val any, substitutions []substitution) ([]substitution, error) {
for _, p := range r.nullUnionPaths {
// first collect all values that are nullable
var maps []map[string]any
err := traverseValue(val, p, true, func(v any) {
switch v := v.(type) {
case map[string]any:
maps = append(maps, v)
case *map[string]any:
maps = append(maps, *v)
case *opencdc.StructuredData:
maps = append(maps, *v)
}
})
for _, nullUnionPath := range r.nullUnionPaths {
// first collect all parents that contain a value that is nullable
parentMaps, err := r.collectParentsForNullUnionPath(val, nullUnionPath)
if err != nil {
return nil, err
}

// Loop through collected maps and collect all substitutions. These maps
// contain values encoded as maps with a single key:value pair, where
// key is the type name (e.g. {"int":1}). We want to replace all these
// maps with the actual value (e.g. 1).
// We don't replace them in the loop, because we want to make sure all
// maps actually contain only 1 value.
for i, mapUnion := range maps {
for k, v := range mapUnion {
if v == nil {
// do no change nil values
continue
}
vmap, ok := v.(map[string]any)
if !ok {
// if the value is not a map, it's not a nil value
continue
}
if len(vmap) != 1 {
return nil, fmt.Errorf("expected single value encoded as a map, got %d elements: %w", len(vmap), ErrSchemaValueMismatch)
}

// this is a map with a single value, store the substitution
for _, actualVal := range vmap {
substitutions = append(substitutions, mapSubstitution{
m: maps[i],
key: k,
val: actualVal,
})
break
// nullUnionField is the fields that needs to be substitured.
// It's the last leg in the path.
nullUnionField := nullUnionPath[len(nullUnionPath)-1].field

// Loop through collected parent maps and collect all substitutions.
for _, parentMap := range parentMaps {
// nullUnionField is nil if the field represents a key in a map.
// In that case, all the values in the map need to be checked and substituted.
if nullUnionField == nil {
for key := range parentMap {
sub, err := r.substitute(parentMap, key)
if err != nil {
return nil, err
}
// substitution not needed for this key, skip to next
if sub == nil {
continue
}
substitutions = append(substitutions, sub)
}
continue
}
// nullUnionField is not nil if it's a field within a record schema.
// In that case, we only substitute that field.
sub, err := r.substitute(parentMap, nullUnionField.Name())
if err != nil {
return nil, err
}
// substitution not needed for this key, skip to next
if sub == nil {
continue
}
substitutions = append(substitutions, sub)
}
}
return substitutions, nil
}

// substitute substitutes maps inserted by hamba/avro's Unmarshal() function
// with actual values. The input map (return by hamba/avro's Unmarshal())
// contain values encoded as maps with a single key:value pair, where
// key is the type name (e.g. {"int":1}). We want to replace all these
// maps with the actual value (e.g. 1).
func (r unionResolver) substitute(parentMap map[string]any, name string) (substitution, error) {
avroVal := parentMap[name]
if avroVal == nil {
// don't change nil values
return nil, nil //nolint:nilnil // This is the expected behavior.
}
vmap, ok := avroVal.(map[string]any)
if !ok {
// if the value is not a map, it's not a nil value
return nil, nil //nolint:nilnil // This is the expected behavior.
}
if len(vmap) != 1 {
return nil, fmt.Errorf("expected single value for %s encoded as a map, got %d elements: %w", name, len(vmap), ErrSchemaValueMismatch)
}

// this is a map with a single value, store the substitution
for _, actualVal := range vmap {
return mapSubstitution{
m: parentMap,
key: name,
val: actualVal,
}, nil
}

// we can reach this line only if we didn't return
// the substitution from the loop above
panic("substitution not returned (this is a bug in the code)")
}

// BeforeMarshal traverses the value using the schema and finds all values that
// have the Avro type Union. Those values need to be changed to a map with a
// single key that contains the name of the type. This function takes that value
Expand Down Expand Up @@ -406,6 +437,22 @@ func (r unionResolver) resolveNameForType(v any, us *avro.UnionSchema) (string,
return "", fmt.Errorf("can't resolve %v in union type %v: %w", names, us.String(), ErrSchemaValueMismatch)
}

func (r unionResolver) collectParentsForNullUnionPath(val any, p path) ([]map[string]any, error) {
var parentMaps []map[string]any
err := traverseValue(val, p, true, func(v any) {
switch v := v.(type) {
case map[string]any:
parentMaps = append(parentMaps, v)
case *map[string]any:
parentMaps = append(parentMaps, *v)
case *opencdc.StructuredData:
parentMaps = append(parentMaps, *v)
}
})

return parentMaps, err
}

func isMapUnion(schema avro.Schema) bool {
s, ok := schema.(*avro.MapSchema)
if !ok {
Expand Down Expand Up @@ -453,8 +500,6 @@ func isNullUnion(schema avro.Schema) bool {
return false
}
for _, s := range s.Types() {
// at least one of the types in the union must be a map or array for this
// to count as a map with a union type
if s.Type() == avro.Null {
return true
}
Expand Down
21 changes: 21 additions & 0 deletions schema/avro/union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,27 @@ import (
"github.com/matryer/is"
)

func TestSerde_MarshalUnmarshalNullableFields(t *testing.T) {
is := is.New(t)

sd := opencdc.StructuredData{
"appearance": map[string]interface{}{
"mode": "dark",
"color": "purple",
},
"website": nil,
}
serde, err := SerdeForType(sd)
is.NoErr(err)

bytes, err := serde.Marshal(sd)
is.NoErr(err)

var structuredData opencdc.StructuredData
err = serde.Unmarshal(bytes, &structuredData)
is.NoErr(err)
}

func TestUnionResolver(t *testing.T) {
is := is.New(t)

Expand Down

0 comments on commit d237fb3

Please sign in to comment.