diff --git a/iterator/cdc.go b/iterator/cdc.go index 10e4033..d059328 100644 --- a/iterator/cdc.go +++ b/iterator/cdc.go @@ -31,33 +31,31 @@ import ( // CDCIterator iterates through the table's stream. type CDCIterator struct { - tableName string - partitionKey string - sortKey string - streamsClient *dynamodbstreams.Client - lastSequenceNumber *string - cache chan stypes.Record - streamArn string - shardIterator *string - shardIndex int - tomb *tomb.Tomb - ticker *time.Ticker - p position.Position + tableName string + partitionKey string + sortKey string + streamsClient *dynamodbstreams.Client + cache chan stypes.Record + streamArn string + shardIterator *string + shardIndex int + tomb *tomb.Tomb + ticker *time.Ticker + p position.Position } // NewCDCIterator initializes a CDCIterator starting from the provided position. func NewCDCIterator(ctx context.Context, tableName string, pKey string, sKey string, pollingPeriod time.Duration, client *dynamodbstreams.Client, streamArn string, p position.Position) (*CDCIterator, error) { c := &CDCIterator{ - tableName: tableName, - partitionKey: pKey, - sortKey: sKey, - streamsClient: client, - lastSequenceNumber: nil, - streamArn: streamArn, - tomb: &tomb.Tomb{}, - cache: make(chan stypes.Record, 10), // todo size? - ticker: time.NewTicker(pollingPeriod), - p: p, // todo position handling when pipeline is restarted + tableName: tableName, + partitionKey: pKey, + sortKey: sKey, + streamsClient: client, + streamArn: streamArn, + tomb: &tomb.Tomb{}, + cache: make(chan stypes.Record, 10), // todo size? + ticker: time.NewTicker(pollingPeriod), + p: p, } shardIterator, err := c.getShardIterator(ctx) if err != nil { @@ -107,6 +105,7 @@ func (c *CDCIterator) startCDC() error { case <-c.tomb.Dying(): return fmt.Errorf("tomb is dying: %w", c.tomb.Err()) case <-c.ticker.C: // detect changes every polling period + // todo loop for more than 1000 records out, err := c.streamsClient.GetRecords(c.tomb.Context(nil), &dynamodbstreams.GetRecordsInput{ //nolint:staticcheck // SA1012 tomb expects nil ShardIterator: c.shardIterator, }) @@ -128,7 +127,7 @@ func (c *CDCIterator) startCDC() error { c.shardIterator = out.NextShardIterator if c.shardIterator == nil { - c.shardIterator, err = c.getShardIterator(c.tomb.Context(nil)) //nolint:staticcheck // SA1012 tomb expects nil + c.shardIterator, err = c.moveToNextShard(c.tomb.Context(nil)) //nolint:staticcheck // SA1012 tomb expects nil if err != nil { return fmt.Errorf("failed to get shard iterator: %w", err) } @@ -137,9 +136,12 @@ func (c *CDCIterator) startCDC() error { } } -// getShardIterator gets the shard iterator for the stream. +// getShardIterator gets the shard iterator for the stream depending on the position. +// if position is nil, it starts from the last shard with the "LATEST" iterator type. +// if position has a sequence number, it starts from the shard containing that sequence number, with the "AFTER_SEQUENCE_NUMBER" iterator type. +// if the given sequence number was not found (expired) or was at the end of the shard, then we start from the following shard to it, with the "TRIM_HORIZON" iterator type. func (c *CDCIterator) getShardIterator(ctx context.Context) (*string, error) { - // Describe the stream to get the shard ID + // describe the stream to get the shard ID. describeStreamOutput, err := c.streamsClient.DescribeStream(ctx, &dynamodbstreams.DescribeStreamInput{ StreamArn: aws.String(c.streamArn), }) @@ -147,28 +149,91 @@ func (c *CDCIterator) getShardIterator(ctx context.Context) (*string, error) { return nil, fmt.Errorf("failed to describe stream: %w", err) } - // Start from the last shard - if c.shardIndex == 0 { - c.shardIndex = len(describeStreamOutput.StreamDescription.Shards) - 1 + shards := describeStreamOutput.StreamDescription.Shards + if len(shards) == 0 { + return nil, errors.New("no shards found in the stream") + } + + var selectedShardID string + shardIteratorType := stypes.ShardIteratorTypeLatest + // If position has a sequence number, find the shard containing it + if c.p.SequenceNumber != "" { + for i, shard := range shards { + // if the sequence number is at the end of the shard, save the index od the shard to use with the "TRIM_HORIZON" case. + if shard.SequenceNumberRange.EndingSequenceNumber != nil && c.p.SequenceNumber == *shard.SequenceNumberRange.EndingSequenceNumber { + selectedShardID = *shards[i+1].ShardId // Start from shard following this one. + shardIteratorType = stypes.ShardIteratorTypeTrimHorizon + break + } + // find the shard that contains the position's sequence number. + if *shard.SequenceNumberRange.StartingSequenceNumber <= c.p.SequenceNumber && + (shard.SequenceNumberRange.EndingSequenceNumber == nil || + c.p.SequenceNumber < *shard.SequenceNumberRange.EndingSequenceNumber) { + selectedShardID = *shard.ShardId + shardIteratorType = stypes.ShardIteratorTypeAfterSequenceNumber + break + } + } + + // if no shard was found containing the sequence number, then start from the beginning of the shards. + if selectedShardID == "" { + selectedShardID = *shards[0].ShardId // Start from the first shard + shardIteratorType = stypes.ShardIteratorTypeTrimHorizon + sdk.Logger(ctx).Warn().Msg("The given sequence number is expired, will start getting events from the beginning of the stream.") + } } else { - c.shardIndex++ + // no sequence number, select the latest shard (the last one in the list) + selectedShardID = *shards[len(shards)-1].ShardId + shardIteratorType = stypes.ShardIteratorTypeLatest } - shardID := describeStreamOutput.StreamDescription.Shards[c.shardIndex].ShardId + // now that we have the shard ID, we can fetch the shard iterator + return c.getShardIteratorForShard(ctx, selectedShardID, shardIteratorType) +} + +// getShardIteratorForShard gets the shard iterator for a specific shard. +func (c *CDCIterator) getShardIteratorForShard(ctx context.Context, shardID string, shardIteratorType stypes.ShardIteratorType) (*string, error) { input := &dynamodbstreams.GetShardIteratorInput{ StreamArn: aws.String(c.streamArn), - ShardId: shardID, - ShardIteratorType: stypes.ShardIteratorTypeLatest, + ShardId: aws.String(shardID), + ShardIteratorType: shardIteratorType, + } + // continue from a specific position. + if shardIteratorType == stypes.ShardIteratorTypeAfterSequenceNumber { + input.SequenceNumber = aws.String(c.p.SequenceNumber) } - // Get the shard iterator + + // get the shard iterator. getShardIteratorOutput, err := c.streamsClient.GetShardIterator(ctx, input) if err != nil { - return nil, fmt.Errorf("failed to get shard iterator: %w", err) + return nil, fmt.Errorf("failed to get shard iterator for shard %s: %w", shardID, err) } return getShardIteratorOutput.ShardIterator, nil } +// moveToNextShard used to get the iterator of the shard that follows the current one after it was closed. +func (c *CDCIterator) moveToNextShard(ctx context.Context) (*string, error) { + // describe the stream to get the shard details. + describeStreamOutput, err := c.streamsClient.DescribeStream(ctx, &dynamodbstreams.DescribeStreamInput{ + StreamArn: aws.String(c.streamArn), + }) + if err != nil { + return nil, fmt.Errorf("failed to describe stream: %w", err) + } + + shards := describeStreamOutput.StreamDescription.Shards + + // move to the next shard if the current one is closed + if c.shardIndex+1 < len(shards) { + c.shardIndex++ + nextShard := shards[c.shardIndex] + // get shard iterator for the new shard + return c.getShardIteratorForShard(ctx, *nextShard.ShardId, stypes.ShardIteratorTypeTrimHorizon) + } + return nil, errors.New("no more shards available") +} + func (c *CDCIterator) getRecMap(item map[string]stypes.AttributeValue) map[string]interface{} { //nolint:dupl // different types stringMap := make(map[string]interface{}) for k, v := range item { @@ -213,22 +278,22 @@ func (c *CDCIterator) getOpenCDCRec(rec stypes.Record) (opencdc.Record, error) { if rec.EventName == stypes.OperationTypeRemove { image = oldImage } - // prepare key and position structuredKey := opencdc.StructuredData{ c.partitionKey: image[c.partitionKey], } - c.p.Key = *rec.Dynamodb.SequenceNumber - c.p.IteratorType = position.TypeCDC if c.sortKey != "" { - c.p.Key = c.p.Key + "." + fmt.Sprintf("%v", image[c.sortKey]) structuredKey = opencdc.StructuredData{ c.partitionKey: image[c.partitionKey], c.sortKey: image[c.sortKey], } } - c.p.Time = time.Now() - pos, err := c.p.ToRecordPosition() + pos := position.Position{ + IteratorType: position.TypeCDC, + SequenceNumber: *rec.Dynamodb.SequenceNumber, + Time: time.Now(), + } + cdcPos, err := pos.ToRecordPosition() if err != nil { return opencdc.Record{}, fmt.Errorf("failed to build record's CDC position: %w", err) } @@ -237,7 +302,7 @@ func (c *CDCIterator) getOpenCDCRec(rec stypes.Record) (opencdc.Record, error) { switch rec.EventName { case stypes.OperationTypeInsert: return sdk.Util.Source.NewRecordCreate( - pos, + cdcPos, map[string]string{ opencdc.MetadataCollection: c.tableName, }, @@ -246,7 +311,7 @@ func (c *CDCIterator) getOpenCDCRec(rec stypes.Record) (opencdc.Record, error) { ), nil case stypes.OperationTypeModify: return sdk.Util.Source.NewRecordUpdate( - pos, + cdcPos, map[string]string{ opencdc.MetadataCollection: c.tableName, }, @@ -256,7 +321,7 @@ func (c *CDCIterator) getOpenCDCRec(rec stypes.Record) (opencdc.Record, error) { ), nil case stypes.OperationTypeRemove: return sdk.Util.Source.NewRecordDelete( - pos, + cdcPos, map[string]string{ opencdc.MetadataCollection: c.tableName, }, diff --git a/iterator/combined_iterator.go b/iterator/combined_iterator.go index 4f7b970..5c5fe7c 100644 --- a/iterator/combined_iterator.go +++ b/iterator/combined_iterator.go @@ -62,7 +62,7 @@ func NewCombinedIterator( switch p.IteratorType { case position.TypeSnapshot: - if len(p.Key) != 0 { + if len(p.PartitionKey) != 0 { sdk.Logger(ctx). Warn(). Msg("previous snapshot did not complete successfully. snapshot will be restarted for consistency.") @@ -72,6 +72,11 @@ func NewCombinedIterator( if err != nil { return nil, fmt.Errorf("could not create the snapshot iterator: %w", err) } + // start listening for changes while snapshot is running + c.cdcIterator, err = NewCDCIterator(ctx, tableName, pKey, sKey, pollingPeriod, streamsClient, streamArn, position.Position{}) + if err != nil { + return nil, fmt.Errorf("could not create the CDC iterator: %w", err) + } case position.TypeCDC: c.cdcIterator, err = NewCDCIterator(ctx, tableName, pKey, sKey, pollingPeriod, streamsClient, streamArn, p) if err != nil { @@ -86,15 +91,7 @@ func NewCombinedIterator( func (c *CombinedIterator) HasNext(ctx context.Context) bool { switch { case c.snapshotIterator != nil: - // if snapshot is over - if !c.snapshotIterator.HasNext(ctx) { - err := c.switchToCDCIterator(ctx) - if err != nil { - return false - } - return false - } - return true + return c.snapshotIterator.HasNext(ctx) case c.cdcIterator != nil: return c.cdcIterator.HasNext(ctx) default: @@ -110,12 +107,8 @@ func (c *CombinedIterator) Next(ctx context.Context) (opencdc.Record, error) { return opencdc.Record{}, err } if !c.snapshotIterator.HasNext(ctx) { - // switch to cdc iterator - err := c.switchToCDCIterator(ctx) - if err != nil { - return opencdc.Record{}, err - } - // change the last record's position to CDC + c.snapshotIterator = nil + // change the last snapshot record's position to CDC r.Position, err = position.ConvertToCDCPosition(r.Position) if err != nil { return opencdc.Record{}, fmt.Errorf("error converting position to CDC: %w", err) @@ -135,16 +128,3 @@ func (c *CombinedIterator) Stop() { c.cdcIterator.Stop() } } - -func (c *CombinedIterator) switchToCDCIterator(ctx context.Context) error { - var err error - pos := position.Position{ - IteratorType: position.TypeCDC, - } - c.cdcIterator, err = NewCDCIterator(ctx, c.tableName, c.partitionKey, c.sortKey, c.pollingPeriod, c.streamsClient, c.streamArn, pos) - if err != nil { - return fmt.Errorf("could not create cdc iterator: %w", err) - } - c.snapshotIterator = nil - return nil -} diff --git a/iterator/snapshot.go b/iterator/snapshot.go index f5898ba..445816f 100644 --- a/iterator/snapshot.go +++ b/iterator/snapshot.go @@ -53,7 +53,6 @@ func NewSnapshotIterator(tableName string, pKey string, sKey string, client *dyn }, nil } -// todo can use dynamodb ScanPaginator instead // refreshPage fetches the next page of items from DynamoDB. func (s *SnapshotIterator) refreshPage(ctx context.Context) error { s.items = nil @@ -87,7 +86,11 @@ func (s *SnapshotIterator) HasNext(ctx context.Context) bool { if s.lastEvaluatedKey != nil || s.firstIt { s.firstIt = false err := s.refreshPage(ctx) - return err == nil + if err != nil { + sdk.Logger(ctx).Error().Err(err).Msg("failed to get the next page of the snapshot.") + return false + } + return true } return false } @@ -106,29 +109,29 @@ func (s *SnapshotIterator) Stop() { } func (s *SnapshotIterator) buildOpenCDCRecord(item map[string]interface{}) (opencdc.Record, error) { - var structuredKey opencdc.StructuredData - s.p.Key = fmt.Sprintf("%v", item[s.partitionKey]) + structuredKey := opencdc.StructuredData{ + s.partitionKey: item[s.partitionKey], + } if s.sortKey != "" { - s.p.Key = s.p.Key + "." + fmt.Sprintf("%v", item[s.sortKey]) structuredKey = opencdc.StructuredData{ s.partitionKey: item[s.partitionKey], s.sortKey: item[s.sortKey], } - } else { - structuredKey = opencdc.StructuredData{ - s.partitionKey: item[s.partitionKey], - } } - s.p.IteratorType = position.TypeSnapshot - s.p.Time = time.Now() - pos, err := s.p.ToRecordPosition() + pos := position.Position{ + IteratorType: position.TypeSnapshot, + PartitionKey: fmt.Sprintf("%v", item[s.partitionKey]), + SortKey: fmt.Sprintf("%v", item[s.sortKey]), + Time: time.Now(), + } + recordPos, err := pos.ToRecordPosition() if err != nil { return opencdc.Record{}, fmt.Errorf("error building snapshot position: %w", err) } // Create the record return sdk.Util.Source.NewRecordSnapshot( - pos, + recordPos, map[string]string{ opencdc.MetadataCollection: s.tableName, }, diff --git a/position/position.go b/position/position.go index 130dc23..dbf73a8 100644 --- a/position/position.go +++ b/position/position.go @@ -32,8 +32,9 @@ type IteratorType int type Position struct { IteratorType IteratorType `json:"iterator_type"` - // the record's key, for the snapshot iterator - Key string `json:"key"` + // the record's keys, for the snapshot iterator + PartitionKey string `json:"partition_key"` + SortKey string `json:"sort_key"` // the record's sequence number in the stream, for the CDC iterator. SequenceNumber string `json:"sequence_number"` diff --git a/position/position_test.go b/position/position_test.go index e53dc13..9d1485d 100644 --- a/position/position_test.go +++ b/position/position_test.go @@ -26,14 +26,14 @@ import ( func TestParseSDKPosition(t *testing.T) { validPosition := Position{ IteratorType: TypeCDC, - Key: "key", + PartitionKey: "key", SequenceNumber: "my-sequence-number", Time: time.Time{}, } wrongPosType := Position{ IteratorType: 3, // non-existent type - Key: "key", + PartitionKey: "key", SequenceNumber: "my-sequence-number", } is := is.New(t) diff --git a/source.go b/source.go index 0dc8179..6216130 100644 --- a/source.go +++ b/source.go @@ -18,6 +18,7 @@ package dynamodb import ( "context" + "errors" "fmt" "time" @@ -37,12 +38,11 @@ import ( type Source struct { sdk.UnimplementedSource - config SourceConfig - dynamoDBClient *dynamodb.Client - streamsClient *dynamodbstreams.Client - lastPositionRead opencdc.Position - streamArn string - iterator Iterator + config SourceConfig + dynamoDBClient *dynamodb.Client + streamsClient *dynamodbstreams.Client + streamArn string + iterator Iterator } type SourceConfig struct { @@ -96,7 +96,6 @@ func (s *Source) Open(ctx context.Context, pos opencdc.Position) error { s.dynamoDBClient = dynamodb.NewFromConfig(cfg) s.streamsClient = dynamodbstreams.NewFromConfig(cfg) - s.lastPositionRead = pos partitionKey, sortKey, err := s.getKeyNamesFromTable(ctx) if err != nil { @@ -130,7 +129,7 @@ func (s *Source) Open(ctx context.Context, pos opencdc.Position) error { } func (s *Source) Read(ctx context.Context) (opencdc.Record, error) { - sdk.Logger(ctx).Info().Msg("Reading records from DynamoDB...") + sdk.Logger(ctx).Trace().Msg("Reading records from DynamoDB...") if !s.iterator.HasNext(ctx) { return opencdc.Record{}, sdk.ErrBackoffRetry @@ -202,7 +201,7 @@ func (s *Source) prepareStream(ctx context.Context) error { } if out.Table.LatestStreamArn == nil { - return fmt.Errorf("stream was not enabled successfully") + return errors.New("stream was not enabled successfully") } sdk.Logger(ctx).Info().Str("LatestStreamArn", *out.Table.LatestStreamArn).Msg("Stream enabled successfully.")