Skip to content

Commit

Permalink
Fix Query.One + Filter behavior (#248) (#249)
Browse files Browse the repository at this point in the history
* fix Query.One + Filter behavior (#248)

* Query.One: delay unmarshaling until success (preserves old behavior)

* add some docs explaining ErrTooMany
  • Loading branch information
guregu authored Dec 17, 2024
1 parent cb20568 commit ea7f332
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 42 deletions.
75 changes: 33 additions & 42 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ func (q *Query) ConsumedCapacity(cc *ConsumedCapacity) *Query {

// One executes this query and retrieves a single result,
// unmarshaling the result to out.
// This uses the DynamoDB GetItem API when possible, otherwise Query.
// If the query returns more than one result, [ErrTooMany] may be returned. This is intended as a diagnostic for query mistakes.
// To avoid [ErrTooMany], set the [Query.Limit] to 1.
func (q *Query) One(ctx context.Context, out interface{}) error {
if q.err != nil {
return q.err
Expand Down Expand Up @@ -239,34 +242,20 @@ func (q *Query) One(ctx context.Context, out interface{}) error {
}

// If not, try a Query.
req := q.queryInput()

var res *dynamodb.QueryOutput
err := q.table.db.retry(ctx, func() error {
var err error
res, err = q.table.db.client.Query(ctx, req)
q.cc.incRequests()
if err != nil {
return err
}

switch {
case len(res.Items) == 0:
return ErrNotFound
case len(res.Items) > 1 && q.limit != 1:
return ErrTooMany
case res.LastEvaluatedKey != nil && q.searchLimit != 0:
return ErrTooMany
}

return nil
})
if err != nil {
iter := q.newIter(unmarshalItem)
var item Item
ok := iter.Next(ctx, &item)
if err := iter.Err(); err != nil {
return err
}
q.cc.add(res.ConsumedCapacity)

return unmarshalItem(res.Items[0], out)
if !ok {
return ErrNotFound
}
// Best effort: do we have any pending unused items?
if iter.hasMore() {
return ErrTooMany
}
return unmarshalItem(item, out)
}

// Count executes this request, returning the number of results.
Expand Down Expand Up @@ -314,6 +303,14 @@ func (q *Query) Count(ctx context.Context) (int, error) {
return count, nil
}

func (q *Query) newIter(unmarshal unmarshalFunc) *queryIter {
return &queryIter{
query: q,
unmarshal: unmarshal,
err: q.err,
}
}

// queryIter is the iterator for Query operations
type queryIter struct {
query *Query
Expand Down Expand Up @@ -422,6 +419,13 @@ func (itr *queryIter) Next(ctx context.Context, out interface{}) bool {
return itr.err == nil
}

func (itr *queryIter) hasMore() bool {
if itr.query.limit > 0 && itr.n == itr.query.limit {
return false
}
return itr.output != nil && itr.idx < len(itr.output.Items)
}

// Err returns the error encountered, if any.
// You should check this after Next is finished.
func (itr *queryIter) Err() error {
Expand Down Expand Up @@ -458,11 +462,7 @@ func (itr *queryIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) {

// All executes this request and unmarshals all results to out, which must be a pointer to a slice.
func (q *Query) All(ctx context.Context, out interface{}) error {
iter := &queryIter{
query: q,
unmarshal: unmarshalAppendTo(out),
err: q.err,
}
iter := q.newIter(unmarshalAppendTo(out))
for iter.Next(ctx, out) {
}
return iter.Err()
Expand All @@ -471,11 +471,7 @@ func (q *Query) All(ctx context.Context, out interface{}) error {
// AllWithLastEvaluatedKey executes this request and unmarshals all results to out, which must be a pointer to a slice.
// This returns a PagingKey you can use with StartFrom to split up results.
func (q *Query) AllWithLastEvaluatedKey(ctx context.Context, out interface{}) (PagingKey, error) {
iter := &queryIter{
query: q,
unmarshal: unmarshalAppendTo(out),
err: q.err,
}
iter := q.newIter(unmarshalAppendTo(out))
for iter.Next(ctx, out) {
}
lek, err := iter.LastEvaluatedKey(ctx)
Expand All @@ -484,12 +480,7 @@ func (q *Query) AllWithLastEvaluatedKey(ctx context.Context, out interface{}) (P

// Iter returns a results iterator for this request.
func (q *Query) Iter() PagingIter {
iter := &queryIter{
query: q,
unmarshal: unmarshalItem,
err: q.err,
}
return iter
return q.newIter(unmarshalItem)
}

// can we use the get item API?
Expand Down
25 changes: 25 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dynamo

import (
"context"
"errors"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -111,6 +112,30 @@ func TestGetAllCount(t *testing.T) {
t.Errorf("bad result for get one. %v ≠ %v", one, item)
}

// trigger ErrTooMany
one = widget{}
err = table.Get("UserID", 42).Range("Time", Greater, "0").Consistent(true).One(ctx, &one)
if !errors.Is(err, ErrTooMany) {
t.Errorf("bad error from get one. %v ≠ %v", err, ErrTooMany)
}

// suppress ErrTooMany with Limit(1)
one = widget{}
err = table.Get("UserID", 42).Range("Time", Greater, "0").Consistent(true).Limit(1).One(ctx, &one)
if err != nil {
t.Error("unexpected error:", err)
}
if one.UserID == 0 {
t.Errorf("bad result for get one: %v", one)
}

// trigger ErrNotFound via SearchLimit + Filter + One
one = widget{}
err = table.Get("UserID", 42).Range("Time", Greater, "0").Filter("Msg = ?", item.Msg).Consistent(true).SearchLimit(1).One(ctx, &one)
if !errors.Is(err, ErrNotFound) {
t.Errorf("bad error from get one. %v ≠ %v", err, ErrNotFound)
}

// GetItem + Project
one = widget{}
projected := widget{
Expand Down

0 comments on commit ea7f332

Please sign in to comment.