Skip to content

Commit

Permalink
change Update.Set to remove nil items
Browse files Browse the repository at this point in the history
previously it was broken and returned an AWS encoding error
  • Loading branch information
guregu committed Jul 15, 2019
1 parent cc063bc commit 305af82
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 0 deletions.
37 changes: 37 additions & 0 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,3 +487,40 @@ func isAVEqual(a, b *dynamodb.AttributeValue) bool {
}
return false
}

// isNil returns true if v is considered nil
// this is used to determine if an attribute should be set or removed
func isNil(v interface{}) bool {
if v == nil || v == "" {
return true
}

// consider v nil if it's a special encoder defined on a value type, but v is a pointer
rv := reflect.ValueOf(v)
switch v.(type) {
case Marshaler:
if rv.Kind() == reflect.Ptr && rv.IsNil() {
if _, ok := rv.Type().Elem().MethodByName("MarshalDynamo"); ok {
return true
}
}
case dynamodbattribute.Marshaler:
if rv.Kind() == reflect.Ptr && rv.IsNil() {
if _, ok := rv.Type().Elem().MethodByName("MarshalDynamoDBAttributeValue"); ok {
return true
}
}
case encoding.TextMarshaler:
if rv.Kind() == reflect.Ptr && rv.IsNil() {
if _, ok := rv.Type().Elem().MethodByName("MarshalText"); ok {
return true
}
}
default:
// e.g. (*int)(nil)
return rv.Kind() == reflect.Ptr && rv.IsNil()
}

// non-pointers or special encoders with a pointer receiver
return false
}
33 changes: 33 additions & 0 deletions encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,17 @@ var itemEncodingTests = []struct {
"A": &dynamodb.AttributeValue{S: aws.String("hello")},
},
},
{
name: "pointer (value receiver TextMarshaler)",
in: &struct {
A *textMarshaler
}{
A: new(textMarshaler),
},
out: map[string]*dynamodb.AttributeValue{
"A": &dynamodb.AttributeValue{S: aws.String("false")},
},
},
{
name: "rename",
in: struct {
Expand Down Expand Up @@ -390,9 +401,31 @@ func (tm *textMarshaler) UnmarshalText(text []byte) error {
return nil
}

type ptrTextMarshaler bool

func (tm *ptrTextMarshaler) MarshalText() ([]byte, error) {
if tm == nil {
return []byte("null"), nil
}
if *tm {
return []byte("true"), nil
}
return []byte("false"), nil
}

func (tm *ptrTextMarshaler) UnmarshalText(text []byte) error {
if string(text) == "null" {
return nil
}
*tm = string(text) == "true"
return nil
}

var (
_ Marshaler = new(customMarshaler)
_ Unmarshaler = new(customMarshaler)
_ encoding.TextMarshaler = new(textMarshaler)
_ encoding.TextUnmarshaler = new(textMarshaler)
_ encoding.TextMarshaler = new(ptrTextMarshaler)
_ encoding.TextUnmarshaler = new(ptrTextMarshaler)
)
4 changes: 4 additions & 0 deletions update.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,13 @@ func (u *Update) Range(name string, value interface{}) *Update {
}

// Set changes path to the given value.
// If value is an empty string or nil, path will be removed instead.
// Paths that are reserved words are automatically escaped.
// Use single quotes to escape complex values like 'User'.'Count'.
func (u *Update) Set(path string, value interface{}) *Update {
if isNil(value) {
return u.Remove(path)
}
path, err := u.escape(path)
u.setError(err)
expr, err := u.subExpr("🝕 = ?", path, value)
Expand Down
46 changes: 46 additions & 0 deletions update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,49 @@ func TestUpdate(t *testing.T) {
t.Error("expected ConditionalCheckFailedException, not", err)
}
}

func TestUpdateNil(t *testing.T) {
if testDB == nil {
t.Skip(offlineSkipMsg)
}
table := testDB.Table(testTable)

// first, add an item to make sure there is at least one
item := widget{
UserID: 4242,
Time: time.Now().UTC(),
Msg: "delete me",
Meta: map[string]string{
"abc": "123",
},
Count: 100,
}
err := table.Put(item).Run()
if err != nil {
t.Error("unexpected error:", err)
t.FailNow()
}

// update Msg with 'nil', which should delete it
var result widget
err = table.Update("UserID", item.UserID).Range("Time", item.Time).
Set("Msg", (*textMarshaler)(nil)).
Set("Meta.'abc'", nil).
Set("Meta.'ok'", (*ptrTextMarshaler)(nil)).
Set("Count", (*int)(nil)).
Value(&result)
if err != nil {
t.Error("unexpected error:", err)
}
expected := widget{
UserID: item.UserID,
Time: item.Time,
Msg: "",
Meta: map[string]string{
"ok": "null",
},
}
if !reflect.DeepEqual(result, expected) {
t.Errorf("bad result. %+v ≠ %+v", result, expected)
}
}

0 comments on commit 305af82

Please sign in to comment.