diff --git a/batchget.go b/batchget.go index 58251dd..995056b 100644 --- a/batchget.go +++ b/batchget.go @@ -348,8 +348,8 @@ redo: return false } if itr.bg.cc != nil { - for _, cc := range itr.output.ConsumedCapacity { - addConsumedCapacity(itr.bg.cc, &cc) + for i := range itr.output.ConsumedCapacity { + itr.bg.cc.add(&itr.output.ConsumedCapacity[i]) } } diff --git a/batchwrite.go b/batchwrite.go index b2d0704..837055a 100644 --- a/batchwrite.go +++ b/batchwrite.go @@ -146,8 +146,8 @@ func (bw *BatchWrite) Run(ctx context.Context) (wrote int, err error) { return wrote, err } if bw.cc != nil { - for _, cc := range res.ConsumedCapacity { - addConsumedCapacity(bw.cc, &cc) + for i := range res.ConsumedCapacity { + bw.cc.add(&res.ConsumedCapacity[i]) } } diff --git a/delete.go b/delete.go index d29f91d..c38e621 100644 --- a/delete.go +++ b/delete.go @@ -111,8 +111,8 @@ func (d *Delete) run(ctx context.Context) (*dynamodb.DeleteItemOutput, error) { d.cc.incRequests() return err }) - if d.cc != nil && output != nil { - addConsumedCapacity(d.cc, output.ConsumedCapacity) + if output != nil { + d.cc.add(output.ConsumedCapacity) } return output, err } diff --git a/put.go b/put.go index 1c26a95..8676a01 100644 --- a/put.go +++ b/put.go @@ -84,8 +84,8 @@ func (p *Put) run(ctx context.Context) (output *dynamodb.PutItemOutput, err erro p.cc.incRequests() return err }) - if p.cc != nil && output != nil { - addConsumedCapacity(p.cc, output.ConsumedCapacity) + if output != nil { + p.cc.add(output.ConsumedCapacity) } return } diff --git a/query.go b/query.go index 61c28fa..ed82375 100644 --- a/query.go +++ b/query.go @@ -233,9 +233,7 @@ func (q *Query) One(ctx context.Context, out interface{}) error { if err != nil { return err } - if q.cc != nil { - addConsumedCapacity(q.cc, res.ConsumedCapacity) - } + q.cc.add(res.ConsumedCapacity) return unmarshalItem(res.Item, out) } @@ -266,9 +264,7 @@ func (q *Query) One(ctx context.Context, out interface{}) error { if err != nil { return err } - if q.cc != nil { - addConsumedCapacity(q.cc, res.ConsumedCapacity) - } + q.cc.add(res.ConsumedCapacity) return unmarshalItem(res.Items[0], out) } @@ -304,9 +300,7 @@ func (q *Query) Count(ctx context.Context) (int, error) { if err != nil { return 0, err } - if q.cc != nil { - addConsumedCapacity(q.cc, res.ConsumedCapacity) - } + q.cc.add(res.ConsumedCapacity) q.startKey = res.LastEvaluatedKey if res.LastEvaluatedKey == nil || @@ -402,9 +396,7 @@ func (itr *queryIter) Next(ctx context.Context, out interface{}) bool { if itr.err != nil { return false } - if itr.query.cc != nil { - addConsumedCapacity(itr.query.cc, itr.output.ConsumedCapacity) - } + itr.query.cc.add(itr.output.ConsumedCapacity) if len(itr.output.LastEvaluatedKey) > len(itr.exLEK) { itr.exLEK = itr.output.LastEvaluatedKey } diff --git a/scan.go b/scan.go index 1069713..26792cf 100644 --- a/scan.go +++ b/scan.go @@ -264,10 +264,7 @@ func (s *Scan) Count(ctx context.Context) (int, error) { count += int(out.Count) scanned += out.ScannedCount - - if s.cc != nil { - addConsumedCapacity(s.cc, out.ConsumedCapacity) - } + s.cc.add(out.ConsumedCapacity) if out.LastEvaluatedKey == nil || (s.limit > 0 && count >= s.limit) || @@ -407,9 +404,7 @@ redo: if itr.err != nil { return false } - if itr.scan.cc != nil { - addConsumedCapacity(itr.scan.cc, itr.output.ConsumedCapacity) - } + itr.scan.cc.add(itr.output.ConsumedCapacity) if len(itr.output.LastEvaluatedKey) > len(itr.exLEK) { itr.exLEK = itr.output.LastEvaluatedKey } diff --git a/table.go b/table.go index 4c54682..e1135fe 100644 --- a/table.go +++ b/table.go @@ -241,7 +241,7 @@ type ConsumedCapacity struct { Requests int } -func addConsumedCapacity(cc *ConsumedCapacity, raw *types.ConsumedCapacity) { +func (cc *ConsumedCapacity) add(raw *types.ConsumedCapacity) { if cc == nil || raw == nil { return } diff --git a/table_test.go b/table_test.go index 99a2f3b..82ff173 100644 --- a/table_test.go +++ b/table_test.go @@ -175,6 +175,15 @@ func TestAddConsumedCapacity(t *testing.T) { if !reflect.DeepEqual(cc, expected) { t.Error("bad ConsumedCapacity:", cc, "≠", expected) } + + t.Run("request count", func(t *testing.T) { + const expectedReqs = 2 + cc.incRequests() + cc.incRequests() + if cc.Requests != expectedReqs { + t.Error("bad Requests count:", cc.Requests, "≠", expectedReqs) + } + }) } func normalizeDesc(desc *Description) { diff --git a/tx.go b/tx.go index 1df1597..7679f31 100644 --- a/tx.go +++ b/tx.go @@ -72,8 +72,8 @@ func (tx *GetTx) Run(ctx context.Context) error { resp, err = tx.db.client.TransactGetItems(ctx, input) tx.cc.incRequests() if tx.cc != nil && resp != nil { - for _, cc := range resp.ConsumedCapacity { - addConsumedCapacity(tx.cc, &cc) + for i := range resp.ConsumedCapacity { + tx.cc.add(&resp.ConsumedCapacity[i]) } } return err @@ -113,8 +113,8 @@ func (tx *GetTx) All(ctx context.Context, out interface{}) error { resp, err = tx.db.client.TransactGetItems(ctx, input) tx.cc.incRequests() if tx.cc != nil && resp != nil { - for _, cc := range resp.ConsumedCapacity { - addConsumedCapacity(tx.cc, &cc) + for i := range resp.ConsumedCapacity { + tx.cc.add(&resp.ConsumedCapacity[i]) } } return err @@ -259,9 +259,9 @@ func (tx *WriteTx) Run(ctx context.Context) error { err = tx.db.retry(ctx, func() error { out, err := tx.db.client.TransactWriteItems(ctx, input) tx.cc.incRequests() - if tx.cc != nil && out != nil { - for _, cc := range out.ConsumedCapacity { - addConsumedCapacity(tx.cc, &cc) + if out != nil { + for i := range out.ConsumedCapacity { + tx.cc.add(&out.ConsumedCapacity[i]) } } return err diff --git a/update.go b/update.go index 286171d..1942efb 100644 --- a/update.go +++ b/update.go @@ -350,8 +350,8 @@ func (u *Update) run(ctx context.Context) (*dynamodb.UpdateItemOutput, error) { u.cc.incRequests() return err }) - if u.cc != nil && output != nil { - addConsumedCapacity(u.cc, output.ConsumedCapacity) + if output != nil { + u.cc.add(output.ConsumedCapacity) } return output, err }