diff --git a/query.go b/query.go index ed82375..f5f7a73 100644 --- a/query.go +++ b/query.go @@ -428,6 +428,12 @@ func (itr *queryIter) Err() error { return itr.err } +func (itr *queryIter) SetError(err error) { + if itr.err == nil { + itr.err = err + } +} + func (itr *queryIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) { if itr.output != nil { // if we've hit the end of our results, we can use the real LEK diff --git a/scan.go b/scan.go index 26792cf..2e8222c 100644 --- a/scan.go +++ b/scan.go @@ -434,6 +434,12 @@ func (itr *scanIter) Err() error { return itr.err } +func (itr *scanIter) SetError(err error) { + if itr.err == nil { + itr.err = err + } +} + // LastEvaluatedKey returns a key that can be used to continue this scan. // Use with SearchLimit for best results. func (itr *scanIter) LastEvaluatedKey(ctx context.Context) (PagingKey, error) { diff --git a/seq_go123.go b/seq_go123.go new file mode 100644 index 0000000..9f6c897 --- /dev/null +++ b/seq_go123.go @@ -0,0 +1,40 @@ +//go:build go1.23 + +package dynamo + +import ( + "context" + "iter" +) + +// Seq returns an item iterator compatible with Go 1.23 `for ... range` loops. +func Seq[V any](ctx context.Context, iter Iter) iter.Seq[V] { + return func(yield func(V) bool) { + item := new(V) + for iter.Next(ctx, item) { + if !yield(*item) { + break + } + item = new(V) + } + } +} + +// SeqLEK returns a LastEvaluatedKey and item iterator compatible with Go 1.23 `for ... range` loops. +func SeqLEK[V any](ctx context.Context, iter PagingIter) iter.Seq2[PagingKey, V] { + return func(yield func(PagingKey, V) bool) { + item := new(V) + for iter.Next(ctx, item) { + lek, err := iter.LastEvaluatedKey(ctx) + if err != nil { + if setter, ok := iter.(interface{ SetError(error) }); ok { + setter.SetError(err) + } + } + if !yield(lek, *item) { + break + } + item = new(V) + } + } +} diff --git a/seq_test.go b/seq_test.go new file mode 100644 index 0000000..49d0006 --- /dev/null +++ b/seq_test.go @@ -0,0 +1,75 @@ +//go:build go1.23 + +package dynamo + +import ( + "context" + "testing" + "time" +) + +func TestSeq(t *testing.T) { + if testDB == nil { + t.Skip(offlineSkipMsg) + } + ctx := context.Background() + table := testDB.Table(testTableWidgets) + + widgets := []any{ + widget{ + UserID: 1971, + Time: time.Date(1971, 4, 00, 0, 0, 0, 0, time.UTC), + Msg: "Seq1", + }, + widget{ + UserID: 1971, + Time: time.Date(1971, 4, 10, 0, 0, 0, 0, time.UTC), + Msg: "Seq1", + }, + widget{ + UserID: 1971, + Time: time.Date(1971, 4, 20, 0, 0, 0, 0, time.UTC), + Msg: "Seq1", + }, + } + + t.Run("prepare data", func(t *testing.T) { + if _, err := table.Batch().Write().Put(widgets...).Run(ctx); err != nil { + t.Fatal(err) + } + }) + + iter := testDB.Table(testTableWidgets).Get("UserID", 1971).Iter() + var got []*widget + var count int + for item := range Seq[*widget](ctx, iter) { + t.Log(item) + item.Count = count + got = append(got, item) + count++ + } + + if iter.Err() != nil { + t.Fatal(iter.Err()) + } + + t.Run("results match", func(t *testing.T) { + for i, item := range got { + want := widgets[i].(widget) + if !item.Time.Equal(want.Time) { + t.Error("bad result. want:", want.Time, "got:", item.Time) + } + } + }) + + t.Run("result item isolation", func(t *testing.T) { + // make sure that when mutating the result in the `for ... range` loop + // it only affects one item + t.Log("got", got) + for i, item := range got { + if item.Count != i { + t.Error("unexpected count. got:", item.Count, "want:", i) + } + } + }) +}