Skip to content

Commit

Permalink
fix Update.Set with empty typed strings
Browse files Browse the repository at this point in the history
see: #151
auto-omitted values not caught by isNil were resulting in a SerializationException
  • Loading branch information
guregu committed Dec 26, 2020
1 parent 20f2fb8 commit b1a1053
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 40 deletions.
37 changes: 0 additions & 37 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -712,40 +712,3 @@ func isAVEqual(a, b *dynamodb.AttributeValue) bool {
}
return false
}

// isNil returns true if v is considered nil
// this is used to determine if an attribute should be set or removed
func isNil(v interface{}) bool {
if v == nil || v == "" {
return true
}

// consider v nil if it's a special encoder defined on a value type, but v is a pointer
rv := reflect.ValueOf(v)
switch v.(type) {
case Marshaler:
if rv.Kind() == reflect.Ptr && rv.IsNil() {
if _, ok := rv.Type().Elem().MethodByName("MarshalDynamo"); ok {
return true
}
}
case dynamodbattribute.Marshaler:
if rv.Kind() == reflect.Ptr && rv.IsNil() {
if _, ok := rv.Type().Elem().MethodByName("MarshalDynamoDBAttributeValue"); ok {
return true
}
}
case encoding.TextMarshaler:
if rv.Kind() == reflect.Ptr && rv.IsNil() {
if _, ok := rv.Type().Elem().MethodByName("MarshalText"); ok {
return true
}
}
default:
// e.g. (*int)(nil)
return rv.Kind() == reflect.Ptr && rv.IsNil()
}

// non-pointers or special encoders with a pointer receiver
return false
}
14 changes: 14 additions & 0 deletions encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ var encodingTests = []struct {
"Empty": &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{}},
}},
},
{
name: "textMarshaler maps",
in: struct {
M1 map[textMarshaler]bool // dont omit
}{
M1: map[textMarshaler]bool{textMarshaler(true): true},
},
out: &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{
"M1": &dynamodb.AttributeValue{M: map[string]*dynamodb.AttributeValue{
"true": &dynamodb.AttributeValue{BOOL: aws.Bool(true)},
}},
}},
},
{
name: "struct",
in: struct {
Expand Down Expand Up @@ -312,6 +325,7 @@ var itemEncodingTests = []struct {
in: struct {
OK string
EmptyStr string
EmptyStr2 customString
EmptyB []byte
EmptyL []int
EmptyM map[string]bool
Expand Down
3 changes: 3 additions & 0 deletions substitute.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func (s *subber) subValue(value interface{}, flags encodeFlags) (string, error)
if err != nil {
return "", err
}
if av == nil {
return "", fmt.Errorf("invalid substitue value for '%s': %v", sub, av)
}
s.valueExpr[sub] = av
return sub, nil
}
Expand Down
10 changes: 7 additions & 3 deletions update.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,16 @@ func (u *Update) Range(name string, value interface{}) *Update {
// Paths that are reserved words are automatically escaped.
// Use single quotes to escape complex values like 'User'.'Count'.
func (u *Update) Set(path string, value interface{}) *Update {
if isNil(value) {
v, err := marshal(value, flagNone)
if v == nil && err == nil {
// auto-omitted value
return u.Remove(path)
}
path, err := u.escape(path)
u.setError(err)
expr, err := u.subExpr("🝕 = ?", path, value)

path, err = u.escape(path)
u.setError(err)
expr, err := u.subExpr("🝕 = ?", path, v)
u.setError(err)
u.set = append(u.set, expr)
return u
Expand Down
47 changes: 47 additions & 0 deletions update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,50 @@ func TestUpdateNil(t *testing.T) {
t.Errorf("bad result. %+v ≠ %+v", result, expected)
}
}

func TestUpdateSetAutoOmit(t *testing.T) {
if testDB == nil {
t.Skip(offlineSkipMsg)
}
table := testDB.Table(testTable)

type widget2 struct {
widget
CStr customString
SPtr *string
}

// first, add an item to make sure there is at least one
str := "delete me ptr"
item := widget2{
widget: widget{
UserID: 11111,
Time: time.Now().UTC(),
},
CStr: customString("delete me"),
SPtr: &str,
}
err := table.Put(item).Run()
if err != nil {
t.Error("unexpected error:", err)
t.FailNow()
}

// update Msg with 'nil', which should delete it
var result widget2
err = table.Update("UserID", item.UserID).Range("Time", item.Time).
Set("CStr", customString("")).
Set("SPtr", nil).
Value(&result)
if err != nil {
t.Error("unexpected error:", err)
}
expected := widget2{
widget: item.widget,
CStr: customString(""),
SPtr: nil,
}
if !reflect.DeepEqual(result, expected) {
t.Errorf("bad result. %+v ≠ %+v", result, expected)
}
}

0 comments on commit b1a1053

Please sign in to comment.