Skip to content

Commit

Permalink
dynamolock: add context enabled calls
Browse files Browse the repository at this point in the history
Addresses #106
  • Loading branch information
ucirello committed Aug 31, 2020
1 parent 0c8bdc7 commit 829fa9a
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 38 deletions.
110 changes: 81 additions & 29 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ func WithSessionMonitor(safeTime time.Duration, callback func()) AcquireLockOpti

// AcquireLock holds the defined lock.
func (c *Client) AcquireLock(key string, opts ...AcquireLockOption) (*Lock, error) {
return c.AcquireLockWithContext(context.Background(), key, opts...)
}

// AcquireLockWithContext holds the defined lock. The given context is passed
// down to the underlying dynamoDB call.
func (c *Client) AcquireLockWithContext(ctx context.Context, key string, opts ...AcquireLockOption) (*Lock, error) {
if c.isClosed() {
return nil, ErrClientClosed
}
Expand All @@ -260,10 +266,10 @@ func (c *Client) AcquireLock(key string, opts ...AcquireLockOption) (*Lock, erro
for _, opt := range opts {
opt(req)
}
return c.acquireLock(req)
return c.acquireLock(ctx, req)
}

func (c *Client) acquireLock(opt *acquireLockOptions) (*Lock, error) {
func (c *Client) acquireLock(ctx context.Context, opt *acquireLockOptions) (*Lock, error) {
// Hold the read lock when acquiring locks. This prevents us from
// acquiring a lock while the Client is being closed as we hold the
// write lock during close.
Expand Down Expand Up @@ -311,7 +317,10 @@ func (c *Client) acquireLock(opt *acquireLockOptions) (*Lock, error) {
}

for {
l, err := c.storeLock(&getLockOptions)
if err := ctx.Err(); err != nil {
return nil, err
}
l, err := c.storeLock(ctx, &getLockOptions)
if err != nil {
return nil, err
} else if l != nil {
Expand All @@ -322,10 +331,10 @@ func (c *Client) acquireLock(opt *acquireLockOptions) (*Lock, error) {
}
}

func (c *Client) storeLock(getLockOptions *getLockOptions) (*Lock, error) {
func (c *Client) storeLock(ctx context.Context, getLockOptions *getLockOptions) (*Lock, error) {
c.logger.Println("Call GetItem to see if the lock for ",
c.partitionKeyName, " =", getLockOptions.partitionKeyName, " exists in the table")
existingLock, err := c.getLockFromDynamoDB(*getLockOptions)
existingLock, err := c.getLockFromDynamoDB(ctx, *getLockOptions)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -369,6 +378,7 @@ func (c *Client) storeLock(getLockOptions *getLockOptions) (*Lock, error) {
//if the existing lock does not exist or exists and is released
if existingLock == nil || existingLock.isReleased {
l, err := c.upsertAndMonitorNewOrReleasedLock(
ctx,
getLockOptions.additionalAttributes,
getLockOptions.partitionKeyName,
getLockOptions.deleteLockOnRelease,
Expand Down Expand Up @@ -406,6 +416,7 @@ func (c *Client) storeLock(getLockOptions *getLockOptions) (*Lock, error) {
} else if getLockOptions.lockTryingToBeAcquired.recordVersionNumber == existingLock.recordVersionNumber && getLockOptions.lockTryingToBeAcquired.isExpired() {
/* If the version numbers match, then we can acquire the lock, assuming it has already expired */
l, err := c.upsertAndMonitorExpiredLock(
ctx,
getLockOptions.additionalAttributes,
getLockOptions.partitionKeyName,
getLockOptions.deleteLockOnRelease,
Expand Down Expand Up @@ -434,6 +445,7 @@ func (c *Client) storeLock(getLockOptions *getLockOptions) (*Lock, error) {
}

func (c *Client) upsertAndMonitorExpiredLock(
ctx context.Context,
additionalAttributes map[string]*dynamodb.AttributeValue,
key string,
deleteLockOnRelease bool,
Expand All @@ -459,11 +471,12 @@ func (c *Client) upsertAndMonitorExpiredLock(
c.logger.Println("Acquiring an existing lock whose revisionVersionNumber did not change for ",
c.partitionKeyName, " partitionKeyName=", key)
return c.putLockItemAndStartSessionMonitor(
additionalAttributes, key, deleteLockOnRelease, newLockData,
ctx, additionalAttributes, key, deleteLockOnRelease, newLockData,
recordVersionNumber, sessionMonitor, putItemRequest)
}

func (c *Client) upsertAndMonitorNewOrReleasedLock(
ctx context.Context,
additionalAttributes map[string]*dynamodb.AttributeValue,
key string,
deleteLockOnRelease bool,
Expand Down Expand Up @@ -494,12 +507,13 @@ func (c *Client) upsertAndMonitorNewOrReleasedLock(
// expire sooner than it actually will, so they start counting towards
// its expiration before the Put succeeds
c.logger.Println("Acquiring a new lock or an existing yet released lock on ", c.partitionKeyName, "=", key)
return c.putLockItemAndStartSessionMonitor(additionalAttributes, key,
return c.putLockItemAndStartSessionMonitor(ctx, additionalAttributes, key,
deleteLockOnRelease, newLockData,
recordVersionNumber, sessionMonitor, req)
}

func (c *Client) putLockItemAndStartSessionMonitor(
ctx context.Context,
additionalAttributes map[string]*dynamodb.AttributeValue,
key string,
deleteLockOnRelease bool,
Expand All @@ -510,7 +524,7 @@ func (c *Client) putLockItemAndStartSessionMonitor(

lastUpdatedTime := time.Now()

_, err := c.dynamoDB.PutItem(putItemRequest)
_, err := c.dynamoDB.PutItemWithContext(ctx, putItemRequest)
if err != nil {
return nil, parseDynamoDBError(err, "cannot store lock item: lock already acquired by other client")
}
Expand All @@ -533,8 +547,8 @@ func (c *Client) putLockItemAndStartSessionMonitor(
return lockItem, nil
}

func (c *Client) getLockFromDynamoDB(opt getLockOptions) (*Lock, error) {
res, err := c.readFromDynamoDB(opt.partitionKeyName)
func (c *Client) getLockFromDynamoDB(ctx context.Context, opt getLockOptions) (*Lock, error) {
res, err := c.readFromDynamoDB(ctx, opt.partitionKeyName)
if err != nil {
return nil, err
}
Expand All @@ -547,11 +561,11 @@ func (c *Client) getLockFromDynamoDB(opt getLockOptions) (*Lock, error) {
return c.createLockItem(opt, item)
}

func (c *Client) readFromDynamoDB(key string) (*dynamodb.GetItemOutput, error) {
func (c *Client) readFromDynamoDB(ctx context.Context, key string) (*dynamodb.GetItemOutput, error) {
dynamoDBKey := map[string]*dynamodb.AttributeValue{
c.partitionKeyName: {S: aws.String(key)},
}
return c.dynamoDB.GetItem(&dynamodb.GetItemInput{
return c.dynamoDB.GetItemWithContext(ctx, &dynamodb.GetItemInput{
ConsistentRead: aws.Bool(true),
TableName: aws.String(c.tableName),
Key: dynamoDBKey,
Expand Down Expand Up @@ -648,6 +662,15 @@ func (c *Client) heartbeat(ctx context.Context) {
// takes a few minutes for DynamoDB to provision a new instance. Also, if the
// table already exists, it will return an error.
func (c *Client) CreateTable(tableName string, opts ...CreateTableOption) (*dynamodb.CreateTableOutput, error) {
return c.CreateTableWithContext(context.Background(), tableName, opts...)
}

// CreateTableWithContext prepares a DynamoDB table with the right schema for it
// to be used by this locking library. The table should be set up in advance,
// because it takes a few minutes for DynamoDB to provision a new instance.
// Also, if the table already exists, it will return an error. The given context
// is passed down to the underlying dynamoDB call.
func (c *Client) CreateTableWithContext(ctx context.Context, tableName string, opts ...CreateTableOption) (*dynamodb.CreateTableOutput, error) {
if c.isClosed() {
return nil, ErrClientClosed
}
Expand All @@ -659,7 +682,7 @@ func (c *Client) CreateTable(tableName string, opts ...CreateTableOption) (*dyna
for _, opt := range opts {
opt(createTableOptions)
}
return c.createTable(createTableOptions)
return c.createTable(ctx, createTableOptions)
}

// CreateTableOption is an options type for the CreateTable method in the lock
Expand Down Expand Up @@ -692,7 +715,7 @@ func WithProvisionedThroughput(provisionedThroughput *dynamodb.ProvisionedThroug
}
}

func (c *Client) createTable(opt *createDynamoDBTableOptions) (*dynamodb.CreateTableOutput, error) {
func (c *Client) createTable(ctx context.Context, opt *createDynamoDBTableOptions) (*dynamodb.CreateTableOutput, error) {
keySchema := []*dynamodb.KeySchemaElement{
{
AttributeName: aws.String(opt.partitionKeyName),
Expand Down Expand Up @@ -722,18 +745,26 @@ func (c *Client) createTable(opt *createDynamoDBTableOptions) (*dynamodb.CreateT
createTableInput.Tags = opt.tags
}

return c.dynamoDB.CreateTable(createTableInput)
return c.dynamoDB.CreateTableWithContext(ctx, createTableInput)
}

// ReleaseLock releases the given lock if the current user still has it,
// returning true if the lock was successfully released, and false if someone
// else already stole the lock or a problem happened. Deletes the lock item if
// it is released and deleteLockItemOnClose is set.
func (c *Client) ReleaseLock(lockItem *Lock, opts ...ReleaseLockOption) (bool, error) {
return c.ReleaseLockWithContext(context.Background(), lockItem, opts...)
}

// ReleaseLockWithContext releases the given lock if the current user still has it,
// returning true if the lock was successfully released, and false if someone
// else already stole the lock or a problem happened. Deletes the lock item if
// it is released and deleteLockItemOnClose is set.
func (c *Client) ReleaseLockWithContext(ctx context.Context, lockItem *Lock, opts ...ReleaseLockOption) (bool, error) {
if c.isClosed() {
return false, ErrClientClosed
}
err := c.releaseLock(lockItem, opts...)
err := c.releaseLock(ctx, lockItem, opts...)
return err == nil, err
}

Expand Down Expand Up @@ -771,7 +802,7 @@ func ownershipLockCondition(partitionKeyName, recordVersionNumber, ownerName str
return cond
}

func (c *Client) releaseLock(lockItem *Lock, opts ...ReleaseLockOption) error {
func (c *Client) releaseLock(ctx context.Context, lockItem *Lock, opts ...ReleaseLockOption) error {
options := &releaseLockOptions{
lockItem: lockItem,
}
Expand Down Expand Up @@ -801,12 +832,12 @@ func (c *Client) releaseLock(lockItem *Lock, opts ...ReleaseLockOption) error {
key := c.getItemKeys(lockItem)
ownershipLockCond := ownershipLockCondition(c.partitionKeyName, lockItem.recordVersionNumber, lockItem.ownerName)
if deleteLock {
err := c.deleteLock(ownershipLockCond, key)
err := c.deleteLock(ctx, ownershipLockCond, key)
if err != nil {
return err
}
} else {
err := c.updateLock(data, ownershipLockCond, key)
err := c.updateLock(ctx, data, ownershipLockCond, key)
if err != nil {
return err
}
Expand All @@ -815,7 +846,7 @@ func (c *Client) releaseLock(lockItem *Lock, opts ...ReleaseLockOption) error {
return nil
}

func (c *Client) deleteLock(ownershipLockCond expression.ConditionBuilder, key map[string]*dynamodb.AttributeValue) error {
func (c *Client) deleteLock(ctx context.Context, ownershipLockCond expression.ConditionBuilder, key map[string]*dynamodb.AttributeValue) error {
delExpr, _ := expression.NewBuilder().WithCondition(ownershipLockCond).Build()
deleteItemRequest := &dynamodb.DeleteItemInput{
TableName: aws.String(c.tableName),
Expand All @@ -824,14 +855,14 @@ func (c *Client) deleteLock(ownershipLockCond expression.ConditionBuilder, key m
ExpressionAttributeNames: delExpr.Names(),
ExpressionAttributeValues: delExpr.Values(),
}
_, err := c.dynamoDB.DeleteItem(deleteItemRequest)
_, err := c.dynamoDB.DeleteItemWithContext(ctx, deleteItemRequest)
if err != nil {
return err
}
return nil
}

func (c *Client) updateLock(data []byte, ownershipLockCond expression.ConditionBuilder, key map[string]*dynamodb.AttributeValue) error {
func (c *Client) updateLock(ctx context.Context, data []byte, ownershipLockCond expression.ConditionBuilder, key map[string]*dynamodb.AttributeValue) error {
update := expression.Set(isReleasedAttr, isReleasedAttrVal)
if len(data) > 0 {
update = update.Set(dataAttr, expression.Value(data))
Expand All @@ -847,14 +878,14 @@ func (c *Client) updateLock(data []byte, ownershipLockCond expression.ConditionB
ExpressionAttributeValues: updateExpr.Values(),
}

_, err := c.dynamoDB.UpdateItem(updateItemRequest)
_, err := c.dynamoDB.UpdateItemWithContext(ctx, updateItemRequest)
return err
}

func (c *Client) releaseAllLocks() error {
func (c *Client) releaseAllLocks(ctx context.Context) error {
var err error
c.locks.Range(func(key interface{}, value interface{}) bool {
err = c.releaseLock(value.(*Lock))
err = c.releaseLock(ctx, value.(*Lock))
return err == nil
})
return err
Expand All @@ -875,22 +906,37 @@ func (c *Client) getItemKeys(lockItem *Lock) map[string]*dynamodb.AttributeValue
// should check lockItem.isExpired() to figure out if it currently has the
// lock.)
func (c *Client) Get(key string) (*Lock, error) {
return c.GetWithContext(context.Background(), key)
}

// GetWithContext finds out who owns the given lock, but does not acquire the
// lock. It returns the metadata currently associated with the given lock. If
// the client currently has the lock, it will return the lock, and operations
// such as releaseLock will work. However, if the client does not have the lock,
// then operations like releaseLock will not work (after calling Get, the caller
// should check lockItem.isExpired() to figure out if it currently has the
// lock.) If the context is canceled, it is going to return the context error
// on local cache hit. The given context is passed down to the underlying
// dynamoDB call.
func (c *Client) GetWithContext(ctx context.Context, key string) (*Lock, error) {
if c.isClosed() {
return nil, ErrClientClosed
}

if err := ctx.Err(); err != nil {
return nil, err
}

getLockOption := getLockOptions{
partitionKeyName: key,
}

keyName := getLockOption.partitionKeyName

v, ok := c.locks.Load(keyName)
if ok {
return v.(*Lock), nil
}

lockItem, err := c.getLockFromDynamoDB(getLockOption)
lockItem, err := c.getLockFromDynamoDB(ctx, getLockOption)
if err != nil {
return nil, err
}
Expand All @@ -916,13 +962,19 @@ func (c *Client) isClosed() bool {

// Close releases all of the locks.
func (c *Client) Close() error {
return c.CloseWithContext(context.Background())
}

// CloseWithContext releases all of the locks. The given context is passed down
// to the underlying dynamoDB calls.
func (c *Client) CloseWithContext(ctx context.Context) error {
err := ErrClientClosed
c.closeOnce.Do(func() {
// Hold the write lock for the duration of the close operation
// to prevent new locks from being acquired.
c.mu.Lock()
defer c.mu.Unlock()
err = c.releaseAllLocks()
err = c.releaseAllLocks(context.Background())
c.stopHeartbeat()
c.closed = true
})
Expand Down
19 changes: 15 additions & 4 deletions client_heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package dynamolock

import (
"context"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -49,12 +50,22 @@ func ReplaceHeartbeatData(data []byte) SendHeartbeatOption {
}
}

// SendHeartbeat indicatee that the given lock is still being worked on. If
// SendHeartbeat indicates that the given lock is still being worked on. If
// using WithHeartbeatPeriod > 0 when setting up this object, then this method
// is unnecessary, because the background thread will be periodically calling it
// and sending heartbeats. However, if WithHeartbeatPeriod = 0, then this method
// must be called to instruct DynamoDB that the lock should not be expired.
func (c *Client) SendHeartbeat(lockItem *Lock, opts ...SendHeartbeatOption) error {
return c.SendHeartbeatWithContext(context.Background(), lockItem, opts...)
}

// SendHeartbeatWithContext indicates that the given lock is still being worked
// on. If using WithHeartbeatPeriod > 0 when setting up this object, then this
// method is unnecessary, because the background thread will be periodically
// calling it and sending heartbeats. However, if WithHeartbeatPeriod = 0, then
// this method must be called to instruct DynamoDB that the lock should not be
// expired. The given context is passed down to the underlying dynamoDB call.
func (c *Client) SendHeartbeatWithContext(ctx context.Context, lockItem *Lock, opts ...SendHeartbeatOption) error {
if c.isClosed() {
return ErrClientClosed
}
Expand All @@ -64,10 +75,10 @@ func (c *Client) SendHeartbeat(lockItem *Lock, opts ...SendHeartbeatOption) erro
for _, opt := range opts {
opt(sho)
}
return c.sendHeartbeat(sho)
return c.sendHeartbeat(ctx, sho)
}

func (c *Client) sendHeartbeat(options *sendHeartbeatOptions) error {
func (c *Client) sendHeartbeat(ctx context.Context, options *sendHeartbeatOptions) error {
leaseDuration := c.leaseDuration

lockItem := options.lockItem
Expand Down Expand Up @@ -109,7 +120,7 @@ func (c *Client) sendHeartbeat(options *sendHeartbeatOptions) error {

lastUpdateOfLock := time.Now()

_, err := c.dynamoDB.UpdateItem(updateItemInput)
_, err := c.dynamoDB.UpdateItemWithContext(ctx, updateItemInput)
if err != nil {
err := parseDynamoDBError(err, "already acquired lock, stopping heartbeats")
if isLockNotGrantedError(err) {
Expand Down
Loading

0 comments on commit 829fa9a

Please sign in to comment.