From 67f288f20c50450415f33ef7c255cd24a3e4b5d8 Mon Sep 17 00:00:00 2001 From: Greg Date: Sat, 4 May 2024 23:06:25 +0900 Subject: [PATCH] skip scanning struct fields for ItemMarshaler structs (#229) --- encode_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++ encoding.go | 11 ++++++++-- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/encode_test.go b/encode_test.go index 81c9ec7..ab8361c 100644 --- a/encode_test.go +++ b/encode_test.go @@ -133,3 +133,57 @@ func TestMarshalItemAsymmetric(t *testing.T) { }) } } + +type isValue_Kind interface { + isValue_Kind() +} + +type myStruct struct { + OK bool + Value isValue_Kind +} + +func (ms *myStruct) MarshalDynamoItem() (map[string]*dynamodb.AttributeValue, error) { + world := "world" + return map[string]*dynamodb.AttributeValue{ + "hello": {S: &world}, + }, nil +} + +func (ms *myStruct) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { + hello := item["hello"] + if hello == nil || hello.S == nil || *hello.S != "world" { + ms.OK = false + } else { + ms.OK = true + } + return nil +} + +var _ ItemMarshaler = &myStruct{} +var _ ItemUnmarshaler = &myStruct{} + +func TestMarshalItemBypass(t *testing.T) { + something := &myStruct{} + got, err := MarshalItem(something) + if err != nil { + t.Fatal(err) + } + + world := "world" + expect := map[string]*dynamodb.AttributeValue{ + "hello": {S: &world}, + } + if !reflect.DeepEqual(got, expect) { + t.Error("bad marshal. want:", expect, "got:", got) + } + + var dec myStruct + err = UnmarshalItem(got, &dec) + if err != nil { + t.Fatal(err) + } + if !dec.OK { + t.Error("bad unmarshal") + } +} diff --git a/encoding.go b/encoding.go index 2f7877e..dabfbda 100644 --- a/encoding.go +++ b/encoding.go @@ -14,8 +14,9 @@ import ( var typeCache sync.Map // unmarshalKey → *typedef type typedef struct { - decoders map[unmarshalKey]decodeFunc - fields []structField + decoders map[unmarshalKey]decodeFunc + fields []structField + marshaler bool } func newTypedef(rt reflect.Type) (*typedef, error) { @@ -27,6 +28,7 @@ func newTypedef(rt reflect.Type) (*typedef, error) { } func (def *typedef) init(rt reflect.Type) error { + rt0 := rt for rt.Kind() == reflect.Pointer { rt = rt.Elem() } @@ -37,6 +39,11 @@ func (def *typedef) init(rt reflect.Type) error { return nil } + // skip visiting struct fields if encoding will be bypassed by a custom marshaler + if shouldBypassEncodeItem(rt0) || shouldBypassEncodeItem(rt) { + return nil + } + var err error def.fields, err = structFields(rt) return err