Skip to content

Commit

Permalink
change LastEvaluatedKey signature
Browse files Browse the repository at this point in the history
  • Loading branch information
guregu committed Jan 27, 2024
1 parent 91db943 commit ef69352
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 48 deletions.
5 changes: 2 additions & 3 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ func newDB(client dynamodbiface.DynamoDBAPI, cfg aws.Config) *DB {
} else if cfg.RetryMaxAttempts > 0 {
db.retryMax = cfg.RetryMaxAttempts
}

// }

return db
Expand Down Expand Up @@ -230,15 +229,15 @@ type PagingIter interface {
Iter
// LastEvaluatedKey returns a key that can be passed to StartFrom in Query or Scan.
// Combined with SearchLimit, it is useful for paginating partial results.
LastEvaluatedKey() PagingKey
LastEvaluatedKey(context.Context) (PagingKey, error)
}

// PagingIter is an iterator of combined request results from multiple iterators running in parallel.
type ParallelIter interface {
Iter
// LastEvaluatedKeys returns each parallel segment's last evaluated key in order of segment number.
// The slice will be the same size as the number of segments, and the keys can be nil.
LastEvaluatedKeys() []PagingKey
LastEvaluatedKeys(context.Context) ([]PagingKey, error)
}

// PagingKey is a key used for splitting up partial results.
Expand Down
23 changes: 8 additions & 15 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,39 +424,32 @@ func (itr *queryIter) Err() error {
return itr.err
}

func (itr *queryIter) LastEvaluatedKey() PagingKey {
func (itr *queryIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) {
if itr.output != nil {
// if we've hit the end of our results, we can use the real LEK
if itr.idx == len(itr.output.Items) {
return itr.output.LastEvaluatedKey
return itr.output.LastEvaluatedKey, nil
}

// figure out the primary keys if needed
if itr.keys == nil && itr.keyErr == nil {
ctx, _ := defaultContext() // TODO(v2): take context instead of using the default
itr.keys, itr.keyErr = itr.query.table.primaryKeys(ctx, itr.exLEK, itr.exESK, itr.query.index)
}
if itr.keyErr != nil {
// primaryKeys can fail if the credentials lack DescribeTable permissions
// in order to preserve backwards compatibility, we fall back to the old behavior and warn
// see: https://github.com/guregu/dynamo/pull/187#issuecomment-1045183901
// TODO(v2): rejigger this API.
itr.query.table.db.log("dynamo: Warning:", itr.keyErr, "Returning a later LastEvaluatedKey.")
return itr.output.LastEvaluatedKey
return itr.output.LastEvaluatedKey, fmt.Errorf("dynamo: failed to determine LastEvaluatedKey in query: %w", itr.keyErr)
}

// we can't use the real LEK, so we need to infer the LEK from the last item we saw
lek, err := lekify(itr.last, itr.keys)
// unfortunately, this API can't return an error so a warning is the best we can do...
// this matches old behavior before the LEK was automatically generated
// TODO(v2): fix this.
if err != nil {
itr.query.table.db.log("dynamo: Warning:", err, "Returning a later LastEvaluatedKey.")
return itr.output.LastEvaluatedKey
return itr.output.LastEvaluatedKey, fmt.Errorf("dynamo: failed to infer LastEvaluatedKey in query: %w", err)
}
return lek
return lek, nil
}
return nil
return nil, nil
}

// All executes this request and unmarshals all results to out, which must be a pointer to a slice.
Expand Down Expand Up @@ -493,7 +486,8 @@ func (q *Query) AllWithLastEvaluatedKeyContext(ctx context.Context, out interfac
}
for iter.NextWithContext(ctx, out) {
}
return iter.LastEvaluatedKey(), iter.Err()
lek, err := iter.LastEvaluatedKey(ctx)
return lek, errors.Join(iter.Err(), err)
}

// Iter returns a results iterator for this request.
Expand All @@ -503,7 +497,6 @@ func (q *Query) Iter() PagingIter {
unmarshal: unmarshalItem,
err: q.err,
}

return iter
}

Expand Down
18 changes: 15 additions & 3 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ func TestQueryPaging(t *testing.T) {
if more {
t.Error("unexpected more", more)
}
itr = table.Get("UserID", 1969).StartFrom(itr.LastEvaluatedKey()).SearchLimit(1).Iter()
lek, err := itr.LastEvaluatedKey(context.Background())
if err != nil {
t.Error("LEK error", err)
}
itr = table.Get("UserID", 1969).StartFrom(lek).SearchLimit(1).Iter()
}
}

Expand Down Expand Up @@ -235,7 +239,11 @@ func TestQueryMagicLEK(t *testing.T) {
if more {
t.Error("unexpected more", more)
}
itr = table.Get("UserID", 1970).StartFrom(itr.LastEvaluatedKey()).Limit(1).Iter()
lek, err := itr.LastEvaluatedKey(context.Background())
if err != nil {
t.Error("LEK error", lek)
}
itr = table.Get("UserID", 1970).StartFrom(lek).Limit(1).Iter()
}
})

Expand Down Expand Up @@ -268,7 +276,11 @@ func TestQueryMagicLEK(t *testing.T) {
if more {
t.Error("unexpected more", more)
}
itr = table.Get("Msg", "TestQueryMagicLEK").Index("Msg-Time-index").Filter("UserID = ?", 1970).StartFrom(itr.LastEvaluatedKey()).Limit(1).Iter()
lek, err := itr.LastEvaluatedKey(context.Background())
if err != nil {
t.Error("LEK error", err)
}
itr = table.Get("Msg", "TestQueryMagicLEK").Index("Msg-Time-index").Filter("UserID = ?", 1970).StartFrom(lek).Limit(1).Iter()
}
})
}
Expand Down
49 changes: 26 additions & 23 deletions scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package dynamo

import (
"context"
"errors"
"fmt"
"strings"
"sync"

Expand Down Expand Up @@ -198,7 +200,8 @@ func (s *Scan) AllWithLastEvaluatedKeyContext(ctx context.Context, out interface
}
for itr.NextWithContext(ctx, out) {
}
return itr.LastEvaluatedKey(), itr.Err()
lek, err := itr.LastEvaluatedKey(ctx)
return lek, errors.Join(itr.Err(), err)
}

// AllParallel executes this request by running the given number of segments in parallel, then unmarshaling all results to out, which must be a pointer to a slice.
Expand All @@ -219,7 +222,8 @@ func (s *Scan) AllParallelWithLastEvaluatedKeys(ctx context.Context, segments in
go ps.run(ctx)
for ps.NextWithContext(ctx, out) {
}
return ps.LastEvaluatedKeys(), ps.Err()
leks, err := ps.LastEvaluatedKeys(ctx)
return leks, errors.Join(ps.Err(), err)
}

// AllParallelStartFrom executes this request by continuing parallel scans from the given LastEvaluatedKeys, then unmarshaling all results to out, which must be a pointer to a slice.
Expand All @@ -230,7 +234,8 @@ func (s *Scan) AllParallelStartFrom(ctx context.Context, keys []PagingKey, out i
go ps.run(ctx)
for ps.NextWithContext(ctx, out) {
}
return ps.LastEvaluatedKeys(), ps.Err()
leks, err := ps.LastEvaluatedKeys(ctx)
return leks, errors.Join(ps.Err(), err)
}

// Count executes this request and returns the number of items matching the scan.
Expand Down Expand Up @@ -442,49 +447,44 @@ func (itr *scanIter) Err() error {

// LastEvaluatedKey returns a key that can be used to continue this scan.
// Use with SearchLimit for best results.
func (itr *scanIter) LastEvaluatedKey() PagingKey {
func (itr *scanIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) {
if itr.output != nil {
// if we've hit the end of our results, we can use the real LEK
if itr.idx == len(itr.output.Items) {
return itr.output.LastEvaluatedKey
return itr.output.LastEvaluatedKey, nil
}

// figure out the primary keys if needed
if itr.keys == nil && itr.keyErr == nil {
ctx, _ := defaultContext() // TODO(v2): take context instead of using the default
itr.keys, itr.keyErr = itr.scan.table.primaryKeys(ctx, itr.exLEK, itr.exESK, itr.scan.index)
}
if itr.keyErr != nil {
// primaryKeys can fail if the credentials lack DescribeTable permissions
// in order to preserve backwards compatibility, we fall back to the old behavior and warn
// see: https://github.com/guregu/dynamo/pull/187#issuecomment-1045183901
// TODO(v2): rejigger this API.
itr.scan.table.db.log("dynamo: Warning:", itr.keyErr, "Returning a later LastEvaluatedKey.")
return itr.output.LastEvaluatedKey
return itr.output.LastEvaluatedKey, fmt.Errorf("dynamo: failed to determine LastEvaluatedKey in scan: %w", itr.keyErr)
}

// we can't use the real LEK, so we need to infer the LEK from the last item we saw
lek, err := lekify(itr.last, itr.keys)
// unfortunately, this API can't return an error so a warning is the best we can do...
// this matches old behavior before the LEK was automatically generated
// TODO(v2): fix this.
if err != nil {
itr.scan.table.db.log("dynamo: Warning:", err, "Returning a later LastEvaluatedKey.")
return itr.output.LastEvaluatedKey
return itr.output.LastEvaluatedKey, fmt.Errorf("dynamo: failed to infer LastEvaluatedKey in scan: %w", err)
}
return lek
return lek, nil
}
return nil
return nil, nil
}

type parallelScan struct {
iters []*scanIter
items chan Item

leks []PagingKey
cc *ConsumedCapacity
err error
mu *sync.Mutex
leks []PagingKey
lekErr error

cc *ConsumedCapacity
err error
mu *sync.Mutex

unmarshal unmarshalFunc
}
Expand Down Expand Up @@ -522,9 +522,12 @@ func (ps *parallelScan) run(ctx context.Context) {
}

if ps.leks != nil {
lek := iter.LastEvaluatedKey()
lek, err := iter.LastEvaluatedKey(ctx)
ps.mu.Lock()
ps.leks[i] = lek
if err != nil && ps.lekErr == nil {
ps.lekErr = err
}
ps.mu.Unlock()
}
}
Expand Down Expand Up @@ -582,10 +585,10 @@ func (ps *parallelScan) Err() error {
return ps.err
}

func (ps *parallelScan) LastEvaluatedKeys() []PagingKey {
func (ps *parallelScan) LastEvaluatedKeys(_ context.Context) ([]PagingKey, error) {
keys := make([]PagingKey, len(ps.leks))
ps.mu.Lock()
defer ps.mu.Unlock()
copy(keys, ps.leks)
return keys
return keys, ps.lekErr
}
24 changes: 20 additions & 4 deletions scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ func TestScanPaging(t *testing.T) {
if !more {
break
}
itr = table.Scan().StartFrom(itr.LastEvaluatedKey()).SearchLimit(1).Iter()
lek, err := itr.LastEvaluatedKey(context.Background())
if err != nil {
t.Error("LEK error", err)
}
itr = table.Scan().StartFrom(lek).SearchLimit(1).Iter()
}
for i, w := range widgets {
if w.UserID == 0 && w.Time.IsZero() {
Expand Down Expand Up @@ -159,7 +163,11 @@ func TestScanPaging(t *testing.T) {
if !more {
break
}
itr = table.Scan().SearchLimit(1).IterParallelStartFrom(ctx, itr.LastEvaluatedKeys())
leks, err := itr.LastEvaluatedKeys(context.Background())
if err != nil {
t.Error("LEK error", err)
}
itr = table.Scan().SearchLimit(1).IterParallelStartFrom(ctx, leks)
}
for i, w := range widgets {
if w.UserID == 0 && w.Time.IsZero() {
Expand Down Expand Up @@ -205,7 +213,11 @@ func TestScanMagicLEK(t *testing.T) {
if itr.Err() != nil {
t.Error("unexpected error", itr.Err())
}
itr = table.Scan().Filter("'Msg' = ?", "TestScanMagicLEK").StartFrom(itr.LastEvaluatedKey()).Limit(2).Iter()
lek, err := itr.LastEvaluatedKey(context.Background())
if err != nil {
t.Error("LEK error", err)
}
itr = table.Scan().Filter("'Msg' = ?", "TestScanMagicLEK").StartFrom(lek).Limit(2).Iter()
}
})

Expand All @@ -217,7 +229,11 @@ func TestScanMagicLEK(t *testing.T) {
if itr.Err() != nil {
t.Error("unexpected error", itr.Err())
}
itr = table.Scan().Index("Msg-Time-index").Filter("UserID = ?", 2069).StartFrom(itr.LastEvaluatedKey()).Limit(2).Iter()
lek, err := itr.LastEvaluatedKey(context.Background())
if err != nil {
t.Error("LEK error", err)
}
itr = table.Scan().Index("Msg-Time-index").Filter("UserID = ?", 2069).StartFrom(lek).Limit(2).Iter()
}
})

Expand Down

0 comments on commit ef69352

Please sign in to comment.