diff --git a/batchget.go b/batchget.go index 0a16d6b..995056b 100644 --- a/batchget.go +++ b/batchget.go @@ -341,14 +341,15 @@ redo: itr.err = itr.bg.batch.table.db.retry(ctx, func() error { var err error itr.output, err = itr.bg.batch.table.db.client.BatchGetItem(ctx, itr.input) + itr.bg.cc.incRequests() return err }) if itr.err != nil { return false } if itr.bg.cc != nil { - for _, cc := range itr.output.ConsumedCapacity { - addConsumedCapacity(itr.bg.cc, &cc) + for i := range itr.output.ConsumedCapacity { + itr.bg.cc.add(&itr.output.ConsumedCapacity[i]) } } diff --git a/batchwrite.go b/batchwrite.go index adf5021..837055a 100644 --- a/batchwrite.go +++ b/batchwrite.go @@ -139,14 +139,15 @@ func (bw *BatchWrite) Run(ctx context.Context) (wrote int, err error) { err := bw.batch.table.db.retry(ctx, func() error { var err error res, err = bw.batch.table.db.client.BatchWriteItem(ctx, req) + bw.cc.incRequests() return err }) if err != nil { return wrote, err } if bw.cc != nil { - for _, cc := range res.ConsumedCapacity { - addConsumedCapacity(bw.cc, &cc) + for i := range res.ConsumedCapacity { + bw.cc.add(&res.ConsumedCapacity[i]) } } diff --git a/delete.go b/delete.go index 67a3710..c38e621 100644 --- a/delete.go +++ b/delete.go @@ -108,10 +108,11 @@ func (d *Delete) run(ctx context.Context) (*dynamodb.DeleteItemOutput, error) { err := d.table.db.retry(ctx, func() error { var err error output, err = d.table.db.client.DeleteItem(ctx, input) + d.cc.incRequests() return err }) - if d.cc != nil && output != nil { - addConsumedCapacity(d.cc, output.ConsumedCapacity) + if output != nil { + d.cc.add(output.ConsumedCapacity) } return output, err } diff --git a/put.go b/put.go index 08fc277..8676a01 100644 --- a/put.go +++ b/put.go @@ -81,10 +81,11 @@ func (p *Put) run(ctx context.Context) (output *dynamodb.PutItemOutput, err erro req := p.input() p.table.db.retry(ctx, func() error { output, err = p.table.db.client.PutItem(ctx, req) + p.cc.incRequests() return err }) - if p.cc != nil && output != nil { - addConsumedCapacity(p.cc, output.ConsumedCapacity) + if output != nil { + p.cc.add(output.ConsumedCapacity) } return } diff --git a/query.go b/query.go index 8e49d5c..ed82375 100644 --- a/query.go +++ b/query.go @@ -221,6 +221,7 @@ func (q *Query) One(ctx context.Context, out interface{}) error { err := q.table.db.retry(ctx, func() error { var err error res, err = q.table.db.client.GetItem(ctx, req) + q.cc.incRequests() if err != nil { return err } @@ -232,9 +233,7 @@ func (q *Query) One(ctx context.Context, out interface{}) error { if err != nil { return err } - if q.cc != nil { - addConsumedCapacity(q.cc, res.ConsumedCapacity) - } + q.cc.add(res.ConsumedCapacity) return unmarshalItem(res.Item, out) } @@ -246,6 +245,7 @@ func (q *Query) One(ctx context.Context, out interface{}) error { err := q.table.db.retry(ctx, func() error { var err error res, err = q.table.db.client.Query(ctx, req) + q.cc.incRequests() if err != nil { return err } @@ -264,9 +264,7 @@ func (q *Query) One(ctx context.Context, out interface{}) error { if err != nil { return err } - if q.cc != nil { - addConsumedCapacity(q.cc, res.ConsumedCapacity) - } + q.cc.add(res.ConsumedCapacity) return unmarshalItem(res.Items[0], out) } @@ -288,6 +286,7 @@ func (q *Query) Count(ctx context.Context) (int, error) { err := q.table.db.retry(ctx, func() error { var err error res, err = q.table.db.client.Query(ctx, input) + q.cc.incRequests() if err != nil { return err } @@ -301,9 +300,7 @@ func (q *Query) Count(ctx context.Context) (int, error) { if err != nil { return 0, err } - if q.cc != nil { - addConsumedCapacity(q.cc, res.ConsumedCapacity) - } + q.cc.add(res.ConsumedCapacity) q.startKey = res.LastEvaluatedKey if res.LastEvaluatedKey == nil || @@ -392,15 +389,14 @@ func (itr *queryIter) Next(ctx context.Context, out interface{}) bool { itr.err = itr.query.table.db.retry(ctx, func() error { var err error itr.output, err = itr.query.table.db.client.Query(ctx, itr.input) + itr.query.cc.incRequests() return err }) if itr.err != nil { return false } - if itr.query.cc != nil { - addConsumedCapacity(itr.query.cc, itr.output.ConsumedCapacity) - } + itr.query.cc.add(itr.output.ConsumedCapacity) if len(itr.output.LastEvaluatedKey) > len(itr.exLEK) { itr.exLEK = itr.output.LastEvaluatedKey } diff --git a/scan.go b/scan.go index f649a7f..26792cf 100644 --- a/scan.go +++ b/scan.go @@ -254,6 +254,7 @@ func (s *Scan) Count(ctx context.Context) (int, error) { err := s.table.db.retry(ctx, func() error { var err error out, err = s.table.db.client.Scan(ctx, input) + s.cc.incRequests() return err }) if err != nil { @@ -263,10 +264,7 @@ func (s *Scan) Count(ctx context.Context) (int, error) { count += int(out.Count) scanned += out.ScannedCount - - if s.cc != nil { - addConsumedCapacity(s.cc, out.ConsumedCapacity) - } + s.cc.add(out.ConsumedCapacity) if out.LastEvaluatedKey == nil || (s.limit > 0 && count >= s.limit) || @@ -399,15 +397,14 @@ redo: itr.err = itr.scan.table.db.retry(ctx, func() error { var err error itr.output, err = itr.scan.table.db.client.Scan(ctx, itr.input) + itr.scan.cc.incRequests() return err }) if itr.err != nil { return false } - if itr.scan.cc != nil { - addConsumedCapacity(itr.scan.cc, itr.output.ConsumedCapacity) - } + itr.scan.cc.add(itr.output.ConsumedCapacity) if len(itr.output.LastEvaluatedKey) > len(itr.exLEK) { itr.exLEK = itr.output.LastEvaluatedKey } diff --git a/table.go b/table.go index 056e991..e1135fe 100644 --- a/table.go +++ b/table.go @@ -207,6 +207,7 @@ type ConsumedCapacity struct { // Write is the total number of write capacity units consumed during this operation. // This seems to be only set for transactions. Write float64 + // GSI is a map of Global Secondary Index names to total consumed capacity units. GSI map[string]float64 // GSIRead is a map of Global Secondary Index names to consumed read capacity units. @@ -215,6 +216,7 @@ type ConsumedCapacity struct { // GSIWrite is a map of Global Secondary Index names to consumed write capacity units. // This seems to be only set for transactions. GSIWrite map[string]float64 + // LSI is a map of Local Secondary Index names to total consumed capacity units. LSI map[string]float64 // LSIRead is a map of Local Secondary Index names to consumed read capacity units. @@ -223,6 +225,7 @@ type ConsumedCapacity struct { // LSIWrite is a map of Local Secondary Index names to consumed write capacity units. // This seems to be only set for transactions. LSIWrite map[string]float64 + // Table is the amount of total throughput consumed by the table. Table float64 // TableRead is the amount of read throughput consumed by the table. @@ -233,9 +236,12 @@ type ConsumedCapacity struct { TableWrite float64 // TableName is the name of the table affected by this operation. TableName string + + // Requests is the number of SDK requests made against DynamoDB's API. + Requests int } -func addConsumedCapacity(cc *ConsumedCapacity, raw *types.ConsumedCapacity) { +func (cc *ConsumedCapacity) add(raw *types.ConsumedCapacity) { if cc == nil || raw == nil { return } @@ -302,6 +308,13 @@ func addConsumedCapacity(cc *ConsumedCapacity, raw *types.ConsumedCapacity) { } } +func (cc *ConsumedCapacity) incRequests() { + if cc == nil { + return + } + cc.Requests++ +} + func mergeConsumedCapacity(dst, src *ConsumedCapacity) { if dst == nil || src == nil { return @@ -363,4 +376,5 @@ func mergeConsumedCapacity(dst, src *ConsumedCapacity) { if dst.TableName == "" && src.TableName != "" { dst.TableName = src.TableName } + dst.Requests += src.Requests } diff --git a/table_test.go b/table_test.go index ca356cc..82ff173 100644 --- a/table_test.go +++ b/table_test.go @@ -170,11 +170,20 @@ func TestAddConsumedCapacity(t *testing.T) { } var cc = new(ConsumedCapacity) - addConsumedCapacity(cc, raw) + cc.add(raw) if !reflect.DeepEqual(cc, expected) { t.Error("bad ConsumedCapacity:", cc, "≠", expected) } + + t.Run("request count", func(t *testing.T) { + const expectedReqs = 2 + cc.incRequests() + cc.incRequests() + if cc.Requests != expectedReqs { + t.Error("bad Requests count:", cc.Requests, "≠", expectedReqs) + } + }) } func normalizeDesc(desc *Description) { diff --git a/tx.go b/tx.go index 1caecd5..7679f31 100644 --- a/tx.go +++ b/tx.go @@ -70,9 +70,10 @@ func (tx *GetTx) Run(ctx context.Context) error { err = tx.db.retry(ctx, func() error { var err error resp, err = tx.db.client.TransactGetItems(ctx, input) + tx.cc.incRequests() if tx.cc != nil && resp != nil { - for _, cc := range resp.ConsumedCapacity { - addConsumedCapacity(tx.cc, &cc) + for i := range resp.ConsumedCapacity { + tx.cc.add(&resp.ConsumedCapacity[i]) } } return err @@ -110,9 +111,10 @@ func (tx *GetTx) All(ctx context.Context, out interface{}) error { err = tx.db.retry(ctx, func() error { var err error resp, err = tx.db.client.TransactGetItems(ctx, input) + tx.cc.incRequests() if tx.cc != nil && resp != nil { - for _, cc := range resp.ConsumedCapacity { - addConsumedCapacity(tx.cc, &cc) + for i := range resp.ConsumedCapacity { + tx.cc.add(&resp.ConsumedCapacity[i]) } } return err @@ -256,9 +258,10 @@ func (tx *WriteTx) Run(ctx context.Context) error { } err = tx.db.retry(ctx, func() error { out, err := tx.db.client.TransactWriteItems(ctx, input) - if tx.cc != nil && out != nil { - for _, cc := range out.ConsumedCapacity { - addConsumedCapacity(tx.cc, &cc) + tx.cc.incRequests() + if out != nil { + for i := range out.ConsumedCapacity { + tx.cc.add(&out.ConsumedCapacity[i]) } } return err diff --git a/update.go b/update.go index d8e6d0d..1942efb 100644 --- a/update.go +++ b/update.go @@ -347,10 +347,11 @@ func (u *Update) run(ctx context.Context) (*dynamodb.UpdateItemOutput, error) { err := u.table.db.retry(ctx, func() error { var err error output, err = u.table.db.client.UpdateItem(ctx, input) + u.cc.incRequests() return err }) - if u.cc != nil && output != nil { - addConsumedCapacity(u.cc, output.ConsumedCapacity) + if output != nil { + u.cc.add(output.ConsumedCapacity) } return output, err }