From e86c7f05eed323eba916aef2be29ebac98c50702 Mon Sep 17 00:00:00 2001 From: Greg Date: Mon, 12 Feb 2024 14:12:18 +0900 Subject: [PATCH] add multi-table batch tests --- README.md | 43 ++++++++-------------- batch_test.go | 84 ++++++++++++++++++++++++++++--------------- db_test.go | 80 +++++++++++++++++++++++++++++------------ delete_test.go | 2 +- describetable_test.go | 6 ++-- put_test.go | 6 ++-- query_test.go | 8 ++--- scan_test.go | 6 ++-- ttl_test.go | 2 +- tx_test.go | 4 +-- update_test.go | 6 ++-- updatetable_test.go | 6 ++-- 12 files changed, 150 insertions(+), 103 deletions(-) diff --git a/README.md b/README.md index b4ac10c..8cf6fa5 100644 --- a/README.md +++ b/README.md @@ -232,38 +232,23 @@ err := db.Table("Books").Get("ID", 555).One(dynamo.AWSEncoding(&someBook)) ### Integration tests -By default, tests are run in offline mode. Create a table called `TestDB`, with a number partition key called `UserID` and a string sort key called `Time`. It also needs a Global Secondary Index called `Msg-Time-index` with a string partition key called `Msg` and a string sort key called `Time`. +By default, tests are run in offline mode. In order to run the integration tests, some environment variables need to be set. -Change the table name with the environment variable `DYNAMO_TEST_TABLE`. You must specify `DYNAMO_TEST_REGION`, setting it to the AWS region where your test table is. - - - ```bash -DYNAMO_TEST_REGION=us-west-2 go test github.com/guregu/dynamo/... -cover - ``` - -If you want to use [DynamoDB Local](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DynamoDBLocal.html) to run local tests, specify `DYNAMO_TEST_ENDPOINT`. - - ```bash -DYNAMO_TEST_REGION=us-west-2 DYNAMO_TEST_ENDPOINT=http://localhost:8000 go test github.com/guregu/dynamo/... -cover - ``` - -Example of using [aws-cli](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Tools.CLI.html) to create a table for testing. +To run the tests against [DynamoDB Local](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DynamoDBLocal.html): ```bash -aws dynamodb create-table \ - --table-name TestDB \ - --attribute-definitions \ - AttributeName=UserID,AttributeType=N \ - AttributeName=Time,AttributeType=S \ - AttributeName=Msg,AttributeType=S \ - --key-schema \ - AttributeName=UserID,KeyType=HASH \ - AttributeName=Time,KeyType=RANGE \ - --global-secondary-indexes \ - IndexName=Msg-Time-index,KeySchema=[{'AttributeName=Msg,KeyType=HASH'},{'AttributeName=Time,KeyType=RANGE'}],Projection={'ProjectionType=ALL'} \ - --billing-mode PAY_PER_REQUEST \ - --region us-west-2 \ - --endpoint-url http://localhost:8000 # using DynamoDB local +# Use Docker to run DynamoDB local on port 8880 +docker compose -f '.github/docker-compose.yml' up -d + +# Run the tests with a fresh table +# The tables will be created automatically +DYNAMO_TEST_ENDPOINT='http://localhost:8880' \ + DYNAMO_TEST_REGION='local' \ + DYNAMO_TEST_TABLE='TestDB-%' \ # the % will be replaced the current timestamp + AWS_ACCESS_KEY_ID='dummy' \ + AWS_SECRET_ACCESS_KEY='dummy' \ + AWS_REGION='local' \ + go test -v -race ./... -cover -coverpkg=./... ``` ### License diff --git a/batch_test.go b/batch_test.go index 398e455..e18502d 100644 --- a/batch_test.go +++ b/batch_test.go @@ -11,7 +11,10 @@ func TestBatchGetWrite(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table1 := testDB.Table(testTableWidgets) + table2 := testDB.Table(testTableSprockets) + tables := []Table{table1, table2} + totalBatchSize := batchSize * len(tables) items := make([]interface{}, batchSize) widgets := make(map[int]widget) @@ -28,10 +31,19 @@ func TestBatchGetWrite(t *testing.T) { keys[i] = Keys{i, now} } + var batches []*BatchWrite + for _, table := range tables { + b := table.Batch().Write().Put(items...) + batches = append(batches, b) + } + batch1 := batches[0] + for _, b := range batches[1:] { + batch1.Merge(b) + } var wcc ConsumedCapacity - wrote, err := table.Batch().Write().Put(items...).ConsumedCapacity(&wcc).Run() - if wrote != batchSize { - t.Error("unexpected wrote:", wrote, "≠", batchSize) + wrote, err := batch1.ConsumedCapacity(&wcc).Run() + if wrote != totalBatchSize { + t.Error("unexpected wrote:", wrote, "≠", totalBatchSize) } if err != nil { t.Error("unexpected error:", err) @@ -41,20 +53,29 @@ func TestBatchGetWrite(t *testing.T) { } // get all - var results []widget + var gets []*BatchGet + for _, table := range tables { + b := table.Batch("UserID", "Time"). + Get(keys...). + Project("UserID", "Time"). + Consistent(true) + gets = append(gets, b) + } + var cc ConsumedCapacity - err = table.Batch("UserID", "Time"). - Get(keys...). - Project("UserID", "Time"). - Consistent(true). - ConsumedCapacity(&cc). - All(&results) + get1 := gets[0].ConsumedCapacity(&cc) + for _, b := range gets[1:] { + get1.Merge(b) + } + + var results []widget + err = get1.All(&results) if err != nil { t.Error("unexpected error:", err) } - if len(results) != batchSize { - t.Error("expected", batchSize, "results, got", len(results)) + if len(results) != totalBatchSize { + t.Error("expected", totalBatchSize, "results, got", len(results)) } if cc.Total == 0 { @@ -72,26 +93,31 @@ func TestBatchGetWrite(t *testing.T) { } // delete both - wrote, err = table.Batch("UserID", "Time").Write(). - Delete(keys...).Run() - if wrote != batchSize { - t.Error("unexpected wrote:", wrote, "≠", batchSize) + wrote, err = table1.Batch("UserID", "Time").Write(). + Delete(keys...). + DeleteInRange(table2, "UserID", "Time", keys...). + Run() + if wrote != totalBatchSize { + t.Error("unexpected wrote:", wrote, "≠", totalBatchSize) } if err != nil { t.Error("unexpected error:", err) } // get both again - results = nil - err = table.Batch("UserID", "Time"). - Get(keys...). - Consistent(true). - All(&results) - if err != ErrNotFound { - t.Error("expected ErrNotFound, got", err) - } - if len(results) != 0 { - t.Error("expected 0 results, got", len(results)) + { + var results []widget + err = table1.Batch("UserID", "Time"). + Get(keys...). + FromRange(table2, "UserID", "Time", keys...). + Consistent(true). + All(&results) + if err != ErrNotFound { + t.Error("expected ErrNotFound, got", err) + } + if len(results) != 0 { + t.Error("expected 0 results, got", len(results)) + } } } @@ -99,7 +125,7 @@ func TestBatchGetEmptySets(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) now := time.Now().UnixNano() / 1000000000 id := int(now) @@ -150,7 +176,7 @@ func TestBatchGetEmptySets(t *testing.T) { } func TestBatchEmptyInput(t *testing.T) { - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) var out []any err := table.Batch("UserID", "Time").Get().All(&out) if err != ErrNoInput { diff --git a/db_test.go b/db_test.go index 3082e02..e9cd8c2 100644 --- a/db_test.go +++ b/db_test.go @@ -2,6 +2,7 @@ package dynamo import ( "errors" + "fmt" "log" "os" "strconv" @@ -16,8 +17,9 @@ import ( ) var ( - testDB *DB - testTable = "TestDB" + testDB *DB + testTableWidgets = "TestDB" + testTableSprockets = "TestDB-Sprockets" ) var dummyCreds = credentials.NewStaticCredentials("dummy", "dummy", "") @@ -35,37 +37,71 @@ type widget struct { } func TestMain(m *testing.M) { - var endpoint *string - if region := os.Getenv("DYNAMO_TEST_REGION"); region != "" { - if dte := os.Getenv("DYNAMO_TEST_ENDPOINT"); dte != "" { - endpoint = aws.String(dte) - } + var endpoint, region *string + if dte := os.Getenv("DYNAMO_TEST_ENDPOINT"); dte != "" { + endpoint = &dte + } + if dtr := os.Getenv("DYNAMO_TEST_REGION"); dtr != "" { + region = &dtr + } + if endpoint != nil && region == nil { + dtr := "local" + region = &dtr + } + if region != nil { testDB = New(session.Must(session.NewSession()), &aws.Config{ - Region: aws.String(region), + Region: region, Endpoint: endpoint, // LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody), }) } + + timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10) + var offline bool if table := os.Getenv("DYNAMO_TEST_TABLE"); table != "" { + offline = false // Test-% --> Test-1707708680863 - table = strings.ReplaceAll(table, "%", strconv.FormatInt(time.Now().UnixMilli(), 10)) - testTable = table + table = strings.ReplaceAll(table, "%", timestamp) + testTableWidgets = table + } + if table := os.Getenv("DYNAMO_TEST_TABLE2"); table != "" { + table = strings.ReplaceAll(table, "%", timestamp) + testTableSprockets = table + } else if !offline { + testTableSprockets = testTableWidgets + "-Sprockets" + } + + if !offline && testTableWidgets == testTableSprockets { + panic(fmt.Sprintf("DYNAMO_TEST_TABLE must not equal DYNAMO_TEST_TABLE2. got DYNAMO_TEST_TABLE=%q and DYNAMO_TEST_TABLE2=%q", + testTableWidgets, testTableSprockets)) + } + + var shouldCreate bool + switch os.Getenv("DYNAMO_TEST_CREATE_TABLE") { + case "1", "true", "yes": + shouldCreate = true + case "0", "false", "no": + shouldCreate = false + default: + shouldCreate = endpoint != nil } var created []Table if testDB != nil { - table := testDB.Table(testTable) - log.Println("Checking test table:", testTable) - _, err := table.Describe().Run() - switch { - case isTableNotExistsErr(err) && endpoint != nil: - log.Println("Creating test table:", testTable) - if err := testDB.CreateTable(testTable, widget{}).Run(); err != nil { + for _, name := range []string{testTableWidgets, testTableSprockets} { + table := testDB.Table(name) + log.Println("Checking test table:", name) + _, err := table.Describe().Run() + switch { + case isTableNotExistsErr(err) && shouldCreate: + log.Println("Creating test table:", name) + if err := testDB.CreateTable(name, widget{}).Run(); err != nil { + panic(err) + } + created = append(created, testDB.Table(name)) + case err != nil: panic(err) } - created = append(created, testDB.Table(testTable)) - case err != nil: - panic(err) } } @@ -98,13 +134,13 @@ func TestListTables(t *testing.T) { found := false for _, t := range tables { - if t == testTable { + if t == testTableWidgets { found = true break } } if !found { - t.Error("couldn't find testTable", testTable, "in:", tables) + t.Error("couldn't find testTable", testTableWidgets, "in:", tables) } } diff --git a/delete_test.go b/delete_test.go index 751565d..9d11dd7 100644 --- a/delete_test.go +++ b/delete_test.go @@ -10,7 +10,7 @@ func TestDelete(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // first, add an item to delete later item := widget{ diff --git a/describetable_test.go b/describetable_test.go index 34bc50e..9e22173 100644 --- a/describetable_test.go +++ b/describetable_test.go @@ -8,7 +8,7 @@ func TestDescribeTable(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) desc, err := table.Describe().Run() if err != nil { @@ -16,8 +16,8 @@ func TestDescribeTable(t *testing.T) { return } - if desc.Name != testTable { - t.Error("wrong name:", desc.Name, "≠", testTable) + if desc.Name != testTableWidgets { + t.Error("wrong name:", desc.Name, "≠", testTableWidgets) } if desc.HashKey != "UserID" || desc.RangeKey != "Time" { t.Error("bad keys:", desc.HashKey, desc.RangeKey) diff --git a/put_test.go b/put_test.go index f0fd48a..fe2ad63 100644 --- a/put_test.go +++ b/put_test.go @@ -12,7 +12,7 @@ func TestPut(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) type widget2 struct { widget @@ -62,7 +62,7 @@ func TestPut(t *testing.T) { t.Errorf("bad old value. %#v ≠ %#v", oldValue, item) } - if cc.Total < 1 || cc.Table < 1 || cc.TableName != testTable { + if cc.Total < 1 || cc.Table < 1 || cc.TableName != testTableWidgets { t.Errorf("bad consumed capacity: %#v", cc) } @@ -77,7 +77,7 @@ func TestPutAndQueryAWSEncoding(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) type awsWidget struct { XUserID int `dynamodbav:"UserID"` diff --git a/query_test.go b/query_test.go index 398a359..f1d7992 100644 --- a/query_test.go +++ b/query_test.go @@ -13,7 +13,7 @@ func TestGetAllCount(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // first, add an item to make sure there is at least one item := widget{ @@ -146,7 +146,7 @@ func TestQueryPaging(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) widgets := []interface{}{ widget{ @@ -193,7 +193,7 @@ func TestQueryMagicLEK(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) widgets := []interface{}{ widget{ @@ -276,7 +276,7 @@ func TestQueryBadKeys(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) t.Run("hash key", func(t *testing.T) { var v interface{} diff --git a/scan_test.go b/scan_test.go index df0a056..9941b6a 100644 --- a/scan_test.go +++ b/scan_test.go @@ -13,7 +13,7 @@ func TestScan(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // first, add an item to make sure there is at least one item := widget{ @@ -107,7 +107,7 @@ func TestScanPaging(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // prepare data insert := make([]interface{}, 10) @@ -168,7 +168,7 @@ func TestScanMagicLEK(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) widgets := []interface{}{ widget{ diff --git a/ttl_test.go b/ttl_test.go index 9ded4e3..9ffcafc 100644 --- a/ttl_test.go +++ b/ttl_test.go @@ -8,7 +8,7 @@ func TestDescribeTTL(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) desc, err := table.DescribeTTL().Run() if err != nil { diff --git a/tx_test.go b/tx_test.go index 626d026..0fd28e0 100644 --- a/tx_test.go +++ b/tx_test.go @@ -21,7 +21,7 @@ func TestTx(t *testing.T) { widget2 := widget{UserID: 69, Time: date2, Msg: "cat"} // basic write & check - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) tx := testDB.WriteTx() var cc, ccold ConsumedCapacity tx.Idempotent(true) @@ -184,7 +184,7 @@ func TestTxRetry(t *testing.T) { date1 := time.Date(1999, 1, 1, 1, 1, 1, 0, time.UTC) widget1 := widget{UserID: 69, Time: date1, Msg: "dog", Count: 0} - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) if err := table.Put(widget1).Run(); err != nil { t.Fatal(err) } diff --git a/update_test.go b/update_test.go index 56dce4d..ae22824 100644 --- a/update_test.go +++ b/update_test.go @@ -13,7 +13,7 @@ func TestUpdate(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) type widget2 struct { widget @@ -168,7 +168,7 @@ func TestUpdateNil(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) // first, add an item to make sure there is at least one item := widget{ @@ -223,7 +223,7 @@ func TestUpdateSetAutoOmit(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) type widget2 struct { widget diff --git a/updatetable_test.go b/updatetable_test.go index 04c2c1b..472b641 100644 --- a/updatetable_test.go +++ b/updatetable_test.go @@ -9,7 +9,7 @@ func _TestUpdateTable(t *testing.T) { if testDB == nil { t.Skip(offlineSkipMsg) } - table := testDB.Table(testTable) + table := testDB.Table(testTableWidgets) desc, err := table.UpdateTable().CreateIndex(Index{ Name: "test123", @@ -32,8 +32,8 @@ func _TestUpdateTable(t *testing.T) { if err != nil { t.Error(err) } - if desc.Name != testTable { - t.Error("wrong name:", desc.Name, "≠", testTable) + if desc.Name != testTableWidgets { + t.Error("wrong name:", desc.Name, "≠", testTableWidgets) } if desc.Status != UpdatingStatus { t.Error("bad status:", desc.Status, "≠", UpdatingStatus)