Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Query.One + Filter behavior (#248) #249

Merged
merged 4 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 33 additions & 42 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ func (q *Query) ConsumedCapacity(cc *ConsumedCapacity) *Query {

// One executes this query and retrieves a single result,
// unmarshaling the result to out.
// This uses the DynamoDB GetItem API when possible, otherwise Query.
// If the query returns more than one result, [ErrTooMany] may be returned. This is intended as a diagnostic for query mistakes.
// To avoid [ErrTooMany], set the [Query.Limit] to 1.
func (q *Query) One(ctx context.Context, out interface{}) error {
if q.err != nil {
return q.err
Expand Down Expand Up @@ -239,34 +242,20 @@ func (q *Query) One(ctx context.Context, out interface{}) error {
}

// If not, try a Query.
req := q.queryInput()

var res *dynamodb.QueryOutput
err := q.table.db.retry(ctx, func() error {
var err error
res, err = q.table.db.client.Query(ctx, req)
q.cc.incRequests()
if err != nil {
return err
}

switch {
case len(res.Items) == 0:
return ErrNotFound
case len(res.Items) > 1 && q.limit != 1:
return ErrTooMany
case res.LastEvaluatedKey != nil && q.searchLimit != 0:
return ErrTooMany
}

return nil
})
if err != nil {
iter := q.newIter(unmarshalItem)
var item Item
ok := iter.Next(ctx, &item)
if err := iter.Err(); err != nil {
return err
}
q.cc.add(res.ConsumedCapacity)

return unmarshalItem(res.Items[0], out)
if !ok {
return ErrNotFound
}
// Best effort: do we have any pending unused items?
if iter.hasMore() {
return ErrTooMany
}
return unmarshalItem(item, out)
}

// Count executes this request, returning the number of results.
Expand Down Expand Up @@ -314,6 +303,14 @@ func (q *Query) Count(ctx context.Context) (int, error) {
return count, nil
}

func (q *Query) newIter(unmarshal unmarshalFunc) *queryIter {
return &queryIter{
query: q,
unmarshal: unmarshal,
err: q.err,
}
}

// queryIter is the iterator for Query operations
type queryIter struct {
query *Query
Expand Down Expand Up @@ -422,6 +419,13 @@ func (itr *queryIter) Next(ctx context.Context, out interface{}) bool {
return itr.err == nil
}

func (itr *queryIter) hasMore() bool {
if itr.query.limit > 0 && itr.n == itr.query.limit {
return false
}
return itr.output != nil && itr.idx < len(itr.output.Items)
}

// Err returns the error encountered, if any.
// You should check this after Next is finished.
func (itr *queryIter) Err() error {
Expand Down Expand Up @@ -458,11 +462,7 @@ func (itr *queryIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) {

// All executes this request and unmarshals all results to out, which must be a pointer to a slice.
func (q *Query) All(ctx context.Context, out interface{}) error {
iter := &queryIter{
query: q,
unmarshal: unmarshalAppendTo(out),
err: q.err,
}
iter := q.newIter(unmarshalAppendTo(out))
for iter.Next(ctx, out) {
}
return iter.Err()
Expand All @@ -471,11 +471,7 @@ func (q *Query) All(ctx context.Context, out interface{}) error {
// AllWithLastEvaluatedKey executes this request and unmarshals all results to out, which must be a pointer to a slice.
// This returns a PagingKey you can use with StartFrom to split up results.
func (q *Query) AllWithLastEvaluatedKey(ctx context.Context, out interface{}) (PagingKey, error) {
iter := &queryIter{
query: q,
unmarshal: unmarshalAppendTo(out),
err: q.err,
}
iter := q.newIter(unmarshalAppendTo(out))
for iter.Next(ctx, out) {
}
lek, err := iter.LastEvaluatedKey(ctx)
Expand All @@ -484,12 +480,7 @@ func (q *Query) AllWithLastEvaluatedKey(ctx context.Context, out interface{}) (P

// Iter returns a results iterator for this request.
func (q *Query) Iter() PagingIter {
iter := &queryIter{
query: q,
unmarshal: unmarshalItem,
err: q.err,
}
return iter
return q.newIter(unmarshalItem)
}

// can we use the get item API?
Expand Down
25 changes: 25 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dynamo

import (
"context"
"errors"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -111,6 +112,30 @@ func TestGetAllCount(t *testing.T) {
t.Errorf("bad result for get one. %v ≠ %v", one, item)
}

// trigger ErrTooMany
one = widget{}
err = table.Get("UserID", 42).Range("Time", Greater, "0").Consistent(true).One(ctx, &one)
if !errors.Is(err, ErrTooMany) {
t.Errorf("bad error from get one. %v ≠ %v", err, ErrTooMany)
}

// suppress ErrTooMany with Limit(1)
one = widget{}
err = table.Get("UserID", 42).Range("Time", Greater, "0").Consistent(true).Limit(1).One(ctx, &one)
if err != nil {
t.Error("unexpected error:", err)
}
if one.UserID == 0 {
t.Errorf("bad result for get one: %v", one)
}

// trigger ErrNotFound via SearchLimit + Filter + One
one = widget{}
err = table.Get("UserID", 42).Range("Time", Greater, "0").Filter("Msg = ?", item.Msg).Consistent(true).SearchLimit(1).One(ctx, &one)
if !errors.Is(err, ErrNotFound) {
t.Errorf("bad error from get one. %v ≠ %v", err, ErrNotFound)
}

// GetItem + Project
one = widget{}
projected := widget{
Expand Down
Loading