diff --git a/deserialize/deserialize.go b/deserialize/deserialize.go index 167f9de..0a38bcc 100644 --- a/deserialize/deserialize.go +++ b/deserialize/deserialize.go @@ -216,7 +216,7 @@ func MakeMapDeserializerFromReflect(options Options, typ reflect.Type) (MapRefle } noTags := tags.Empty() - reflectDeserializer, err := makeFieldDeserializerFromReflect(options.RootPath, typ, innerOptions, &noTags, placeholder, false) + reflectDeserializer, err := makeFieldDeserializerFromReflect(options.RootPath, typ, innerOptions, &noTags, placeholder, false, false) if err != nil { return nil, err @@ -296,7 +296,7 @@ func MakeKVDeserializerFromReflect(options Options, typ reflect.Type) (KVListRef } var placeholder = reflect.New(typ).Elem() noTags := tags.Empty() - wrapped, err := makeFieldDeserializerFromReflect(".", typ, innerOptions, &noTags, placeholder, false) + wrapped, err := makeFieldDeserializerFromReflect(".", typ, innerOptions, &noTags, placeholder, false, false) if err != nil { return nil, err } @@ -445,11 +445,16 @@ func deListMapReflect(typ reflect.Type, outMap map[string]any, inMap map[string] publicFieldName = &field.Name } - switch field.Type.Kind() { - case reflect.Array: + switch { + case field.Type.Kind() == reflect.Array: fallthrough - case reflect.Slice: + case field.Type.Kind() == reflect.Slice: outMap[*publicFieldName] = inMap[*publicFieldName] + case field.Type.Kind() == reflect.Struct && (tags.IsFlattened() || field.Anonymous): + err = deListMapReflect(field.Type, outMap, inMap, options) + if err != nil { + return err + } default: length := len(inMap[*publicFieldName]) switch length { @@ -632,34 +637,63 @@ func makeStructDeserializerFromReflect(path string, typ reflect.Type, options in fieldPath := fmt.Sprint(path, ".", *publicFieldName) - var fieldContentDeserializer reflectDeserializer - fieldContentDeserializer, err = makeFieldDeserializerFromReflect(fieldPath, fieldType, options, &tags, selfContainer, willPreinitialize) - if err != nil { - return nil, err - } - fieldDeserializer := func(outPtr *reflect.Value, inMap shared.Dict) error { - // Note: maps are references, so there is no loss to passing a `map` instead of a `*map`. - // Use the `fieldName` to access the field in the record. - outReflect := outPtr.FieldByName(fieldNativeName) - - // Use the `publicFieldName` to access the field in the map. - var fieldValue shared.Value - if isPublic { - // If the field is public, we can accept external data, if provided. - var ok bool - fieldValue, ok = inMap.Lookup(*publicFieldName) - if !ok { - fieldValue = nil + var fieldDeserializer func(*reflect.Value, shared.Dict) error + if tags.IsFlattened() || field.Anonymous { + // The field is flattened either explicitly (tag `flatten`) or implicitly + // (because it's an anonymous field). In either case, the *contents* of that + // struct are pulled from *the same outer map* `inMap`. + + fieldContentDeserializer, err := makeFieldDeserializerFromReflect(fieldPath, fieldType, options, &tags, selfContainer, willPreinitialize, true) + if err != nil { + return nil, err + } + + fieldDeserializer = func(outPtr *reflect.Value, inMap shared.Dict) error { + // Note: maps are references, so there is no loss to passing a `map` instead of a `*map`. + // Use the `fieldName` to access the field in the record. + outReflect := outPtr.FieldByName(fieldNativeName) + + err := fieldContentDeserializer(&outReflect, inMap.AsValue()) + if err != nil { + return err } + + // At this stage, the field has already been validated by using `Validator.Validate()`. + // In future versions, we may wish to add support for further validation using tags. + return nil } - err := fieldContentDeserializer(&outReflect, fieldValue) + + } else { + // The field is nested, so we'll try to move into the corresponding entry in the map. + fieldContentDeserializer, err := makeFieldDeserializerFromReflect(fieldPath, fieldType, options, &tags, selfContainer, willPreinitialize, false) if err != nil { - return err + return nil, err } - // At this stage, the field has already been validated by using `Validator.Validate()`. - // In future versions, we may wish to add support for further validation using tags. - return nil + fieldDeserializer = func(outPtr *reflect.Value, inMap shared.Dict) error { + // Note: maps are references, so there is no loss to passing a `map` instead of a `*map`. + // Use the `fieldName` to access the field in the record. + outReflect := outPtr.FieldByName(fieldNativeName) + + // Use the `publicFieldName` to access the field in the map. + var fieldValue shared.Value + if isPublic { + // If the field is public, we can accept external data, if provided. + var ok bool + fieldValue, ok = inMap.Lookup(*publicFieldName) + if !ok { + fieldValue = nil + } + } // otherwise, use the zero value for that field. + err := fieldContentDeserializer(&outReflect, fieldValue) + if err != nil { + return err + } + + // At this stage, the field has already been validated by using `Validator.Validate()`. + // In future versions, we may wish to add support for further validation using tags. + return nil + } } deserializers[field.Name] = fieldDeserializer @@ -771,8 +805,8 @@ func makeStructDeserializerFromReflect(path string, typ reflect.Type, options in } // We may now deserialize fields. - for _, fieldDeserializationData := range deserializers { - err = fieldDeserializationData(&result, inMap) + for _, fieldDeserializer := range deserializers { + err = fieldDeserializer(&result, inMap) if err != nil { return err } @@ -820,7 +854,7 @@ func makeMapDeserializerFromReflect(path string, typ reflect.Type, options inner subPath := path + "[]" subTags := tagsPkg.Empty() subTyp := typ.Elem() - contentDeserializer, err := makeFieldDeserializerFromReflect(subPath, subTyp, options, &subTags, selfContainer, initializationMetadata.willPreinitialize) + contentDeserializer, err := makeFieldDeserializerFromReflect(subPath, subTyp, options, &subTags, selfContainer, initializationMetadata.willPreinitialize, false) if err != nil { return nil, err } @@ -929,7 +963,7 @@ func makeSliceDeserializer(fieldPath string, fieldType reflect.Type, options inn // Prepare a deserializer for elements in this slice. childPreinitialized := wasPreinitialized || tags.IsPreinitialized() - elementDeserializer, err := makeFieldDeserializerFromReflect(arrayPath, fieldType.Elem(), options, &subTags, subContainer, childPreinitialized) + elementDeserializer, err := makeFieldDeserializerFromReflect(arrayPath, fieldType.Elem(), options, &subTags, subContainer, childPreinitialized, false) if err != nil { return nil, fmt.Errorf("failed to generate a deserializer for %s\n\t * %w", fieldPath, err) } @@ -1020,7 +1054,7 @@ func makePointerDeserializer(fieldPath string, fieldType reflect.Type, options i subTags := tagsPkg.Empty() subContainer := reflect.New(fieldType).Elem() childPreinitialized := wasPreinitialized || tags.IsPreinitialized() - elementDeserializer, err := makeFieldDeserializerFromReflect(ptrPath, fieldType.Elem(), options, &subTags, subContainer, childPreinitialized) + elementDeserializer, err := makeFieldDeserializerFromReflect(ptrPath, fieldType.Elem(), options, &subTags, subContainer, childPreinitialized, false) if err != nil { return nil, fmt.Errorf("failed to generate a deserializer for %s\n\t * %w", fieldPath, err) } @@ -1251,15 +1285,18 @@ func makeFlatFieldDeserializer(fieldPath string, fieldType reflect.Type, options // - `typ` the dynamic type for the field being compiled; // - `tagName` the name of tags to use for field renamings, e.g. `query`; // - `tags` the table of tags for this field. -func makeFieldDeserializerFromReflect(fieldPath string, fieldType reflect.Type, options innerOptions, tags *tagsPkg.Tags, container reflect.Value, wasPreinitialized bool) (reflectDeserializer, error) { - err := options.unmarshaler.Enter(fieldPath, fieldType) - if err != nil { - return nil, err //nolint:wrapcheck +func makeFieldDeserializerFromReflect(fieldPath string, fieldType reflect.Type, options innerOptions, tags *tagsPkg.Tags, container reflect.Value, wasPreinitialized bool, wasFlattened bool) (reflectDeserializer, error) { + if !wasFlattened { + err := options.unmarshaler.Enter(fieldPath, fieldType) + if err != nil { + return nil, err //nolint:wrapcheck + } + defer func() { + options.unmarshaler.Exit(fieldType) + }() } - defer func() { - options.unmarshaler.Exit(fieldType) - }() + var err error var structured reflectDeserializer switch fieldType.Kind() { diff --git a/deserialize/deserialize_reflect_test.go b/deserialize/deserialize_reflect_test.go index f1557be..8ce1c9d 100644 --- a/deserialize/deserialize_reflect_test.go +++ b/deserialize/deserialize_reflect_test.go @@ -60,6 +60,29 @@ func TestReflectMapDeserializer(t *testing.T) { assert.DeepEqual(t, &sample, out) } +func TestReflectMapEmbeddedDeserializer(t *testing.T) { + type Inner struct { + Nested string + } + type Outer struct { + Inner + String string + Int int + } + sample := Outer{ + Inner: Inner{ + Nested: "def", + }, + String: "abc", + Int: 123, + } + out, err := twoWaysReflect[Outer, Outer](t, sample) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, &sample, out) +} + func TestReflectKVDeserializer(t *testing.T) { type Test struct { String string @@ -86,3 +109,69 @@ func TestReflectKVDeserializer(t *testing.T) { assert.NilError(t, err) assert.Equal(t, *deserialized, sample) } + +// Should be useful for books, as we wouldn't have to recreate a Pagination struct for each route for example. +func TestNestedStructReflectKVDeserializer(t *testing.T) { + type NestedStruct struct { + BBB string + } + type MainStruct struct { + AAA string + NestedStruct NestedStruct `flatten:""` + } + sample := MainStruct{ + AAA: "aaa", + NestedStruct: NestedStruct{ + BBB: "bbb", + }, + } + + deserializer, err := deserialize.MakeKVDeserializerFromReflect(deserialize.Options{ + Unmarshaler: jsonPkg.Driver, + MainTagName: "json", + RootPath: "", + }, reflect.TypeOf(sample)) + assert.NilError(t, err) + + kvList := map[string][]string{} + kvList["AAA"] = []string{sample.AAA} + kvList["BBB"] = []string{sample.NestedStruct.BBB} + + deserialized := new(MainStruct) + reflectDeserialized := reflect.ValueOf(deserialized).Elem() + err = deserializer.DeserializeKVListTo(kvList, &reflectDeserialized) + assert.NilError(t, err) + assert.Equal(t, *deserialized, sample) +} + +// Not mandatory, but could be nice to have. +func TestAnonymStructReflectKVDeserializer(t *testing.T) { + type EmbeddedStruct struct { + BBB string + } + type MainStruct struct { + AAA string + EmbeddedStruct // Embedded struct are anonymous fields in reflection, flattened automatically. + } + sample := MainStruct{ + AAA: "aaa", + EmbeddedStruct: EmbeddedStruct{BBB: "bbb"}, + } + + deserializer, err := deserialize.MakeKVDeserializerFromReflect(deserialize.Options{ + Unmarshaler: jsonPkg.Driver, + MainTagName: "json", + RootPath: "", + }, reflect.TypeOf(sample)) + assert.NilError(t, err) + + kvList := map[string][]string{} + kvList["AAA"] = []string{sample.AAA} + kvList["BBB"] = []string{sample.BBB} // Embedded struct fields can be accessed like if it was at root level + + deserialized := new(MainStruct) + reflectDeserialized := reflect.ValueOf(deserialized).Elem() + err = deserializer.DeserializeKVListTo(kvList, &reflectDeserialized) + assert.NilError(t, err) + assert.Equal(t, *deserialized, sample) +} diff --git a/deserialize/deserialize_test.go b/deserialize/deserialize_test.go index 1f3b898..95b2180 100644 --- a/deserialize/deserialize_test.go +++ b/deserialize/deserialize_test.go @@ -1481,3 +1481,80 @@ func TestKVCallsInnerValidation(t *testing.T) { _, err = deserializer.DeserializeKVList(kvlist) assert.ErrorContains(t, err, "custom validation error") } + +// ------ Test that flattened structs are deserialized properly. +func TestMapDeserializerFlattened(t *testing.T) { + type Inner struct { + Left string + Right string + } + type Outer struct { + Flattened Inner `flatten:""` + Inner + Regular Inner + } + + deserializer, err := deserialize.MakeMapDeserializer[Outer](deserialize.JSONOptions("")) + assert.NilError(t, err) + + data := ` + { + "Left": "flattened_left", + "Right": "flattened_right", + "Regular": { + "Left": "regular_left", + "Right": "regular_right" + } + }` + expected := Outer{ + Flattened: Inner{ + Left: "flattened_left", + Right: "flattened_right", + }, + Inner: Inner{ + Left: "flattened_left", + Right: "flattened_right", + }, + Regular: Inner{ + Left: "regular_left", + Right: "regular_right", + }, + } + found, err := deserializer.DeserializeBytes([]byte(data)) + assert.NilError(t, err) + + assert.DeepEqual(t, *found, expected) +} + +func TestKVDeserializerFlattened(t *testing.T) { + type Inner struct { + Left string + Right string + } + type Outer struct { + Flattened Inner `flatten:""` + Inner + } + + deserializer, err := deserialize.MakeKVListDeserializer[Outer](deserialize.QueryOptions("")) + assert.NilError(t, err) + + data := make(map[string][]string) + data["Left"] = []string{"flattened_left"} + data["Right"] = []string{"flattened_right"} + + expected := Outer{ + Flattened: Inner{ + Left: "flattened_left", + Right: "flattened_right", + }, + Inner: Inner{ + Left: "flattened_left", + Right: "flattened_right", + }, + } + found, err := deserializer.DeserializeKVList(data) + assert.NilError(t, err) + + assert.DeepEqual(t, *found, expected) +} diff --git a/deserialize/tags/tags.go b/deserialize/tags/tags.go index d4d8632..b23a1c7 100644 --- a/deserialize/tags/tags.go +++ b/deserialize/tags/tags.go @@ -161,8 +161,33 @@ func (tags Tags) IsPreinitialized() bool { return ok } +// Return `true` if this field is marked as `flatten`, e.g. +// +// type Flattening struct { +// A string +// B struct { +// C string +// D string +// } // `flatten:""` +// } +// +// should deserialized from the following JSON +// +// { +// "A": "aaaaa", +// // no field B +// "C": "ccccc", +// "D": "ddddd" +// } +func (tags Tags) IsFlattened() bool { + tags.witness.Assert() + _, ok := tags.tags["flatten"] + return ok +} + // Lookup a key. func (tags Tags) Lookup(key string) ([]string, bool) { + tags.witness.Assert() result, ok := tags.tags[key] return result, ok }