diff --git a/client.go b/client.go index 8c3dc91..d380289 100644 --- a/client.go +++ b/client.go @@ -251,6 +251,12 @@ func WithSessionMonitor(safeTime time.Duration, callback func()) AcquireLockOpti // AcquireLock holds the defined lock. func (c *Client) AcquireLock(key string, opts ...AcquireLockOption) (*Lock, error) { + return c.AcquireLockWithContext(context.Background(), key, opts...) +} + +// AcquireLockWithContext holds the defined lock. The given context is passed +// down to the underlying dynamoDB call. +func (c *Client) AcquireLockWithContext(ctx context.Context, key string, opts ...AcquireLockOption) (*Lock, error) { if c.isClosed() { return nil, ErrClientClosed } @@ -260,10 +266,10 @@ func (c *Client) AcquireLock(key string, opts ...AcquireLockOption) (*Lock, erro for _, opt := range opts { opt(req) } - return c.acquireLock(req) + return c.acquireLock(ctx, req) } -func (c *Client) acquireLock(opt *acquireLockOptions) (*Lock, error) { +func (c *Client) acquireLock(ctx context.Context, opt *acquireLockOptions) (*Lock, error) { // Hold the read lock when acquiring locks. This prevents us from // acquiring a lock while the Client is being closed as we hold the // write lock during close. @@ -311,7 +317,10 @@ func (c *Client) acquireLock(opt *acquireLockOptions) (*Lock, error) { } for { - l, err := c.storeLock(&getLockOptions) + if err := ctx.Err(); err != nil { + return nil, err + } + l, err := c.storeLock(ctx, &getLockOptions) if err != nil { return nil, err } else if l != nil { @@ -322,10 +331,10 @@ func (c *Client) acquireLock(opt *acquireLockOptions) (*Lock, error) { } } -func (c *Client) storeLock(getLockOptions *getLockOptions) (*Lock, error) { +func (c *Client) storeLock(ctx context.Context, getLockOptions *getLockOptions) (*Lock, error) { c.logger.Println("Call GetItem to see if the lock for ", c.partitionKeyName, " =", getLockOptions.partitionKeyName, " exists in the table") - existingLock, err := c.getLockFromDynamoDB(*getLockOptions) + existingLock, err := c.getLockFromDynamoDB(ctx, *getLockOptions) if err != nil { return nil, err } @@ -369,6 +378,7 @@ func (c *Client) storeLock(getLockOptions *getLockOptions) (*Lock, error) { //if the existing lock does not exist or exists and is released if existingLock == nil || existingLock.isReleased { l, err := c.upsertAndMonitorNewOrReleasedLock( + ctx, getLockOptions.additionalAttributes, getLockOptions.partitionKeyName, getLockOptions.deleteLockOnRelease, @@ -406,6 +416,7 @@ func (c *Client) storeLock(getLockOptions *getLockOptions) (*Lock, error) { } else if getLockOptions.lockTryingToBeAcquired.recordVersionNumber == existingLock.recordVersionNumber && getLockOptions.lockTryingToBeAcquired.isExpired() { /* If the version numbers match, then we can acquire the lock, assuming it has already expired */ l, err := c.upsertAndMonitorExpiredLock( + ctx, getLockOptions.additionalAttributes, getLockOptions.partitionKeyName, getLockOptions.deleteLockOnRelease, @@ -434,6 +445,7 @@ func (c *Client) storeLock(getLockOptions *getLockOptions) (*Lock, error) { } func (c *Client) upsertAndMonitorExpiredLock( + ctx context.Context, additionalAttributes map[string]*dynamodb.AttributeValue, key string, deleteLockOnRelease bool, @@ -459,11 +471,12 @@ func (c *Client) upsertAndMonitorExpiredLock( c.logger.Println("Acquiring an existing lock whose revisionVersionNumber did not change for ", c.partitionKeyName, " partitionKeyName=", key) return c.putLockItemAndStartSessionMonitor( - additionalAttributes, key, deleteLockOnRelease, newLockData, + ctx, additionalAttributes, key, deleteLockOnRelease, newLockData, recordVersionNumber, sessionMonitor, putItemRequest) } func (c *Client) upsertAndMonitorNewOrReleasedLock( + ctx context.Context, additionalAttributes map[string]*dynamodb.AttributeValue, key string, deleteLockOnRelease bool, @@ -494,12 +507,13 @@ func (c *Client) upsertAndMonitorNewOrReleasedLock( // expire sooner than it actually will, so they start counting towards // its expiration before the Put succeeds c.logger.Println("Acquiring a new lock or an existing yet released lock on ", c.partitionKeyName, "=", key) - return c.putLockItemAndStartSessionMonitor(additionalAttributes, key, + return c.putLockItemAndStartSessionMonitor(ctx, additionalAttributes, key, deleteLockOnRelease, newLockData, recordVersionNumber, sessionMonitor, req) } func (c *Client) putLockItemAndStartSessionMonitor( + ctx context.Context, additionalAttributes map[string]*dynamodb.AttributeValue, key string, deleteLockOnRelease bool, @@ -510,7 +524,7 @@ func (c *Client) putLockItemAndStartSessionMonitor( lastUpdatedTime := time.Now() - _, err := c.dynamoDB.PutItem(putItemRequest) + _, err := c.dynamoDB.PutItemWithContext(ctx, putItemRequest) if err != nil { return nil, parseDynamoDBError(err, "cannot store lock item: lock already acquired by other client") } @@ -533,8 +547,8 @@ func (c *Client) putLockItemAndStartSessionMonitor( return lockItem, nil } -func (c *Client) getLockFromDynamoDB(opt getLockOptions) (*Lock, error) { - res, err := c.readFromDynamoDB(opt.partitionKeyName) +func (c *Client) getLockFromDynamoDB(ctx context.Context, opt getLockOptions) (*Lock, error) { + res, err := c.readFromDynamoDB(ctx, opt.partitionKeyName) if err != nil { return nil, err } @@ -547,11 +561,11 @@ func (c *Client) getLockFromDynamoDB(opt getLockOptions) (*Lock, error) { return c.createLockItem(opt, item) } -func (c *Client) readFromDynamoDB(key string) (*dynamodb.GetItemOutput, error) { +func (c *Client) readFromDynamoDB(ctx context.Context, key string) (*dynamodb.GetItemOutput, error) { dynamoDBKey := map[string]*dynamodb.AttributeValue{ c.partitionKeyName: {S: aws.String(key)}, } - return c.dynamoDB.GetItem(&dynamodb.GetItemInput{ + return c.dynamoDB.GetItemWithContext(ctx, &dynamodb.GetItemInput{ ConsistentRead: aws.Bool(true), TableName: aws.String(c.tableName), Key: dynamoDBKey, @@ -648,6 +662,15 @@ func (c *Client) heartbeat(ctx context.Context) { // takes a few minutes for DynamoDB to provision a new instance. Also, if the // table already exists, it will return an error. func (c *Client) CreateTable(tableName string, opts ...CreateTableOption) (*dynamodb.CreateTableOutput, error) { + return c.CreateTableWithContext(context.Background(), tableName, opts...) +} + +// CreateTableWithContext prepares a DynamoDB table with the right schema for it +// to be used by this locking library. The table should be set up in advance, +// because it takes a few minutes for DynamoDB to provision a new instance. +// Also, if the table already exists, it will return an error. The given context +// is passed down to the underlying dynamoDB call. +func (c *Client) CreateTableWithContext(ctx context.Context, tableName string, opts ...CreateTableOption) (*dynamodb.CreateTableOutput, error) { if c.isClosed() { return nil, ErrClientClosed } @@ -659,7 +682,7 @@ func (c *Client) CreateTable(tableName string, opts ...CreateTableOption) (*dyna for _, opt := range opts { opt(createTableOptions) } - return c.createTable(createTableOptions) + return c.createTable(ctx, createTableOptions) } // CreateTableOption is an options type for the CreateTable method in the lock @@ -692,7 +715,7 @@ func WithProvisionedThroughput(provisionedThroughput *dynamodb.ProvisionedThroug } } -func (c *Client) createTable(opt *createDynamoDBTableOptions) (*dynamodb.CreateTableOutput, error) { +func (c *Client) createTable(ctx context.Context, opt *createDynamoDBTableOptions) (*dynamodb.CreateTableOutput, error) { keySchema := []*dynamodb.KeySchemaElement{ { AttributeName: aws.String(opt.partitionKeyName), @@ -722,7 +745,7 @@ func (c *Client) createTable(opt *createDynamoDBTableOptions) (*dynamodb.CreateT createTableInput.Tags = opt.tags } - return c.dynamoDB.CreateTable(createTableInput) + return c.dynamoDB.CreateTableWithContext(ctx, createTableInput) } // ReleaseLock releases the given lock if the current user still has it, @@ -730,10 +753,18 @@ func (c *Client) createTable(opt *createDynamoDBTableOptions) (*dynamodb.CreateT // else already stole the lock or a problem happened. Deletes the lock item if // it is released and deleteLockItemOnClose is set. func (c *Client) ReleaseLock(lockItem *Lock, opts ...ReleaseLockOption) (bool, error) { + return c.ReleaseLockWithContext(context.Background(), lockItem, opts...) +} + +// ReleaseLockWithContext releases the given lock if the current user still has it, +// returning true if the lock was successfully released, and false if someone +// else already stole the lock or a problem happened. Deletes the lock item if +// it is released and deleteLockItemOnClose is set. +func (c *Client) ReleaseLockWithContext(ctx context.Context, lockItem *Lock, opts ...ReleaseLockOption) (bool, error) { if c.isClosed() { return false, ErrClientClosed } - err := c.releaseLock(lockItem, opts...) + err := c.releaseLock(ctx, lockItem, opts...) return err == nil, err } @@ -771,7 +802,7 @@ func ownershipLockCondition(partitionKeyName, recordVersionNumber, ownerName str return cond } -func (c *Client) releaseLock(lockItem *Lock, opts ...ReleaseLockOption) error { +func (c *Client) releaseLock(ctx context.Context, lockItem *Lock, opts ...ReleaseLockOption) error { options := &releaseLockOptions{ lockItem: lockItem, } @@ -801,12 +832,12 @@ func (c *Client) releaseLock(lockItem *Lock, opts ...ReleaseLockOption) error { key := c.getItemKeys(lockItem) ownershipLockCond := ownershipLockCondition(c.partitionKeyName, lockItem.recordVersionNumber, lockItem.ownerName) if deleteLock { - err := c.deleteLock(ownershipLockCond, key) + err := c.deleteLock(ctx, ownershipLockCond, key) if err != nil { return err } } else { - err := c.updateLock(data, ownershipLockCond, key) + err := c.updateLock(ctx, data, ownershipLockCond, key) if err != nil { return err } @@ -815,7 +846,7 @@ func (c *Client) releaseLock(lockItem *Lock, opts ...ReleaseLockOption) error { return nil } -func (c *Client) deleteLock(ownershipLockCond expression.ConditionBuilder, key map[string]*dynamodb.AttributeValue) error { +func (c *Client) deleteLock(ctx context.Context, ownershipLockCond expression.ConditionBuilder, key map[string]*dynamodb.AttributeValue) error { delExpr, _ := expression.NewBuilder().WithCondition(ownershipLockCond).Build() deleteItemRequest := &dynamodb.DeleteItemInput{ TableName: aws.String(c.tableName), @@ -824,14 +855,14 @@ func (c *Client) deleteLock(ownershipLockCond expression.ConditionBuilder, key m ExpressionAttributeNames: delExpr.Names(), ExpressionAttributeValues: delExpr.Values(), } - _, err := c.dynamoDB.DeleteItem(deleteItemRequest) + _, err := c.dynamoDB.DeleteItemWithContext(ctx, deleteItemRequest) if err != nil { return err } return nil } -func (c *Client) updateLock(data []byte, ownershipLockCond expression.ConditionBuilder, key map[string]*dynamodb.AttributeValue) error { +func (c *Client) updateLock(ctx context.Context, data []byte, ownershipLockCond expression.ConditionBuilder, key map[string]*dynamodb.AttributeValue) error { update := expression.Set(isReleasedAttr, isReleasedAttrVal) if len(data) > 0 { update = update.Set(dataAttr, expression.Value(data)) @@ -847,14 +878,14 @@ func (c *Client) updateLock(data []byte, ownershipLockCond expression.ConditionB ExpressionAttributeValues: updateExpr.Values(), } - _, err := c.dynamoDB.UpdateItem(updateItemRequest) + _, err := c.dynamoDB.UpdateItemWithContext(ctx, updateItemRequest) return err } -func (c *Client) releaseAllLocks() error { +func (c *Client) releaseAllLocks(ctx context.Context) error { var err error c.locks.Range(func(key interface{}, value interface{}) bool { - err = c.releaseLock(value.(*Lock)) + err = c.releaseLock(ctx, value.(*Lock)) return err == nil }) return err @@ -875,22 +906,37 @@ func (c *Client) getItemKeys(lockItem *Lock) map[string]*dynamodb.AttributeValue // should check lockItem.isExpired() to figure out if it currently has the // lock.) func (c *Client) Get(key string) (*Lock, error) { + return c.GetWithContext(context.Background(), key) +} + +// GetWithContext finds out who owns the given lock, but does not acquire the +// lock. It returns the metadata currently associated with the given lock. If +// the client currently has the lock, it will return the lock, and operations +// such as releaseLock will work. However, if the client does not have the lock, +// then operations like releaseLock will not work (after calling Get, the caller +// should check lockItem.isExpired() to figure out if it currently has the +// lock.) If the context is canceled, it is going to return the context error +// on local cache hit. The given context is passed down to the underlying +// dynamoDB call. +func (c *Client) GetWithContext(ctx context.Context, key string) (*Lock, error) { if c.isClosed() { return nil, ErrClientClosed } + if err := ctx.Err(); err != nil { + return nil, err + } + getLockOption := getLockOptions{ partitionKeyName: key, } - keyName := getLockOption.partitionKeyName - v, ok := c.locks.Load(keyName) if ok { return v.(*Lock), nil } - lockItem, err := c.getLockFromDynamoDB(getLockOption) + lockItem, err := c.getLockFromDynamoDB(ctx, getLockOption) if err != nil { return nil, err } @@ -916,13 +962,19 @@ func (c *Client) isClosed() bool { // Close releases all of the locks. func (c *Client) Close() error { + return c.CloseWithContext(context.Background()) +} + +// CloseWithContext releases all of the locks. The given context is passed down +// to the underlying dynamoDB calls. +func (c *Client) CloseWithContext(ctx context.Context) error { err := ErrClientClosed c.closeOnce.Do(func() { // Hold the write lock for the duration of the close operation // to prevent new locks from being acquired. c.mu.Lock() defer c.mu.Unlock() - err = c.releaseAllLocks() + err = c.releaseAllLocks(context.Background()) c.stopHeartbeat() c.closed = true }) diff --git a/client_heartbeat.go b/client_heartbeat.go index 126fdc5..f873786 100644 --- a/client_heartbeat.go +++ b/client_heartbeat.go @@ -17,6 +17,7 @@ limitations under the License. package dynamolock import ( + "context" "time" "github.com/aws/aws-sdk-go/aws" @@ -49,12 +50,22 @@ func ReplaceHeartbeatData(data []byte) SendHeartbeatOption { } } -// SendHeartbeat indicatee that the given lock is still being worked on. If +// SendHeartbeat indicates that the given lock is still being worked on. If // using WithHeartbeatPeriod > 0 when setting up this object, then this method // is unnecessary, because the background thread will be periodically calling it // and sending heartbeats. However, if WithHeartbeatPeriod = 0, then this method // must be called to instruct DynamoDB that the lock should not be expired. func (c *Client) SendHeartbeat(lockItem *Lock, opts ...SendHeartbeatOption) error { + return c.SendHeartbeatWithContext(context.Background(), lockItem, opts...) +} + +// SendHeartbeatWithContext indicates that the given lock is still being worked +// on. If using WithHeartbeatPeriod > 0 when setting up this object, then this +// method is unnecessary, because the background thread will be periodically +// calling it and sending heartbeats. However, if WithHeartbeatPeriod = 0, then +// this method must be called to instruct DynamoDB that the lock should not be +// expired. The given context is passed down to the underlying dynamoDB call. +func (c *Client) SendHeartbeatWithContext(ctx context.Context, lockItem *Lock, opts ...SendHeartbeatOption) error { if c.isClosed() { return ErrClientClosed } @@ -64,10 +75,10 @@ func (c *Client) SendHeartbeat(lockItem *Lock, opts ...SendHeartbeatOption) erro for _, opt := range opts { opt(sho) } - return c.sendHeartbeat(sho) + return c.sendHeartbeat(ctx, sho) } -func (c *Client) sendHeartbeat(options *sendHeartbeatOptions) error { +func (c *Client) sendHeartbeat(ctx context.Context, options *sendHeartbeatOptions) error { leaseDuration := c.leaseDuration lockItem := options.lockItem @@ -109,7 +120,7 @@ func (c *Client) sendHeartbeat(options *sendHeartbeatOptions) error { lastUpdateOfLock := time.Now() - _, err := c.dynamoDB.UpdateItem(updateItemInput) + _, err := c.dynamoDB.UpdateItemWithContext(ctx, updateItemInput) if err != nil { err := parseDynamoDBError(err, "already acquired lock, stopping heartbeats") if isLockNotGrantedError(err) { diff --git a/client_internal_test.go b/client_internal_test.go index a2b2668..83449e7 100644 --- a/client_internal_test.go +++ b/client_internal_test.go @@ -17,6 +17,7 @@ limitations under the License. package dynamolock import ( + "context" "fmt" "strconv" "sync" @@ -24,6 +25,7 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" ) @@ -31,13 +33,14 @@ import ( type mockDynamoDBClient struct { dynamodbiface.DynamoDBAPI } -func (m *mockDynamoDBClient) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) { + +func (m *mockDynamoDBClient) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, _ ...request.Option) (*dynamodb.PutItemOutput, error) { return &dynamodb.PutItemOutput{}, nil } -func (m *mockDynamoDBClient) GetItem(input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) { +func (m *mockDynamoDBClient) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, _ ...request.Option) (*dynamodb.GetItemOutput, error) { return &dynamodb.GetItemOutput{}, nil } -func (m *mockDynamoDBClient) UpdateItem(input *dynamodb.UpdateItemInput) (*dynamodb.UpdateItemOutput, error) { +func (m *mockDynamoDBClient) UpdateItemWithContext(ctx context.Context, input *dynamodb.UpdateItemInput, _ ...request.Option) (*dynamodb.UpdateItemOutput, error) { return &dynamodb.UpdateItemOutput{}, nil } @@ -45,7 +48,7 @@ func (m *mockDynamoDBClient) UpdateItem(input *dynamodb.UpdateItemInput) (*dynam This test checks for lock leaks during closing, that is, to make sure that no locks are able to be acquired while the client is closing, and to ensure that we don't have any locks in the internal lock map after a client is closed. - */ +*/ func TestCloseRace(t *testing.T) { mockSvc := &mockDynamoDBClient{} diff --git a/client_test.go b/client_test.go index 304ec35..c1f4f58 100644 --- a/client_test.go +++ b/client_test.go @@ -18,6 +18,7 @@ package dynamolock_test import ( "bytes" + "context" "errors" "fmt" "log" @@ -29,6 +30,7 @@ import ( "cirello.io/dynamolock" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" @@ -1015,7 +1017,7 @@ type fakeDynamoDB struct { dynamodbiface.DynamoDBAPI } -func (f *fakeDynamoDB) GetItem(*dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) { +func (f *fakeDynamoDB) GetItemWithContext(context.Context, *dynamodb.GetItemInput, ...request.Option) (*dynamodb.GetItemOutput, error) { return nil, errors.New("service is offline") }