Skip to content

Commit

Permalink
Allow decoding to struct field of interface type (#280)
Browse files Browse the repository at this point in the history
Closes #260
Closes #275
  • Loading branch information
fxamacker authored May 30, 2021
1 parent 3240b60 commit 4a03f1c
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 2 deletions.
9 changes: 7 additions & 2 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const (
specialTypeNone specialType = iota
specialTypeUnmarshalerIface
specialTypeEmptyIface
specialTypeIface
specialTypeTag
specialTypeTime
)
Expand All @@ -57,8 +58,12 @@ func newTypeInfo(t reflect.Type) *typeInfo {
tInfo.nonPtrType = t
tInfo.nonPtrKind = k

if k == reflect.Interface && t.NumMethod() == 0 {
tInfo.spclType = specialTypeEmptyIface
if k == reflect.Interface {
if t.NumMethod() == 0 {
tInfo.spclType = specialTypeEmptyIface
} else {
tInfo.spclType = specialTypeIface
}
} else if t == typeTag {
tInfo.spclType = specialTypeTag
} else if t == typeTime {
Expand Down
6 changes: 6 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,12 @@ const (
// parseToValue decodes CBOR data to value. It assumes data is well-formed,
// and does not perform bounds checking.
func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolint:gocyclo

if tInfo.spclType == specialTypeIface && !v.IsNil() {
v = v.Elem()
tInfo = getTypeInfo(v.Type())
}

// Create new value for the pointer v to point to if CBOR value is not nil/undefined.
if !d.nextCBORNil() {
for v.Kind() == reflect.Ptr {
Expand Down
224 changes: 224 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5178,3 +5178,227 @@ func TestUnmarshalInvalidTagBignum(t *testing.T) {
}
}
}

type Foo interface {
Foo() string
}

type UintFoo uint

func (f *UintFoo) Foo() string {
return fmt.Sprint(f)
}

type IntFoo int

func (f *IntFoo) Foo() string {
return fmt.Sprint(*f)
}

type ByteFoo []byte

func (f *ByteFoo) Foo() string {
return fmt.Sprint(*f)
}

type StringFoo string

func (f *StringFoo) Foo() string {
return string(*f)
}

type ArrayFoo []int

func (f *ArrayFoo) Foo() string {
return fmt.Sprint(*f)
}

type MapFoo map[int]int

func (f *MapFoo) Foo() string {
return fmt.Sprint(*f)
}

type StructFoo struct {
Value int `cbor:"1,keyasint"`
}

func (f *StructFoo) Foo() string {
return fmt.Sprint(*f)
}

type TestExample struct {
Message string `cbor:"1,keyasint"`
Foo Foo `cbor:"2,keyasint"`
}

func TestUnmarshalToInterface(t *testing.T) {

uintFoo, uintFoo123 := UintFoo(0), UintFoo(123)
intFoo, intFooNeg1 := IntFoo(0), IntFoo(-1)
byteFoo, byteFoo123 := ByteFoo(nil), ByteFoo([]byte{1, 2, 3})
stringFoo, stringFoo123 := StringFoo(""), StringFoo("123")
arrayFoo, arrayFoo123 := ArrayFoo(nil), ArrayFoo([]int{1, 2, 3})
mapFoo, mapFoo123 := MapFoo(nil), MapFoo(map[int]int{1: 1, 2: 2, 3: 3})

em, _ := EncOptions{Sort: SortCanonical}.EncMode()

testCases := []struct {
name string
data []byte
v *TestExample
unmarshalToObj *TestExample
}{
{
name: "uint",
data: hexDecode("a2016b736f6d65206d657373676502187b"), // {1: "some messge", 2: 123}
v: &TestExample{
Message: "some messge",
Foo: &uintFoo123,
},
unmarshalToObj: &TestExample{Foo: &uintFoo},
},
{
name: "int",
data: hexDecode("a2016b736f6d65206d65737367650220"), // {1: "some messge", 2: -1}
v: &TestExample{
Message: "some messge",
Foo: &intFooNeg1,
},
unmarshalToObj: &TestExample{Foo: &intFoo},
},
{
name: "bytes",
data: hexDecode("a2016b736f6d65206d65737367650243010203"), // {1: "some messge", 2: [1,2,3]}
v: &TestExample{
Message: "some messge",
Foo: &byteFoo123,
},
unmarshalToObj: &TestExample{Foo: &byteFoo},
},
{
name: "string",
data: hexDecode("a2016b736f6d65206d65737367650263313233"), // {1: "some messge", 2: "123"}
v: &TestExample{
Message: "some messge",
Foo: &stringFoo123,
},
unmarshalToObj: &TestExample{Foo: &stringFoo},
},
{
name: "array",
data: hexDecode("a2016b736f6d65206d65737367650283010203"), // {1: "some messge", 2: []int{1,2,3}}
v: &TestExample{
Message: "some messge",
Foo: &arrayFoo123,
},
unmarshalToObj: &TestExample{Foo: &arrayFoo},
},
{
name: "map",
data: hexDecode("a2016b736f6d65206d657373676502a3010102020303"), // {1: "some messge", 2: map[int]int{1:1,2:2,3:3}}
v: &TestExample{
Message: "some messge",
Foo: &mapFoo123,
},
unmarshalToObj: &TestExample{Foo: &mapFoo},
},
{
name: "struct",
data: hexDecode("a2016b736f6d65206d657373676502a1011901c8"), // {1: "some messge", 2: {1: 456}}
v: &TestExample{
Message: "some messge",
Foo: &StructFoo{Value: 456},
},
unmarshalToObj: &TestExample{Foo: &StructFoo{}},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {

data, err := em.Marshal(tc.v)
if err != nil {
t.Errorf("Marshal(%+v) returned error %v", tc.v, err)
} else if !bytes.Equal(data, tc.data) {
t.Errorf("Marshal(%+v) = 0x%x, want 0x%x", tc.v, data, tc.v)
}

// Unmarshal to empty interface
var einterface TestExample
if err = Unmarshal(data, &einterface); err == nil {
t.Errorf("Unmarshal(0x%x) didn't return an error, want error (*UnmarshalTypeError)", data)
} else if _, ok := err.(*UnmarshalTypeError); !ok {
t.Errorf("Unmarshal(0x%x) returned wrong type of error %T, want (*UnmarshalTypeError)", data, err)
}

// Unmarshal to interface value
err = Unmarshal(data, tc.unmarshalToObj)
if err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", data, err)
} else if !reflect.DeepEqual(tc.unmarshalToObj, tc.v) {
t.Errorf("Unmarshal(0x%x) = %v, want %v", data, tc.unmarshalToObj, tc.v)
}
})
}
}

type Bar struct {
I int
}

func (b *Bar) Foo() string {
return fmt.Sprint(*b)
}

type FooStruct struct {
Foos []Foo
}

func TestUnmarshalTaggedDataToInterface(t *testing.T) {

var tags = NewTagSet()
err := tags.Add(
TagOptions{EncTag: EncTagRequired, DecTag: DecTagRequired},
reflect.TypeOf(&Bar{}),
4,
)
if err != nil {
t.Error(err)
}

v := &FooStruct{
Foos: []Foo{&Bar{1}},
}

want := hexDecode("a164466f6f7381c4a1614901") // {"Foos": [4({"I": 1})]}

em, _ := EncOptions{}.EncModeWithTags(tags)
data, err := em.Marshal(v)
if err != nil {
t.Errorf("Marshal(%+v) returned error %v", v, err)
} else if !bytes.Equal(data, want) {
t.Errorf("Marshal(%+v) = 0x%x, want 0x%x", v, data, want)
}

dm, _ := DecOptions{}.DecModeWithTags(tags)

// Unmarshal to empty interface
var v1 Bar
if err = dm.Unmarshal(data, &v1); err == nil {
t.Errorf("Unmarshal(0x%x) didn't return an error, want error (*UnmarshalTypeError)", data)
} else if _, ok := err.(*UnmarshalTypeError); !ok {
t.Errorf("Unmarshal(0x%x) returned wrong type of error %T, want (*UnmarshalTypeError)", data, err)
}

// Unmarshal to interface value
v2 := &FooStruct{
Foos: []Foo{&Bar{}},
}
err = dm.Unmarshal(data, v2)
if err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", data, err)
} else if !reflect.DeepEqual(v2, v) {
t.Errorf("Unmarshal(0x%x) = %v, want %v", data, v2, v)
}
}

0 comments on commit 4a03f1c

Please sign in to comment.