Skip to content

Commit 1a357e1

Browse files
author
Patrick Robinson
authored
Merge pull request #4 from patrobinson/race-condition-get-lease
Expose race conditions in GetLease
2 parents fb5ee54 + 6efb6de commit 1a357e1

7 files changed

+158
-58
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ before_install:
1111
- docker pull deangiberson/aws-dynamodb-local
1212
- docker pull dlsniper/kinesalite
1313
install: make get
14-
script: make travis-integration
14+
script: make docker-integration

Makefile

+2-4
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ integration: get
1010
@go test -timeout 30s -tags=integration
1111

1212
docker-integration:
13-
@docker-compose run --rm gokini make integration
14-
15-
travis-integration:
1613
@docker-compose up -d
1714
@sleep 10
18-
@go test -timeout 30s -tags=integration
15+
@docker-compose run gokini make integration
16+
@docker-compose down

checkpointer.go

+39-43
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ import (
77

88
"github.com/aws/aws-sdk-go/aws"
99
"github.com/aws/aws-sdk-go/aws/awserr"
10+
"github.com/aws/aws-sdk-go/aws/client"
1011
"github.com/aws/aws-sdk-go/aws/session"
1112
"github.com/aws/aws-sdk-go/service/dynamodb"
1213
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
13-
"github.com/matryer/try"
1414
log "github.com/sirupsen/logrus"
1515
)
1616

@@ -35,17 +35,19 @@ var ErrSequenceIDNotFound = errors.New("SequenceIDNotFoundForShard")
3535

3636
// DynamoCheckpoint implements the Checkpoint interface using DynamoDB as a backend
3737
type DynamoCheckpoint struct {
38-
TableName string
39-
LeaseDuration int
40-
svc dynamodbiface.DynamoDBAPI
41-
Retries int
38+
TableName string
39+
LeaseDuration int
40+
Retries int
41+
svc dynamodbiface.DynamoDBAPI
42+
skipTableCheck bool
4243
}
4344

4445
// Init initialises the DynamoDB Checkpoint
4546
func (checkpointer *DynamoCheckpoint) Init() error {
4647
log.Debug("Creating DynamoDB session")
4748
session, err := session.NewSessionWithOptions(
4849
session.Options{
50+
Config: aws.Config{Retryer: client.DefaultRetryer{NumMaxRetries: checkpointer.Retries}},
4951
SharedConfigState: session.SharedConfigEnable,
5052
},
5153
)
@@ -54,7 +56,8 @@ func (checkpointer *DynamoCheckpoint) Init() error {
5456
}
5557

5658
if endpoint := os.Getenv("DYNAMODB_ENDPOINT"); endpoint != "" {
57-
session.Config.Endpoint = aws.String(endpoint)
59+
log.Infof("Using dynamodb endpoint from environment %s", endpoint)
60+
session.Config.Endpoint = &endpoint
5861
}
5962

6063
checkpointer.svc = dynamodb.New(session)
@@ -63,7 +66,7 @@ func (checkpointer *DynamoCheckpoint) Init() error {
6366
checkpointer.LeaseDuration = defaultLeaseDuration
6467
}
6568

66-
if !checkpointer.doesTableExist() {
69+
if !checkpointer.skipTableCheck && !checkpointer.doesTableExist() {
6770
return checkpointer.createTable()
6871
}
6972
return nil
@@ -84,6 +87,14 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *shardStatus, newAssignTo s
8487
var expressionAttributeValues map[string]*dynamodb.AttributeValue
8588
if !leaseTimeoutOk || !assignedToOk {
8689
conditionalExpression = "attribute_not_exists(AssignedTo)"
90+
if shard.Checkpoint != "" {
91+
conditionalExpression = conditionalExpression + " AND SequenceID = :id"
92+
expressionAttributeValues = map[string]*dynamodb.AttributeValue{
93+
":id": {
94+
S: &shard.Checkpoint,
95+
},
96+
}
97+
}
8798
} else {
8899
assignedTo := *assignedVar.S
89100
leaseTimeout := *leaseVar.S
@@ -108,6 +119,12 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *shardStatus, newAssignTo s
108119
S: &leaseTimeout,
109120
},
110121
}
122+
if shard.Checkpoint != "" {
123+
conditionalExpression = conditionalExpression + " AND SequenceID = :sid"
124+
expressionAttributeValues[":sid"] = &dynamodb.AttributeValue{
125+
S: &shard.Checkpoint,
126+
}
127+
}
111128
}
112129

113130
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
@@ -122,6 +139,10 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *shardStatus, newAssignTo s
122139
},
123140
}
124141

142+
if shard.Checkpoint != "" {
143+
marshalledCheckpoint["SequenceID"] = &dynamodb.AttributeValue{S: &shard.Checkpoint}
144+
}
145+
125146
if shard.Checkpoint != "" {
126147
marshalledCheckpoint["Checkpoint"] = &dynamodb.AttributeValue{
127148
S: &shard.Checkpoint,
@@ -229,51 +250,26 @@ func (checkpointer *DynamoCheckpoint) saveItem(item map[string]*dynamodb.Attribu
229250

230251
func (checkpointer *DynamoCheckpoint) conditionalUpdate(conditionExpression string, expressionAttributeValues map[string]*dynamodb.AttributeValue, item map[string]*dynamodb.AttributeValue) error {
231252
return checkpointer.putItem(&dynamodb.PutItemInput{
232-
ConditionExpression: aws.String(conditionExpression),
233-
TableName: aws.String(checkpointer.TableName),
234-
Item: item,
253+
ConditionExpression: aws.String(conditionExpression),
254+
TableName: aws.String(checkpointer.TableName),
255+
Item: item,
235256
ExpressionAttributeValues: expressionAttributeValues,
236257
})
237258
}
238259

239260
func (checkpointer *DynamoCheckpoint) putItem(input *dynamodb.PutItemInput) error {
240-
return try.Do(func(attempt int) (bool, error) {
241-
_, err := checkpointer.svc.PutItem(input)
242-
if awsErr, ok := err.(awserr.Error); ok {
243-
if awsErr.Code() == dynamodb.ErrCodeProvisionedThroughputExceededException ||
244-
awsErr.Code() == dynamodb.ErrCodeInternalServerError &&
245-
attempt < checkpointer.Retries {
246-
// Backoff time as recommended by https://docs.aws.amazon.com/general/latest/gr/api-retries.html
247-
time.Sleep(time.Duration(2^attempt*100) * time.Millisecond)
248-
return true, err
249-
}
250-
}
251-
return false, err
252-
})
261+
_, err := checkpointer.svc.PutItem(input)
262+
return err
253263
}
254264

255265
func (checkpointer *DynamoCheckpoint) getItem(shardID string) (map[string]*dynamodb.AttributeValue, error) {
256-
var item *dynamodb.GetItemOutput
257-
err := try.Do(func(attempt int) (bool, error) {
258-
var err error
259-
item, err = checkpointer.svc.GetItem(&dynamodb.GetItemInput{
260-
TableName: aws.String(checkpointer.TableName),
261-
Key: map[string]*dynamodb.AttributeValue{
262-
"ShardID": {
263-
S: aws.String(shardID),
264-
},
266+
item, err := checkpointer.svc.GetItem(&dynamodb.GetItemInput{
267+
TableName: aws.String(checkpointer.TableName),
268+
Key: map[string]*dynamodb.AttributeValue{
269+
"ShardID": {
270+
S: aws.String(shardID),
265271
},
266-
})
267-
if awsErr, ok := err.(awserr.Error); ok {
268-
if awsErr.Code() == dynamodb.ErrCodeProvisionedThroughputExceededException ||
269-
awsErr.Code() == dynamodb.ErrCodeInternalServerError &&
270-
attempt < checkpointer.Retries {
271-
// Backoff time as recommended by https://docs.aws.amazon.com/general/latest/gr/api-retries.html
272-
time.Sleep(time.Duration(2^attempt*100) * time.Millisecond)
273-
return true, err
274-
}
275-
}
276-
return false, err
272+
},
277273
})
278274
return item.Item, err
279275
}

checkpointer_integration_test.go

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//+build integration
2+
3+
package gokini
4+
5+
import (
6+
"sync"
7+
"testing"
8+
"time"
9+
10+
"github.com/aws/aws-sdk-go/aws"
11+
"github.com/aws/aws-sdk-go/service/dynamodb"
12+
)
13+
14+
func TestRaceCondGetLeaseTimeout(t *testing.T) {
15+
checkpoint := &DynamoCheckpoint{
16+
TableName: "TableName",
17+
}
18+
checkpoint.Init()
19+
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
20+
"ShardID": {
21+
S: aws.String("0001"),
22+
},
23+
"AssignedTo": {
24+
S: aws.String("abcd-efgh"),
25+
},
26+
"LeaseTimeout": {
27+
S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)),
28+
},
29+
"SequenceID": {
30+
S: aws.String("deadbeef"),
31+
},
32+
}
33+
input := &dynamodb.PutItemInput{
34+
TableName: aws.String("TableName"),
35+
Item: marshalledCheckpoint,
36+
}
37+
_, err := checkpoint.svc.PutItem(input)
38+
if err != nil {
39+
t.Fatalf("Error writing to dynamo %s", err)
40+
}
41+
shard := &shardStatus{
42+
ID: "0001",
43+
Checkpoint: "TestRaceCondGetLeaseTimeout",
44+
mux: &sync.Mutex{},
45+
}
46+
err = checkpoint.GetLease(shard, "ijkl-mnop")
47+
48+
if err == nil || err.Error() != ErrLeaseNotAquired {
49+
t.Error("Got a lease when checkpoints didn't match. Potentially we stomped on the checkpoint")
50+
}
51+
}
52+
func TestRaceCondGetLeaseNoAssignee(t *testing.T) {
53+
checkpoint := &DynamoCheckpoint{
54+
TableName: "TableName",
55+
}
56+
checkpoint.Init()
57+
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
58+
"ShardID": {
59+
S: aws.String("0001"),
60+
},
61+
"SequenceID": {
62+
S: aws.String("deadbeef"),
63+
},
64+
}
65+
input := &dynamodb.PutItemInput{
66+
TableName: aws.String("TableName"),
67+
Item: marshalledCheckpoint,
68+
}
69+
_, err := checkpoint.svc.PutItem(input)
70+
if err != nil {
71+
t.Fatalf("Error writing to dynamo %s", err)
72+
}
73+
shard := &shardStatus{
74+
ID: "0001",
75+
Checkpoint: "TestRaceCondGetLeaseNoAssignee",
76+
mux: &sync.Mutex{},
77+
}
78+
err = checkpoint.GetLease(shard, "ijkl-mnop")
79+
80+
if err == nil || err.Error() != ErrLeaseNotAquired {
81+
t.Error("Got a lease when checkpoints didn't match. Potentially we stomped on the checkpoint")
82+
}
83+
}

checkpointer_test.go

+21-8
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ func TestDoesTableExist(t *testing.T) {
5555
func TestGetLeaseNotAquired(t *testing.T) {
5656
svc := &mockDynamoDB{tableExist: true}
5757
checkpoint := &DynamoCheckpoint{
58-
TableName: "TableName",
59-
svc: svc,
58+
TableName: "TableName",
59+
skipTableCheck: true,
6060
}
6161
checkpoint.Init()
6262
checkpoint.svc = svc
@@ -82,8 +82,10 @@ func TestGetLeaseNotAquired(t *testing.T) {
8282
func TestGetLeaseAquired(t *testing.T) {
8383
svc := &mockDynamoDB{tableExist: true}
8484
checkpoint := &DynamoCheckpoint{
85-
TableName: "TableName",
85+
TableName: "TableName",
86+
skipTableCheck: true,
8687
}
88+
checkpoint.svc = svc
8789
checkpoint.Init()
8890
checkpoint.svc = svc
8991
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
@@ -96,28 +98,39 @@ func TestGetLeaseAquired(t *testing.T) {
9698
"LeaseTimeout": {
9799
S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)),
98100
},
101+
"SequenceID": {
102+
S: aws.String("deadbeef"),
103+
},
99104
}
100105
input := &dynamodb.PutItemInput{
101106
TableName: aws.String("TableName"),
102107
Item: marshalledCheckpoint,
103108
}
104109
checkpoint.svc.PutItem(input)
105-
err := checkpoint.GetLease(&shardStatus{
110+
shard := &shardStatus{
106111
ID: "0001",
107-
Checkpoint: "",
112+
Checkpoint: "deadbeef",
108113
mux: &sync.Mutex{},
109-
}, "ijkl-mnop")
114+
}
115+
err := checkpoint.GetLease(shard, "ijkl-mnop")
110116

111117
if err != nil {
112118
t.Errorf("Lease not aquired after timeout %s", err)
113119
}
120+
121+
id, ok := svc.item["SequenceID"]
122+
if !ok {
123+
t.Error("Expected SequenceID to be set by GetLease")
124+
} else if *id.S != "deadbeef" {
125+
t.Errorf("Expected SequenceID to be deadbeef. Got '%s'", *id.S)
126+
}
114127
}
115128

116129
func TestGetLeaseRenewed(t *testing.T) {
117130
svc := &mockDynamoDB{tableExist: true}
118131
checkpoint := &DynamoCheckpoint{
119-
TableName: "TableName",
120-
svc: svc,
132+
TableName: "TableName",
133+
skipTableCheck: true,
121134
}
122135
checkpoint.Init()
123136
checkpoint.svc = svc

consumer.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/aws/aws-sdk-go/aws"
1212
"github.com/aws/aws-sdk-go/aws/awserr"
13+
"github.com/aws/aws-sdk-go/aws/client"
1314
"github.com/aws/aws-sdk-go/aws/session"
1415
"github.com/aws/aws-sdk-go/service/kinesis"
1516
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
@@ -55,6 +56,7 @@ type KinesisConsumer struct {
5556
EmptyRecordBackoffMs int
5657
LeaseDuration int
5758
Monitoring MonitoringConfiguration
59+
Retries *int
5860
svc kinesisiface.KinesisAPI
5961
checkpointer Checkpointer
6062
stop *chan struct{}
@@ -65,6 +67,8 @@ type KinesisConsumer struct {
6567
mService monitoringService
6668
}
6769

70+
var defaultRetries = 5
71+
6872
// StartConsumer starts the RecordConsumer, calls Init and starts sending records to ProcessRecords
6973
func (kc *KinesisConsumer) StartConsumer() error {
7074
// Set Defaults
@@ -81,9 +85,15 @@ func (kc *KinesisConsumer) StartConsumer() error {
8185
kc.mService = kc.Monitoring.service
8286

8387
if kc.svc == nil && kc.checkpointer == nil {
88+
retries := defaultRetries
89+
if kc.Retries != nil {
90+
retries = *kc.Retries
91+
}
92+
8493
log.Debugf("Creating Kinesis Session")
8594
session, err := session.NewSessionWithOptions(
8695
session.Options{
96+
Config: aws.Config{Retryer: client.DefaultRetryer{NumMaxRetries: retries}},
8797
SharedConfigState: session.SharedConfigEnable,
8898
},
8999
)
@@ -96,7 +106,7 @@ func (kc *KinesisConsumer) StartConsumer() error {
96106
kc.svc = kinesis.New(session)
97107
kc.checkpointer = &DynamoCheckpoint{
98108
TableName: kc.TableName,
99-
Retries: 5,
109+
Retries: retries,
100110
LeaseDuration: kc.LeaseDuration,
101111
}
102112
}

docker-compose.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ services:
77
expose:
88
- 4567
99
dynamodb:
10-
image: deangiberson/aws-dynamodb-local
10+
image: amazon/dynamodb-local
1111
ports:
1212
- 8000:8000
1313
expose:

0 commit comments

Comments
 (0)