diff --git a/.github/docker-compose.yml b/.github/docker-compose.yml new file mode 100644 index 0000000..f06f00a --- /dev/null +++ b/.github/docker-compose.yml @@ -0,0 +1,8 @@ +version: '3' + +services: + dynamodb: + image: amazon/dynamodb-local:latest + ports: + - "8880:8000" + command: "-jar DynamoDBLocal.jar -sharedDb -inMemory" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..f90120b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,25 @@ +name: CI + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: 'stable' + - name: Start DynamoDB Local + run: docker compose -f '.github/docker-compose.yml' up -d + - name: Test + run: go test -v -race -cover -coverpkg=./... ./... + env: + DYNAMO_TEST_ENDPOINT: 'http://localhost:8880' + DYNAMO_TEST_REGION: local + DYNAMO_TEST_TABLE: 'TestDB-%' + AWS_ACCESS_KEY_ID: dummy + AWS_SECRET_ACCESS_KEY: dummy + AWS_REGION: local diff --git a/README.md b/README.md index f3ac14d..e4aa6d2 100644 --- a/README.md +++ b/README.md @@ -237,38 +237,24 @@ err := db.Table("Books").Get("ID", 555).One(dynamo.AWSEncoding(&someBook)) ### Integration tests -By default, tests are run in offline mode. Create a table called `TestDB`, with a number partition key called `UserID` and a string sort key called `Time`. It also needs a Global Secondary Index called `Msg-Time-index` with a string partition key called `Msg` and a string sort key called `Time`. +By default, tests are run in offline mode. In order to run the integration tests, some environment variables need to be set. -Change the table name with the environment variable `DYNAMO_TEST_TABLE`. You must specify `DYNAMO_TEST_REGION`, setting it to the AWS region where your test table is. - - - ```bash -DYNAMO_TEST_REGION=us-west-2 go test github.com/guregu/dynamo/... -cover - ``` - -If you want to use [DynamoDB Local](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DynamoDBLocal.html) to run local tests, specify `DYNAMO_TEST_ENDPOINT`. - - ```bash -DYNAMO_TEST_REGION=us-west-2 DYNAMO_TEST_ENDPOINT=http://localhost:8000 go test github.com/guregu/dynamo/... -cover - ``` - -Example of using [aws-cli](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Tools.CLI.html) to create a table for testing. +To run the tests against [DynamoDB Local](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DynamoDBLocal.html): ```bash -aws dynamodb create-table \ - --table-name TestDB \ - --attribute-definitions \ - AttributeName=UserID,AttributeType=N \ - AttributeName=Time,AttributeType=S \ - AttributeName=Msg,AttributeType=S \ - --key-schema \ - AttributeName=UserID,KeyType=HASH \ - AttributeName=Time,KeyType=RANGE \ - --global-secondary-indexes \ - IndexName=Msg-Time-index,KeySchema=[{'AttributeName=Msg,KeyType=HASH'},{'AttributeName=Time,KeyType=RANGE'}],Projection={'ProjectionType=ALL'} \ - --billing-mode PAY_PER_REQUEST \ - --region us-west-2 \ - --endpoint-url http://localhost:8000 # using DynamoDB local +# Use Docker to run DynamoDB local on port 8880 +docker compose -f '.github/docker-compose.yml' up -d + +# Run the tests with a fresh table +# The tables will be created automatically +# The '%' in the table name will be replaced the current timestamp +DYNAMO_TEST_ENDPOINT='http://localhost:8880' \ + DYNAMO_TEST_REGION='local' \ + DYNAMO_TEST_TABLE='TestDB-%' \ + AWS_ACCESS_KEY_ID='dummy' \ + AWS_SECRET_ACCESS_KEY='dummy' \ + AWS_REGION='local' \ + go test -v -race ./... -cover -coverpkg=./... ``` ### License diff --git a/batch_test.go b/batch_test.go index c49cf4a..c6ad405 100644 --- a/batch_test.go +++ b/batch_test.go @@ -12,7 +12,10 @@ func TestBatchGetWrite(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table1 := testDB.Table(testTableWidgets) + table2 := testDB.Table(testTableSprockets) + tables := []Table{table1, table2} + totalBatchSize := batchSize * len(tables) ctx := context.TODO() items := make([]interface{}, batchSize) @@ -30,10 +33,17 @@ func TestBatchGetWrite(t *testing.T) { keys[i] = Keys{i, now} } + var batches []*BatchWrite + for _, table := range tables { + b := table.Batch().Write().Put(items...) + batches = append(batches, b) + } + batch1 := batches[0] + batch1.Merge(batches[1:]...) var wcc ConsumedCapacity - wrote, err := table.Batch().Write().Put(items...).ConsumedCapacity(&wcc).Run(ctx) - if wrote != batchSize { - t.Error("unexpected wrote:", wrote, "≠", batchSize) + wrote, err := batch1.ConsumedCapacity(&wcc).Run(ctx) + if wrote != totalBatchSize { + t.Error("unexpected wrote:", wrote, "≠", totalBatchSize) } if err != nil { t.Error("unexpected error:", err) @@ -43,20 +53,27 @@ func TestBatchGetWrite(t *testing.T) { } // get all - var results []widget + var gets []*BatchGet + for _, table := range tables { + b := table.Batch("UserID", "Time"). + Get(keys...). + Project("UserID", "Time"). + Consistent(true) + gets = append(gets, b) + } + var cc ConsumedCapacity - err = table.Batch("UserID", "Time"). - Get(keys...). - Project("UserID", "Time"). - Consistent(true). - ConsumedCapacity(&cc). - All(ctx, &results) + get1 := gets[0].ConsumedCapacity(&cc) + get1.Merge(gets[1:]...) + + var results []widget + err = get1.All(ctx, &results) if err != nil { t.Error("unexpected error:", err) } - if len(results) != batchSize { - t.Error("expected", batchSize, "results, got", len(results)) + if len(results) != totalBatchSize { + t.Error("expected", totalBatchSize, "results, got", len(results)) } if cc.Total == 0 { @@ -74,26 +91,31 @@ func TestBatchGetWrite(t *testing.T) { } // delete both - wrote, err = table.Batch("UserID", "Time").Write(). - Delete(keys...).Run(ctx) - if wrote != batchSize { - t.Error("unexpected wrote:", wrote, "≠", batchSize) + wrote, err = table1.Batch("UserID", "Time").Write(). + Delete(keys...). + DeleteInRange(table2, "UserID", "Time", keys...). + Run(ctx) + if wrote != totalBatchSize { + t.Error("unexpected wrote:", wrote, "≠", totalBatchSize) } if err != nil { t.Error("unexpected error:", err) } // get both again - results = nil - err = table.Batch("UserID", "Time"). - Get(keys...). - Consistent(true). - All(ctx, &results) - if err != ErrNotFound { - t.Error("expected ErrNotFound, got", err) - } - if len(results) != 0 { - t.Error("expected 0 results, got", len(results)) + { + var results []widget + err = table1.Batch("UserID", "Time"). + Get(keys...). + FromRange(table2, "UserID", "Time", keys...). + Consistent(true). + All(ctx, &results) + if err != ErrNotFound { + t.Error("expected ErrNotFound, got", err) + } + if len(results) != 0 { + t.Error("expected 0 results, got", len(results)) + } } } @@ -101,7 +123,7 @@ func TestBatchGetEmptySets(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.TODO() now := time.Now().UnixNano() / 1000000000 @@ -153,8 +175,8 @@ func TestBatchGetEmptySets(t *testing.T) { } func TestBatchEmptyInput(t *testing.T) { + table := testDB.Table(testTableWidgets) ctx := context.TODO() - table := testDB.Table(testTable) var out []any err := table.Batch("UserID", "Time").Get().All(ctx, &out) if err != ErrNoInput { diff --git a/batchget.go b/batchget.go index e5b8b03..09a035c 100644 --- a/batchget.go +++ b/batchget.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" @@ -43,12 +44,12 @@ func (table Table) Batch(hashAndRangeKeyName ...string) Batch { // BatchGet is a BatchGetItem operation. type BatchGet struct { - batch Batch - reqs []*Query - projection string - consistent bool + batch Batch + reqs []*Query + projections map[string][]string // table → paths + projection []string // default paths + consistent bool - subber err error cc *ConsumedCapacity } @@ -63,46 +64,106 @@ func (b Batch) Get(keys ...Keyed) *BatchGet { batch: b, err: b.err, } - bg.add(keys) - return bg + return bg.And(keys...) } -// And adds more keys to be gotten. +// And adds more keys to be gotten from the default table. +// To get items from other tables, use [BatchGet.From] or [BatchGet.FromRange]. func (bg *BatchGet) And(keys ...Keyed) *BatchGet { - bg.add(keys) - return bg + return bg.add(bg.batch.table, bg.batch.hashKey, bg.batch.rangeKey, keys...) +} + +// From adds more keys to be gotten from the given table. +// The given table's primary key must be a hash key (partition key) only. +// For tables with a range key (sort key) primary key, use [BatchGet.FromRange]. +func (bg *BatchGet) From(table Table, hashKey string, keys ...Keyed) *BatchGet { + return bg.add(table, hashKey, "", keys...) +} + +// FromRange adds more keys to be gotten from the given table. +// For tables without a range key (sort key) primary key, use [BatchGet.From]. +func (bg *BatchGet) FromRange(table Table, hashKey, rangeKey string, keys ...Keyed) *BatchGet { + return bg.add(table, hashKey, rangeKey, keys...) } -func (bg *BatchGet) add(keys []Keyed) { +func (bg *BatchGet) add(table Table, hashKey string, rangeKey string, keys ...Keyed) *BatchGet { for _, key := range keys { if key == nil { bg.setError(errors.New("dynamo: batch: the Keyed interface must not be nil")) break } - get := bg.batch.table.Get(bg.batch.hashKey, key.HashKey()) - if rk := key.RangeKey(); bg.batch.rangeKey != "" && rk != nil { - get.Range(bg.batch.rangeKey, Equal, rk) + get := table.Get(hashKey, key.HashKey()) + if rk := key.RangeKey(); rangeKey != "" && rk != nil { + get.Range(rangeKey, Equal, rk) bg.setError(get.err) } bg.reqs = append(bg.reqs, get) } + return bg } // Project limits the result attributes to the given paths. +// This will apply to all tables, but can be overriden by [BatchGet.ProjectTable] to set specific per-table projections. func (bg *BatchGet) Project(paths ...string) *BatchGet { - var expr string - for i, p := range paths { - if i != 0 { - expr += ", " + bg.projection = paths + return bg +} + +// Project limits the result attributes to the given paths for the given table. +func (bg *BatchGet) ProjectTable(table Table, paths ...string) *BatchGet { + return bg.project(table.Name(), paths...) +} + +func (bg *BatchGet) project(table string, paths ...string) *BatchGet { + if bg.projections == nil { + bg.projections = make(map[string][]string) + } + bg.projections[table] = paths + return bg +} + +func (bg *BatchGet) projectionFor(table string) []string { + if proj := bg.projections[table]; proj != nil { + return proj + } + if bg.projection != nil { + return bg.projection + } + return nil +} + +// Merge copies operations and settings from src to this batch get. +func (bg *BatchGet) Merge(srcs ...*BatchGet) *BatchGet { + for _, src := range srcs { + bg.reqs = append(bg.reqs, src.reqs...) + bg.consistent = bg.consistent || src.consistent + this := bg.batch.table.Name() + for table, proj := range src.projections { + if this == table { + continue + } + bg.mergeProjection(table, proj) + } + if len(src.projection) > 0 { + if that := src.batch.table.Name(); that != this { + bg.mergeProjection(that, src.projection) + } } - name, err := bg.escape(p) - bg.setError(err) - expr += name } - bg.projection = expr return bg } +func (bg *BatchGet) mergeProjection(table string, proj []string) { + current := bg.projections[table] + merged := current + for _, path := range proj { + if !slices.Contains(current, path) { + merged = append(merged, path) + } + } + bg.project(table, merged...) +} + // Consistent will, if on is true, make this batch use a strongly consistent read. // Reads are eventually consistent by default. // Strongly consistent reads are more resource-heavy than eventually consistent reads. @@ -119,7 +180,7 @@ func (bg *BatchGet) ConsumedCapacity(cc *ConsumedCapacity) *BatchGet { // All executes this request and unmarshals all results to out, which must be a pointer to a slice. func (bg *BatchGet) All(ctx context.Context, out interface{}) error { - iter := newBGIter(bg, unmarshalAppendTo(out), bg.err) + iter := newBGIter(bg, unmarshalAppendTo(out), nil, bg.err) for iter.Next(ctx, out) { } return iter.Err() @@ -127,7 +188,13 @@ func (bg *BatchGet) All(ctx context.Context, out interface{}) error { // Iter returns a results iterator for this batch. func (bg *BatchGet) Iter() Iter { - return newBGIter(bg, unmarshalItem, bg.err) + return newBGIter(bg, unmarshalItem, nil, bg.err) +} + +// IterWithTable is like [BatchGet.Iter], but will update the value pointed by tablePtr after each iteration. +// This can be useful when getting from multiple tables to determine which table the latest item came from. +func (bg *BatchGet) IterWithTable(tablePtr *string) Iter { + return newBGIter(bg, unmarshalItem, tablePtr, bg.err) } func (bg *BatchGet) input(start int) *dynamodb.BatchGetItemInput { @@ -140,12 +207,12 @@ func (bg *BatchGet) input(start int) *dynamodb.BatchGetItemInput { } in := &dynamodb.BatchGetItemInput{ - RequestItems: make(map[string]types.KeysAndAttributes, 1), + RequestItems: make(map[string]types.KeysAndAttributes), } - if bg.projection != "" { - for _, get := range bg.reqs[start:end] { - get.Project(get.projection) + for _, get := range bg.reqs[start:end] { + if proj := bg.projectionFor(get.table.Name()); proj != nil { + get.Project(proj...) bg.setError(get.err) } } @@ -153,22 +220,19 @@ func (bg *BatchGet) input(start int) *dynamodb.BatchGetItemInput { in.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes } - var kas *types.KeysAndAttributes for _, get := range bg.reqs[start:end] { - if kas == nil { + table := get.table.Name() + kas, ok := in.RequestItems[table] + if !ok { kas = get.keysAndAttribs() + if bg.consistent { + kas.ConsistentRead = &bg.consistent + } + in.RequestItems[table] = kas continue } kas.Keys = append(kas.Keys, get.keys()) } - if bg.projection != "" { - kas.ProjectionExpression = &bg.projection - kas.ExpressionAttributeNames = bg.nameExpr - } - if bg.consistent { - kas.ConsistentRead = &bg.consistent - } - in.RequestItems[bg.batch.table.Name()] = *kas return in } @@ -181,8 +245,10 @@ func (bg *BatchGet) setError(err error) { // bgIter is the iterator for Batch Get operations type bgIter struct { bg *BatchGet + track *string // table out value input *dynamodb.BatchGetItemInput output *dynamodb.BatchGetItemOutput + got []batchGot err error idx int total int @@ -191,13 +257,19 @@ type bgIter struct { unmarshal unmarshalFunc } -func newBGIter(bg *BatchGet, fn unmarshalFunc, err error) *bgIter { +type batchGot struct { + table string + item Item +} + +func newBGIter(bg *BatchGet, fn unmarshalFunc, track *string, err error) *bgIter { if err == nil && len(bg.reqs) == 0 { err = ErrNoInput } iter := &bgIter{ bg: bg, + track: track, err: err, backoff: backoff.NewExponentialBackOff(), unmarshal: fn, @@ -217,16 +289,14 @@ func (itr *bgIter) Next(ctx context.Context, out interface{}) bool { return false } - tableName := itr.bg.batch.table.Name() - redo: // can we use results we already have? - if itr.output != nil && itr.idx < len(itr.output.Responses[tableName]) { - items := itr.output.Responses[tableName] - item := items[itr.idx] - itr.err = itr.unmarshal(item, out) + if itr.output != nil && itr.idx < len(itr.got) { + got := itr.got[itr.idx] + itr.err = itr.unmarshal(got.item, out) itr.idx++ itr.total++ + itr.trackTable(got.table) return itr.err == nil } @@ -235,16 +305,15 @@ redo: itr.input = itr.bg.input(itr.processed) } - if itr.output != nil && itr.idx >= len(itr.output.Responses[tableName]) { - var unprocessed int - + if itr.output != nil && itr.idx >= len(itr.got) { + for _, req := range itr.input.RequestItems { + itr.processed += len(req.Keys) + } if itr.output.UnprocessedKeys != nil { - _, ok := itr.output.UnprocessedKeys[tableName] - if ok { - unprocessed = len(itr.output.UnprocessedKeys[tableName].Keys) + for _, keys := range itr.output.UnprocessedKeys { + itr.processed -= len(keys.Keys) } } - itr.processed += len(itr.input.RequestItems[tableName].Keys) - unprocessed // have we exhausted all results? if len(itr.output.UnprocessedKeys) == 0 { // yes, try to get next inner batch of 100 items @@ -282,10 +351,27 @@ redo: } } + itr.got = itr.got[:0] + for table, resp := range itr.output.Responses { + for _, item := range resp { + itr.got = append(itr.got, batchGot{ + table: table, + item: item, + }) + } + } + // we've got unprocessed results, marshal one goto redo } +func (itr *bgIter) trackTable(next string) { + if itr.track == nil { + return + } + *itr.track = next +} + // Err returns the error encountered, if any. // You should check this after Next is finished. func (itr *bgIter) Err() error { diff --git a/batchwrite.go b/batchwrite.go index 93e84ba..adf5021 100644 --- a/batchwrite.go +++ b/batchwrite.go @@ -16,11 +16,16 @@ const maxWriteOps = 25 // BatchWrite is a BatchWriteItem operation. type BatchWrite struct { batch Batch - ops []types.WriteRequest + ops []batchWrite err error cc *ConsumedCapacity } +type batchWrite struct { + table string + op types.WriteRequest +} + // Write creates a new batch write request, to which // puts and deletes can be added. func (b Batch) Write() *BatchWrite { @@ -30,29 +35,71 @@ func (b Batch) Write() *BatchWrite { } } -// Put adds put operations for items to this batch. +// Put adds put operations for items to this batch using the default table. func (bw *BatchWrite) Put(items ...interface{}) *BatchWrite { + return bw.PutIn(bw.batch.table, items...) +} + +// PutIn adds put operations for items to this batch using the given table. +// This can be useful for writing to multiple different tables. +func (bw *BatchWrite) PutIn(table Table, items ...interface{}) *BatchWrite { + name := table.Name() for _, item := range items { encoded, err := marshalItem(item) bw.setError(err) - bw.ops = append(bw.ops, types.WriteRequest{PutRequest: &types.PutRequest{ - Item: encoded, - }}) + bw.ops = append(bw.ops, batchWrite{ + table: name, + op: types.WriteRequest{PutRequest: &types.PutRequest{ + Item: encoded, + }}, + }) } return bw } -// Delete adds delete operations for the given keys to this batch. +// Delete adds delete operations for the given keys to this batch, using the default table. func (bw *BatchWrite) Delete(keys ...Keyed) *BatchWrite { + return bw.deleteIn(bw.batch.table, bw.batch.hashKey, bw.batch.rangeKey, keys...) +} + +// DeleteIn adds delete operations for the given keys to this batch, using the given table. +// hashKey must be the name of the primary key hash (partition) attribute. +// This function is for tables with a hash key (partition key) only. +// For tables including a range key (sort key) primary key, use [BatchWrite.DeleteInRange] instead. +func (bw *BatchWrite) DeleteIn(table Table, hashKey string, keys ...Keyed) *BatchWrite { + return bw.deleteIn(table, hashKey, "", keys...) +} + +// DeleteInRange adds delete operations for the given keys to this batch, using the given table. +// hashKey must be the name of the primary key hash (parition) attribute, rangeKey must be the name of the primary key range (sort) attribute. +// This function is for tables with a hash key (partition key) and range key (sort key). +// For tables without a range key primary key, use [BatchWrite.DeleteIn] instead. +func (bw *BatchWrite) DeleteInRange(table Table, hashKey, rangeKey string, keys ...Keyed) *BatchWrite { + return bw.deleteIn(table, hashKey, rangeKey, keys...) +} + +func (bw *BatchWrite) deleteIn(table Table, hashKey, rangeKey string, keys ...Keyed) *BatchWrite { + name := table.Name() for _, key := range keys { - del := bw.batch.table.Delete(bw.batch.hashKey, key.HashKey()) - if rk := key.RangeKey(); bw.batch.rangeKey != "" && rk != nil { - del.Range(bw.batch.rangeKey, rk) + del := table.Delete(hashKey, key.HashKey()) + if rk := key.RangeKey(); rangeKey != "" && rk != nil { + del.Range(rangeKey, rk) bw.setError(del.err) } - bw.ops = append(bw.ops, types.WriteRequest{DeleteRequest: &types.DeleteRequest{ - Key: del.key(), - }}) + bw.ops = append(bw.ops, batchWrite{ + table: name, + op: types.WriteRequest{DeleteRequest: &types.DeleteRequest{ + Key: del.key(), + }}, + }) + } + return bw +} + +// Merge copies operations from src to this batch. +func (bw *BatchWrite) Merge(srcs ...*BatchWrite) *BatchWrite { + for _, src := range srcs { + bw.ops = append(bw.ops, src.ops...) } return bw } @@ -103,12 +150,21 @@ func (bw *BatchWrite) Run(ctx context.Context) (wrote int, err error) { } } - unprocessed := res.UnprocessedItems[bw.batch.table.Name()] - wrote += len(ops) - len(unprocessed) - if len(unprocessed) == 0 { + wrote += len(ops) + if len(res.UnprocessedItems) == 0 { break } - ops = unprocessed + + ops = ops[:0] + for tableName, unprocessed := range res.UnprocessedItems { + wrote -= len(unprocessed) + for _, op := range unprocessed { + ops = append(ops, batchWrite{ + table: tableName, + op: op, + }) + } + } // need to sleep when re-requesting, per spec if err := time.SleepWithContext(ctx, boff.NextBackOff()); err != nil { @@ -121,11 +177,13 @@ func (bw *BatchWrite) Run(ctx context.Context) (wrote int, err error) { return wrote, nil } -func (bw *BatchWrite) input(ops []types.WriteRequest) *dynamodb.BatchWriteItemInput { +func (bw *BatchWrite) input(ops []batchWrite) *dynamodb.BatchWriteItemInput { + items := make(map[string][]types.WriteRequest) + for _, op := range ops { + items[op.table] = append(items[op.table], op.op) + } input := &dynamodb.BatchWriteItemInput{ - RequestItems: map[string][]types.WriteRequest{ - bw.batch.table.Name(): ops, - }, + RequestItems: items, } if bw.cc != nil { input.ReturnConsumedCapacity = types.ReturnConsumedCapacityIndexes diff --git a/db_test.go b/db_test.go index a83f1f1..58ecbd9 100644 --- a/db_test.go +++ b/db_test.go @@ -2,40 +2,66 @@ package dynamo import ( "context" + "errors" + "fmt" "log" "os" + "strconv" + "strings" "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/smithy-go" ) var ( - testDB *DB - testTable = "TestDB" + testDB *DB + testTableWidgets = "TestDB" + testTableSprockets = "TestDB-Sprockets" ) var dummyCreds = credentials.NewStaticCredentialsProvider("dummy", "dummy", "") const offlineSkipMsg = "DYNAMO_TEST_REGION not set" -func init() { - // os.Setenv("DYNAMO_TEST_REGION", "us-west-2") - if region := os.Getenv("DYNAMO_TEST_REGION"); region != "" { - var endpoint aws.EndpointResolverWithOptions - if dte := os.Getenv("DYNAMO_TEST_ENDPOINT"); dte != "" { - endpoint = aws.EndpointResolverWithOptionsFunc( +// widget is the data structure used for integration tests +type widget struct { + UserID int `dynamo:",hash"` + Time time.Time `dynamo:",range" index:"Msg-Time-index,range"` + Msg string `index:"Msg-Time-index,hash"` + Count int + Meta map[string]string + StrPtr *string `dynamo:",allowempty"` +} + +func TestMain(m *testing.M) { + var endpoint, region *string + if dte := os.Getenv("DYNAMO_TEST_ENDPOINT"); dte != "" { + endpoint = &dte + } + if dtr := os.Getenv("DYNAMO_TEST_REGION"); dtr != "" { + region = &dtr + } + if endpoint != nil && region == nil { + dtr := "local" + region = &dtr + } + if region != nil { + var resolv aws.EndpointResolverWithOptions + if endpoint != nil { + resolv = aws.EndpointResolverWithOptionsFunc( func(service, region string, options ...interface{}) (aws.Endpoint, error) { - return aws.Endpoint{URL: dte}, nil + return aws.Endpoint{URL: *endpoint}, nil }, ) } cfg, err := config.LoadDefaultConfig( context.Background(), - config.WithRegion(region), - config.WithEndpointResolverWithOptions(endpoint), + config.WithRegion(*region), + config.WithEndpointResolverWithOptions(resolv), config.WithRetryer(nil), ) if err != nil { @@ -43,19 +69,73 @@ func init() { } testDB = New(cfg) } + + timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10) + var offline bool if table := os.Getenv("DYNAMO_TEST_TABLE"); table != "" { - testTable = table + offline = false + // Test-% --> Test-1707708680863 + table = strings.ReplaceAll(table, "%", timestamp) + testTableWidgets = table + } + if table := os.Getenv("DYNAMO_TEST_TABLE2"); table != "" { + table = strings.ReplaceAll(table, "%", timestamp) + testTableSprockets = table + } else if !offline { + testTableSprockets = testTableWidgets + "-Sprockets" + } + + if !offline && testTableWidgets == testTableSprockets { + panic(fmt.Sprintf("DYNAMO_TEST_TABLE must not equal DYNAMO_TEST_TABLE2. got DYNAMO_TEST_TABLE=%q and DYNAMO_TEST_TABLE2=%q", + testTableWidgets, testTableSprockets)) + } + + var shouldCreate bool + switch os.Getenv("DYNAMO_TEST_CREATE_TABLE") { + case "1", "true", "yes": + shouldCreate = true + case "0", "false", "no": + shouldCreate = false + default: + shouldCreate = endpoint != nil + } + ctx := context.Background() + var created []Table + if testDB != nil { + for _, name := range []string{testTableWidgets, testTableSprockets} { + table := testDB.Table(name) + log.Println("Checking test table:", name) + _, err := table.Describe().Run(ctx) + switch { + case isTableNotExistsErr(err) && shouldCreate: + log.Println("Creating test table:", name) + if err := testDB.CreateTable(name, widget{}).Run(ctx); err != nil { + panic(err) + } + created = append(created, testDB.Table(name)) + case err != nil: + panic(err) + } + } + } + + code := m.Run() + defer os.Exit(code) + + for _, table := range created { + log.Println("Deleting test table:", table.Name()) + if err := table.DeleteTable().Run(ctx); err != nil { + log.Println("Error deleting test table:", table.Name(), err) + } } } -// widget is the data structure used for integration tests -type widget struct { - UserID int `dynamo:",hash"` - Time time.Time `dynamo:",range"` - Msg string - Count int - Meta map[string]string - StrPtr *string `dynamo:",allowempty"` +func isTableNotExistsErr(err error) bool { + var aerr smithy.APIError + if errors.As(err, &aerr) { + return aerr.ErrorCode() == "ResourceNotFoundException" + } + return false } func TestListTables(t *testing.T) { @@ -71,13 +151,13 @@ func TestListTables(t *testing.T) { found := false for _, t := range tables { - if t == testTable { + if t == testTableWidgets { found = true break } } if !found { - t.Error("couldn't find testTable", testTable, "in:", tables) + t.Error("couldn't find testTable", testTableWidgets, "in:", tables) } } diff --git a/decode.go b/decode.go index 03fbd03..7586c69 100644 --- a/decode.go +++ b/decode.go @@ -98,14 +98,14 @@ func unmarshalAppendTo(out interface{}) func(item Item, out interface{}) error { /* Like: - member := new(T) return func(item, ...) { + member := new(T) decode(item, member) *slice = append(*slice, *member) } */ - member := reflect.New(membert) // *T of *[]T - return func(item Item, _ any) error { + return func(item map[string]types.AttributeValue, _ any) error { + member := reflect.New(membert) // *T of *[]T if err := plan.decodeItem(item, member); err != nil { return err } diff --git a/decode_test.go b/decode_test.go index df91fb0..7e5b7b0 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1,7 +1,9 @@ package dynamo import ( + "maps" "reflect" + "strconv" "testing" "time" @@ -67,7 +69,7 @@ func TestUnmarshalAsymmetric(t *testing.T) { func TestUnmarshalAppend(t *testing.T) { var results []struct { - User int `dynamo:"UserID"` + User *int `dynamo:"UserID"` Page int Limit uint Null interface{} @@ -83,16 +85,22 @@ func TestUnmarshalAppend(t *testing.T) { "Null": &types.AttributeValueMemberNULL{Value: null}, } - for range [15]struct{}{} { - err := unmarshalAppend(item, &results) + do := unmarshalAppendTo(&results) + + for i := range [15]struct{}{} { + item2 := maps.Clone(item) + id := 12345 + i + idstr := strconv.Itoa(id) + item2["UserID"] = &types.AttributeValueMemberN{Value: idstr} + err := do(item2, &results) if err != nil { t.Fatal(err) } } - for _, h := range results { - if h.User != 12345 || h.Page != 5 || h.Limit != 20 || h.Null != nil { - t.Error("invalid hit", h) + for i, h := range results { + if *h.User != 12345+i || h.Page != 5 || h.Limit != 20 || h.Null != nil { + t.Error("invalid hit", h, *h.User) } } diff --git a/delete_test.go b/delete_test.go index 6858678..feae4c2 100644 --- a/delete_test.go +++ b/delete_test.go @@ -11,8 +11,8 @@ func TestDelete(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) ctx := context.TODO() + table := testDB.Table(testTableWidgets) // first, add an item to delete later item := widget{ diff --git a/describetable_test.go b/describetable_test.go index 894b951..798a94b 100644 --- a/describetable_test.go +++ b/describetable_test.go @@ -9,7 +9,7 @@ func TestDescribeTable(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) desc, err := table.Describe().Run(context.TODO()) if err != nil { @@ -17,8 +17,8 @@ func TestDescribeTable(t *testing.T) { return } - if desc.Name != testTable { - t.Error("wrong name:", desc.Name, "≠", testTable) + if desc.Name != testTableWidgets { + t.Error("wrong name:", desc.Name, "≠", testTableWidgets) } if desc.HashKey != "UserID" || desc.RangeKey != "Time" { t.Error("bad keys:", desc.HashKey, desc.RangeKey) diff --git a/encode.go b/encode.go index 58da7eb..db1d775 100644 --- a/encode.go +++ b/encode.go @@ -48,7 +48,11 @@ func marshal(v interface{}, flags encodeFlags) (types.AttributeValue, error) { } rt := rv.Type() - enc, err := encodeType(rt, flags) + def, err := typedefOf(rt) + if err != nil { + return nil, err + } + enc, err := def.encodeType(rt, flags) if err != nil { return nil, err } @@ -106,7 +110,7 @@ type isZeroer interface { IsZero() bool } -func isZeroFunc(rt reflect.Type) func(rv reflect.Value) bool { +func (def *typedef) isZeroFunc(rt reflect.Type) func(rv reflect.Value) bool { if rt.Implements(rtypeIsZeroer) { return isZeroIface(rt, func(v isZeroer) bool { return v.IsZero() @@ -131,10 +135,10 @@ func isZeroFunc(rt reflect.Type) func(rv reflect.Value) bool { return isNil case reflect.Array: - return isZeroArray(rt) + return def.isZeroArray(rt) case reflect.Struct: - return isZeroStruct(rt) + return def.isZeroStruct(rt) } return isZeroValue @@ -160,13 +164,13 @@ func isZeroIface[T any](rt reflect.Type, isZero func(v T) bool) func(rv reflect. } } -func isZeroStruct(rt reflect.Type) func(rv reflect.Value) bool { - fields, err := structFields(rt) +func (def *typedef) isZeroStruct(rt reflect.Type) func(rv reflect.Value) bool { + fields, err := def.structFields(rt, false) if err != nil { return nil } return func(rv reflect.Value) bool { - for _, info := range fields { + for _, info := range *fields { if info.isZero == nil { continue } @@ -184,8 +188,8 @@ func isZeroStruct(rt reflect.Type) func(rv reflect.Value) bool { } } -func isZeroArray(rt reflect.Type) func(reflect.Value) bool { - elemIsZero := isZeroFunc(rt.Elem()) +func (def *typedef) isZeroArray(rt reflect.Type) func(reflect.Value) bool { + elemIsZero := def.isZeroFunc(rt.Elem()) return func(rv reflect.Value) bool { for i := 0; i < rv.Len(); i++ { if !elemIsZero(rv.Index(i)) { diff --git a/encode_test.go b/encode_test.go index 3a16e3b..fe95dc2 100644 --- a/encode_test.go +++ b/encode_test.go @@ -139,3 +139,127 @@ func TestMarshalItemAsymmetric(t *testing.T) { }) } } + +type isValue_Kind interface { + isValue_Kind() +} + +type myStruct struct { + OK bool + Value isValue_Kind +} + +func (ms *myStruct) MarshalDynamoItem() (map[string]types.AttributeValue, error) { + world := "world" + return map[string]types.AttributeValue{ + "hello": &types.AttributeValueMemberS{Value: world}, + }, nil +} + +func (ms *myStruct) UnmarshalDynamoItem(item map[string]types.AttributeValue) error { + hello := item["hello"] + if h, ok := hello.(*types.AttributeValueMemberS); ok && h.Value == "world" { + ms.OK = true + } else { + ms.OK = false + } + return nil +} + +var _ ItemMarshaler = &myStruct{} +var _ ItemUnmarshaler = &myStruct{} + +func TestMarshalItemBypass(t *testing.T) { + something := &myStruct{} + got, err := MarshalItem(something) + if err != nil { + t.Fatal(err) + } + + world := "world" + expect := map[string]types.AttributeValue{ + "hello": &types.AttributeValueMemberS{Value: world}, + } + if !reflect.DeepEqual(got, expect) { + t.Error("bad marshal. want:", expect, "got:", got) + } + + var dec myStruct + err = UnmarshalItem(got, &dec) + if err != nil { + t.Fatal(err) + } + if !dec.OK { + t.Error("bad unmarshal") + } +} + +func TestMarshalRecursive(t *testing.T) { + t.SkipNow() + + type Person struct { + Spouse *Person + Children []Person + Name string + } + type Friend struct { + ID int + Person Person + Nickname string + } + children := []Person{ + {Name: "Bobby"}, + } + + hank := Person{ + Spouse: &Person{ + Name: "Peggy", + Children: children, + }, + Children: children, + Name: "Hank", + } + + t.Run("self-recursive", func(t *testing.T) { + + want := map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Hank"}, + "Spouse": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Peggy"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Bobby"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + }}, + }, + }, + }}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Bobby"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + }}, + }}, + } + + got, err := MarshalItem(hank) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Error("bad", got) + } + }) + + t.Run("field is recursive", func(t *testing.T) { + friend := Friend{ + Person: hank, + Nickname: "H-Dawg", + } + got, err := MarshalItem(friend) + if err != nil { + t.Fatal(err) + } + t.Fatal(got) + }) +} diff --git a/encodefunc.go b/encodefunc.go index e6cbf91..fad4c2f 100644 --- a/encodefunc.go +++ b/encodefunc.go @@ -13,10 +13,9 @@ import ( type encodeFunc func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) -func encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { +func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { try := rt for { - // deref := func() switch try { case rtypeAttrB: return encode2(func(av types.AttributeValue, _ encodeFlags) (types.AttributeValue, error) { @@ -123,7 +122,7 @@ func encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { switch rt.Kind() { case reflect.Pointer: - return encodePtr(rt, flags) + return def.encodePtr(rt, flags) // BOOL case reflect.Bool: @@ -153,7 +152,7 @@ func encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { return encodeSet(rt, flags) } // lists (L) - return encodeList(rt, flags) + return def.encodeList(rt, flags) case reflect.Map: // sets (NS, SS, BS) @@ -161,22 +160,22 @@ func encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { return encodeSet(rt, flags) } // M - return encodeMapM(rt, flags) + return def.encodeMapM(rt, flags) // M case reflect.Struct: - return encodeStruct(rt) + return def.encodeStruct(rt) case reflect.Interface: if rt.NumMethod() == 0 { - return encodeAny, nil + return def.encodeAny, nil } } return nil, fmt.Errorf("dynamo marshal: unsupported type %s", rt.String()) } -func encodePtr(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { - elem, err := encodeType(rt.Elem(), flags) +func (def *typedef) encodePtr(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { + elem, err := def.encodeType(rt.Elem(), flags) if err != nil { return nil, err } @@ -279,13 +278,23 @@ func encodeBytes(rt reflect.Type, flags encodeFlags) encodeFunc { } } -func encodeStruct(rt reflect.Type) (encodeFunc, error) { - fields, err := structFields(rt) +func (def *typedef) encodeStruct(rt reflect.Type) (encodeFunc, error) { + var fields *[]structField + var err error + if def.sameAsRoot(rt) { + fields, err = def.structFields(rt, false) + } else { + var subdef *typedef + subdef, err = typedefOf(rt) + if subdef != nil { + fields = &subdef.fields + } + } if err != nil { return nil, err } return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { - item, err := encodeItem(fields, rv) + item, err := encodeItem(*fields, rv) if err != nil { return nil, err } @@ -372,7 +381,7 @@ func encodeSliceBS(rv reflect.Value, flags encodeFlags) (types.AttributeValue, e return &types.AttributeValueMemberBS{Value: bs}, nil } -func encodeMapM(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { +func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { keyString := encodeMapKeyFunc(rt) if keyString == nil { return nil, fmt.Errorf("dynamo marshal: map key type must be string or encoding.TextMarshaler, have %v", rt) @@ -388,7 +397,7 @@ func encodeMapM(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { subflags |= flagOmitEmpty } - valueEnc, err := encodeType(rt.Elem(), subflags) + valueEnc, err := def.encodeType(rt.Elem(), subflags) if err != nil { return nil, err } @@ -585,7 +594,7 @@ func encodeSet(rt /* []T | map[T]bool | map[T]struct{} */ reflect.Type, flags en return nil, fmt.Errorf("dynamo: marshal: invalid type for set %s", rt.String()) } -func encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { +func (def *typedef) encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { // lists CAN be empty subflags := flagNone if flags&flagOmitEmptyElem == 0 { @@ -599,7 +608,7 @@ func encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { subflags |= flagAllowEmptyElem } - valueEnc, err := encodeType(rt.Elem(), subflags) + valueEnc, err := def.encodeType(rt.Elem(), subflags) if err != nil { return nil, err } @@ -629,14 +638,14 @@ func encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { }, nil } -func encodeAny(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { +func (def *typedef) encodeAny(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { if !rv.CanInterface() || rv.IsNil() { if flags&flagNull != 0 { return nullAV, nil } return nil, nil } - enc, err := encodeType(rv.Elem().Type(), flags) + enc, err := def.encodeType(rv.Elem().Type(), flags) if err != nil { return nil, err } diff --git a/encoding.go b/encoding.go index 605eab5..599dd08 100644 --- a/encoding.go +++ b/encoding.go @@ -15,17 +15,21 @@ var typeCache sync.Map // unmarshalKey → *typedef type typedef struct { decoders map[unmarshalKey]decodeFunc fields []structField + root reflect.Type } func newTypedef(rt reflect.Type) (*typedef, error) { def := &typedef{ decoders: make(map[unmarshalKey]decodeFunc), + // encoders: make(map[encodeKey]encodeFunc), + root: rt, } err := def.init(rt) return def, err } func (def *typedef) init(rt reflect.Type) error { + rt0 := rt for rt.Kind() == reflect.Pointer { rt = rt.Elem() } @@ -36,8 +40,15 @@ func (def *typedef) init(rt reflect.Type) error { return nil } - var err error - def.fields, err = structFields(rt) + // skip visiting struct fields if encoding will be bypassed by a custom marshaler + if shouldBypassEncodeItem(rt0) || shouldBypassEncodeItem(rt) { + return nil + } + + fieldptr, err := def.structFields(rt, true) + if fieldptr != nil { + def.fields = *fieldptr + } return err } @@ -87,7 +98,7 @@ func (def *typedef) encodeItem(rv reflect.Value) (Item, error) { case reflect.Struct: return encodeItem(def.fields, rv) case reflect.Map: - enc, err := encodeMapM(rv.Type(), flagNone) + enc, err := def.encodeMapM(rv.Type(), flagNone) if err != nil { return nil, err } @@ -437,10 +448,31 @@ type structField struct { isZero func(reflect.Value) bool } -func structFields(rt reflect.Type) ([]structField, error) { +// type encodeKey struct { +// rt reflect.Type +// flags encodeFlags +// } + +func (def *typedef) sameAsRoot(rt reflect.Type) bool { + switch { + case rt == def.root: + return true + case def.root.Kind() == reflect.Pointer && rt.Kind() != reflect.Pointer: + return def.root.Elem() == rt + case def.root.Kind() != reflect.Pointer && rt.Kind() == reflect.Pointer: + return rt.Elem() == def.root + } + return false +} + +func (def *typedef) structFields(rt reflect.Type, isRoot bool) (*[]structField, error) { + if !isRoot && def.sameAsRoot(rt) { + return &def.fields, nil + } + var fields []structField err := visitTypeFields(rt, nil, nil, func(name string, index []int, flags encodeFlags, vt reflect.Type) error { - enc, err := encodeType(vt, flags) + enc, err := def.encodeType(vt, flags) if err != nil { return err } @@ -449,12 +481,12 @@ func structFields(rt reflect.Type) ([]structField, error) { name: name, flags: flags, enc: enc, - isZero: isZeroFunc(vt), + isZero: def.isZeroFunc(vt), } fields = append(fields, field) return nil }) - return fields, err + return &fields, err } var ( diff --git a/encoding_test.go b/encoding_test.go index e1ac632..f266f1e 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -711,6 +711,74 @@ var itemEncodingTests = []struct { "thing": &types.AttributeValueMemberN{Value: "52"}, }, }, + { + name: "self-recursive struct", + in: Person{ + Spouse: &Person{ + Name: "Peggy", + Children: []Person{{Name: "Bobby", Children: []Person{}}}, + }, + Children: []Person{{Name: "Bobby", Children: []Person{}}}, + Name: "Hank", + }, + out: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Hank"}, + "Spouse": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Peggy"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Bobby"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + }}, + }, + }, + }}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Bobby"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + }}, + }}, + }, + }, + { + name: "struct with recursive field", + in: Friend{ + ID: 555, + Person: Person{ + Spouse: &Person{ + Name: "Peggy", + Children: []Person{{Name: "Bobby", Children: []Person{}}}, + }, + Children: []Person{{Name: "Bobby", Children: []Person{}}}, + Name: "Hank", + }, + Nickname: "H-Dawg", + }, + out: map[string]types.AttributeValue{ + "ID": &types.AttributeValueMemberN{Value: "555"}, + "Nickname": &types.AttributeValueMemberS{Value: "H-Dawg"}, + "Person": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Hank"}, + "Spouse": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Peggy"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Bobby"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + }}, + }, + }, + }}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Name": &types.AttributeValueMemberS{Value: "Bobby"}, + "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + }}, + }}, + }}, + }, + }, } type embedded struct { @@ -799,6 +867,18 @@ func (cim *customItemMarshaler) UnmarshalDynamoItem(item Item) error { return nil } +type Person struct { + Spouse *Person + Children []Person + Name string +} + +type Friend struct { + ID int + Person Person + Nickname string +} + func byteSlicePtr(a []byte) *[]byte { return &a } diff --git a/put_test.go b/put_test.go index 3c37edc..cbdd400 100644 --- a/put_test.go +++ b/put_test.go @@ -13,7 +13,7 @@ func TestPut(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.TODO() type widget2 struct { @@ -64,7 +64,7 @@ func TestPut(t *testing.T) { t.Errorf("bad old value. %#v ≠ %#v", oldValue, item) } - if cc.Total < 1 || cc.Table < 1 || cc.TableName != testTable { + if cc.Total < 1 || cc.Table < 1 || cc.TableName != testTableWidgets { t.Errorf("bad consumed capacity: %#v", cc) } @@ -79,7 +79,7 @@ func TestPutAndQueryAWSEncoding(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.TODO() type awsWidget struct { diff --git a/query.go b/query.go index b2eaccc..689bffc 100644 --- a/query.go +++ b/query.go @@ -582,8 +582,8 @@ func (q *Query) keys() Item { return keys } -func (q *Query) keysAndAttribs() *types.KeysAndAttributes { - kas := &types.KeysAndAttributes{ +func (q *Query) keysAndAttribs() types.KeysAndAttributes { + kas := types.KeysAndAttributes{ Keys: []Item{q.keys()}, ExpressionAttributeNames: q.nameExpr, ConsistentRead: &q.consistent, diff --git a/query_test.go b/query_test.go index a3d714f..d4ac428 100644 --- a/query_test.go +++ b/query_test.go @@ -15,7 +15,7 @@ func TestGetAllCount(t *testing.T) { t.Skip(offlineSkipMsg) } ctx := context.TODO() - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // first, add an item to make sure there is at least one item := widget{ @@ -149,7 +149,7 @@ func TestQueryPaging(t *testing.T) { t.Skip(offlineSkipMsg) } ctx := context.TODO() - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) widgets := []interface{}{ widget{ @@ -201,7 +201,7 @@ func TestQueryMagicLEK(t *testing.T) { t.Skip(offlineSkipMsg) } ctx := context.Background() - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) widgets := []interface{}{ widget{ @@ -292,7 +292,7 @@ func TestQueryBadKeys(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.Background() t.Run("hash key", func(t *testing.T) { diff --git a/scan_test.go b/scan_test.go index b4d0ef3..de2dcc9 100644 --- a/scan_test.go +++ b/scan_test.go @@ -11,7 +11,7 @@ func TestScan(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.TODO() // first, add an item to make sure there is at least one @@ -106,7 +106,7 @@ func TestScanPaging(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.TODO() // prepare data @@ -126,11 +126,9 @@ func TestScanPaging(t *testing.T) { widgets := [10]widget{} itr := table.Scan().Consistent(true).SearchLimit(1).Iter() for i := 0; i < len(widgets); i++ { - more := itr.Next(ctx, &widgets[i]) + itr.Next(ctx, &widgets[i]) if itr.Err() != nil { t.Error("unexpected error", itr.Err()) - } - if !more { break } lek, err := itr.LastEvaluatedKey(context.Background()) @@ -150,26 +148,21 @@ func TestScanPaging(t *testing.T) { const segments = 2 ctx := context.Background() widgets := [10]widget{} - itr := table.Scan().Consistent(true).SearchLimit(1).IterParallel(ctx, segments) - for i := 0; i < len(widgets)/segments; i++ { - var more bool - for j := 0; j < segments; j++ { - more = itr.Next(ctx, &widgets[i*segments+j]) - if !more && j != segments-1 { - t.Error("bad number of results from parallel scan") - } + limit := int(len(widgets) / segments) + itr := table.Scan().Consistent(true).SearchLimit(limit).IterParallel(ctx, segments) + for i := 0; i < len(widgets); { + for ; i < len(widgets) && itr.Next(ctx, &widgets[i]); i++ { } if itr.Err() != nil { t.Error("unexpected error", itr.Err()) - } - if !more { break } - leks, err := itr.LastEvaluatedKeys(context.Background()) + t.Logf("parallel chunk: %d", i) + lek, err := itr.LastEvaluatedKeys(ctx) if err != nil { - t.Error("LEK error", err) + t.Fatal("lek error", err) } - itr = table.Scan().SearchLimit(1).IterParallelStartFrom(ctx, leks) + itr = table.Scan().SearchLimit(limit).IterParallelStartFrom(ctx, lek) } for i, w := range widgets { if w.UserID == 0 && w.Time.IsZero() { @@ -183,7 +176,7 @@ func TestScanMagicLEK(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.Background() widgets := []interface{}{ diff --git a/ttl_test.go b/ttl_test.go index 4076647..ed78111 100644 --- a/ttl_test.go +++ b/ttl_test.go @@ -9,7 +9,7 @@ func TestDescribeTTL(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.TODO() desc, err := table.DescribeTTL().Run(ctx) diff --git a/tx_test.go b/tx_test.go index eb09864..d44d709 100644 --- a/tx_test.go +++ b/tx_test.go @@ -25,7 +25,7 @@ func TestTx(t *testing.T) { widget2 := widget{UserID: 69, Time: date2, Msg: "cat"} // basic write & check - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) tx := testDB.WriteTx() var cc, ccold ConsumedCapacity tx.Idempotent(true) @@ -190,7 +190,7 @@ func TestTxRetry(t *testing.T) { date1 := time.Date(1999, 1, 1, 1, 1, 1, 0, time.UTC) widget1 := widget{UserID: 69, Time: date1, Msg: "dog", Count: 0} - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) if err := table.Put(widget1).Run(ctx); err != nil { t.Fatal(err) } diff --git a/update_test.go b/update_test.go index c3a29e1..fce154e 100644 --- a/update_test.go +++ b/update_test.go @@ -14,7 +14,7 @@ func TestUpdate(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.TODO() type widget2 struct { @@ -170,7 +170,7 @@ func TestUpdateNil(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.TODO() // first, add an item to make sure there is at least one @@ -226,7 +226,7 @@ func TestUpdateSetAutoOmit(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.TODO() type widget2 struct { diff --git a/updatetable_test.go b/updatetable_test.go index 5b1e863..1fa5913 100644 --- a/updatetable_test.go +++ b/updatetable_test.go @@ -10,7 +10,7 @@ func _TestUpdateTable(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) ctx := context.TODO() desc, err := table.UpdateTable().CreateIndex(Index{ @@ -34,8 +34,8 @@ func _TestUpdateTable(t *testing.T) { if err != nil { t.Error(err) } - if desc.Name != testTable { - t.Error("wrong name:", desc.Name, "≠", testTable) + if desc.Name != testTableWidgets { + t.Error("wrong name:", desc.Name, "≠", testTableWidgets) } if desc.Status != UpdatingStatus { t.Error("bad status:", desc.Status, "≠", UpdatingStatus)