Skip to content

Commit

Permalink
position handling + address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
maha-hajja authored and maha-hajja committed Oct 3, 2024
1 parent 478bf40 commit 25432e4
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 98 deletions.
151 changes: 108 additions & 43 deletions iterator/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,31 @@ import (

// CDCIterator iterates through the table's stream.
type CDCIterator struct {
tableName string
partitionKey string
sortKey string
streamsClient *dynamodbstreams.Client
lastSequenceNumber *string
cache chan stypes.Record
streamArn string
shardIterator *string
shardIndex int
tomb *tomb.Tomb
ticker *time.Ticker
p position.Position
tableName string
partitionKey string
sortKey string
streamsClient *dynamodbstreams.Client
cache chan stypes.Record
streamArn string
shardIterator *string
shardIndex int
tomb *tomb.Tomb
ticker *time.Ticker
p position.Position
}

// NewCDCIterator initializes a CDCIterator starting from the provided position.
func NewCDCIterator(ctx context.Context, tableName string, pKey string, sKey string, pollingPeriod time.Duration, client *dynamodbstreams.Client, streamArn string, p position.Position) (*CDCIterator, error) {
c := &CDCIterator{
tableName: tableName,
partitionKey: pKey,
sortKey: sKey,
streamsClient: client,
lastSequenceNumber: nil,
streamArn: streamArn,
tomb: &tomb.Tomb{},
cache: make(chan stypes.Record, 10), // todo size?
ticker: time.NewTicker(pollingPeriod),
p: p, // todo position handling when pipeline is restarted
tableName: tableName,
partitionKey: pKey,
sortKey: sKey,
streamsClient: client,
streamArn: streamArn,
tomb: &tomb.Tomb{},
cache: make(chan stypes.Record, 10), // todo size?
ticker: time.NewTicker(pollingPeriod),
p: p,
}
shardIterator, err := c.getShardIterator(ctx)
if err != nil {
Expand Down Expand Up @@ -107,6 +105,7 @@ func (c *CDCIterator) startCDC() error {
case <-c.tomb.Dying():
return fmt.Errorf("tomb is dying: %w", c.tomb.Err())
case <-c.ticker.C: // detect changes every polling period
// todo loop for more than 1000 records
out, err := c.streamsClient.GetRecords(c.tomb.Context(nil), &dynamodbstreams.GetRecordsInput{ //nolint:staticcheck // SA1012 tomb expects nil
ShardIterator: c.shardIterator,
})
Expand All @@ -128,7 +127,7 @@ func (c *CDCIterator) startCDC() error {
c.shardIterator = out.NextShardIterator

if c.shardIterator == nil {
c.shardIterator, err = c.getShardIterator(c.tomb.Context(nil)) //nolint:staticcheck // SA1012 tomb expects nil
c.shardIterator, err = c.moveToNextShard(c.tomb.Context(nil)) //nolint:staticcheck // SA1012 tomb expects nil
if err != nil {
return fmt.Errorf("failed to get shard iterator: %w", err)
}
Expand All @@ -137,38 +136,104 @@ func (c *CDCIterator) startCDC() error {
}
}

// getShardIterator gets the shard iterator for the stream.
// getShardIterator gets the shard iterator for the stream depending on the position.
// if position is nil, it starts from the last shard with the "LATEST" iterator type.
// if position has a sequence number, it starts from the shard containing that sequence number, with the "AFTER_SEQUENCE_NUMBER" iterator type.
// if the given sequence number was not found (expired) or was at the end of the shard, then we start from the following shard to it, with the "TRIM_HORIZON" iterator type.
func (c *CDCIterator) getShardIterator(ctx context.Context) (*string, error) {
// Describe the stream to get the shard ID
// describe the stream to get the shard ID.
describeStreamOutput, err := c.streamsClient.DescribeStream(ctx, &dynamodbstreams.DescribeStreamInput{
StreamArn: aws.String(c.streamArn),
})
if err != nil {
return nil, fmt.Errorf("failed to describe stream: %w", err)
}

// Start from the last shard
if c.shardIndex == 0 {
c.shardIndex = len(describeStreamOutput.StreamDescription.Shards) - 1
shards := describeStreamOutput.StreamDescription.Shards
if len(shards) == 0 {
return nil, errors.New("no shards found in the stream")
}

var selectedShardID string
shardIteratorType := stypes.ShardIteratorTypeLatest
// If position has a sequence number, find the shard containing it
if c.p.SequenceNumber != "" {
for i, shard := range shards {
// if the sequence number is at the end of the shard, save the index od the shard to use with the "TRIM_HORIZON" case.
if shard.SequenceNumberRange.EndingSequenceNumber != nil && c.p.SequenceNumber == *shard.SequenceNumberRange.EndingSequenceNumber {
selectedShardID = *shards[i+1].ShardId // Start from shard following this one.
shardIteratorType = stypes.ShardIteratorTypeTrimHorizon
break
}
// find the shard that contains the position's sequence number.
if *shard.SequenceNumberRange.StartingSequenceNumber <= c.p.SequenceNumber &&
(shard.SequenceNumberRange.EndingSequenceNumber == nil ||
c.p.SequenceNumber < *shard.SequenceNumberRange.EndingSequenceNumber) {
selectedShardID = *shard.ShardId
shardIteratorType = stypes.ShardIteratorTypeAfterSequenceNumber
break
}
}

// if no shard was found containing the sequence number, then start from the beginning of the shards.
if selectedShardID == "" {
selectedShardID = *shards[0].ShardId // Start from the first shard
shardIteratorType = stypes.ShardIteratorTypeTrimHorizon
sdk.Logger(ctx).Warn().Msg("The given sequence number is expired, will start getting events from the beginning of the stream.")
}
} else {
c.shardIndex++
// no sequence number, select the latest shard (the last one in the list)
selectedShardID = *shards[len(shards)-1].ShardId
shardIteratorType = stypes.ShardIteratorTypeLatest
}
shardID := describeStreamOutput.StreamDescription.Shards[c.shardIndex].ShardId

// now that we have the shard ID, we can fetch the shard iterator
return c.getShardIteratorForShard(ctx, selectedShardID, shardIteratorType)
}

// getShardIteratorForShard gets the shard iterator for a specific shard.
func (c *CDCIterator) getShardIteratorForShard(ctx context.Context, shardID string, shardIteratorType stypes.ShardIteratorType) (*string, error) {
input := &dynamodbstreams.GetShardIteratorInput{
StreamArn: aws.String(c.streamArn),
ShardId: shardID,
ShardIteratorType: stypes.ShardIteratorTypeLatest,
ShardId: aws.String(shardID),
ShardIteratorType: shardIteratorType,
}
// continue from a specific position.
if shardIteratorType == stypes.ShardIteratorTypeAfterSequenceNumber {
input.SequenceNumber = aws.String(c.p.SequenceNumber)
}
// Get the shard iterator

// get the shard iterator.
getShardIteratorOutput, err := c.streamsClient.GetShardIterator(ctx, input)
if err != nil {
return nil, fmt.Errorf("failed to get shard iterator: %w", err)
return nil, fmt.Errorf("failed to get shard iterator for shard %s: %w", shardID, err)
}

return getShardIteratorOutput.ShardIterator, nil
}

// moveToNextShard used to get the iterator of the shard that follows the current one after it was closed.
func (c *CDCIterator) moveToNextShard(ctx context.Context) (*string, error) {
// describe the stream to get the shard details.
describeStreamOutput, err := c.streamsClient.DescribeStream(ctx, &dynamodbstreams.DescribeStreamInput{
StreamArn: aws.String(c.streamArn),
})
if err != nil {
return nil, fmt.Errorf("failed to describe stream: %w", err)
}

shards := describeStreamOutput.StreamDescription.Shards

// move to the next shard if the current one is closed
if c.shardIndex+1 < len(shards) {
c.shardIndex++
nextShard := shards[c.shardIndex]
// get shard iterator for the new shard
return c.getShardIteratorForShard(ctx, *nextShard.ShardId, stypes.ShardIteratorTypeTrimHorizon)
}
return nil, errors.New("no more shards available")
}

func (c *CDCIterator) getRecMap(item map[string]stypes.AttributeValue) map[string]interface{} { //nolint:dupl // different types
stringMap := make(map[string]interface{})
for k, v := range item {
Expand Down Expand Up @@ -213,22 +278,22 @@ func (c *CDCIterator) getOpenCDCRec(rec stypes.Record) (opencdc.Record, error) {
if rec.EventName == stypes.OperationTypeRemove {
image = oldImage
}

// prepare key and position
structuredKey := opencdc.StructuredData{
c.partitionKey: image[c.partitionKey],
}
c.p.Key = *rec.Dynamodb.SequenceNumber
c.p.IteratorType = position.TypeCDC
if c.sortKey != "" {
c.p.Key = c.p.Key + "." + fmt.Sprintf("%v", image[c.sortKey])
structuredKey = opencdc.StructuredData{
c.partitionKey: image[c.partitionKey],
c.sortKey: image[c.sortKey],
}
}
c.p.Time = time.Now()
pos, err := c.p.ToRecordPosition()
pos := position.Position{
IteratorType: position.TypeCDC,
SequenceNumber: *rec.Dynamodb.SequenceNumber,
Time: time.Now(),
}
cdcPos, err := pos.ToRecordPosition()
if err != nil {
return opencdc.Record{}, fmt.Errorf("failed to build record's CDC position: %w", err)
}
Expand All @@ -237,7 +302,7 @@ func (c *CDCIterator) getOpenCDCRec(rec stypes.Record) (opencdc.Record, error) {
switch rec.EventName {
case stypes.OperationTypeInsert:
return sdk.Util.Source.NewRecordCreate(
pos,
cdcPos,
map[string]string{
opencdc.MetadataCollection: c.tableName,
},
Expand All @@ -246,7 +311,7 @@ func (c *CDCIterator) getOpenCDCRec(rec stypes.Record) (opencdc.Record, error) {
), nil
case stypes.OperationTypeModify:
return sdk.Util.Source.NewRecordUpdate(
pos,
cdcPos,
map[string]string{
opencdc.MetadataCollection: c.tableName,
},
Expand All @@ -256,7 +321,7 @@ func (c *CDCIterator) getOpenCDCRec(rec stypes.Record) (opencdc.Record, error) {
), nil
case stypes.OperationTypeRemove:
return sdk.Util.Source.NewRecordDelete(
pos,
cdcPos,
map[string]string{
opencdc.MetadataCollection: c.tableName,
},
Expand Down
38 changes: 9 additions & 29 deletions iterator/combined_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func NewCombinedIterator(

switch p.IteratorType {
case position.TypeSnapshot:
if len(p.Key) != 0 {
if len(p.PartitionKey) != 0 {
sdk.Logger(ctx).
Warn().
Msg("previous snapshot did not complete successfully. snapshot will be restarted for consistency.")
Expand All @@ -72,6 +72,11 @@ func NewCombinedIterator(
if err != nil {
return nil, fmt.Errorf("could not create the snapshot iterator: %w", err)
}
// start listening for changes while snapshot is running
c.cdcIterator, err = NewCDCIterator(ctx, tableName, pKey, sKey, pollingPeriod, streamsClient, streamArn, position.Position{})
if err != nil {
return nil, fmt.Errorf("could not create the CDC iterator: %w", err)
}
case position.TypeCDC:
c.cdcIterator, err = NewCDCIterator(ctx, tableName, pKey, sKey, pollingPeriod, streamsClient, streamArn, p)
if err != nil {
Expand All @@ -86,15 +91,7 @@ func NewCombinedIterator(
func (c *CombinedIterator) HasNext(ctx context.Context) bool {
switch {
case c.snapshotIterator != nil:
// if snapshot is over
if !c.snapshotIterator.HasNext(ctx) {
err := c.switchToCDCIterator(ctx)
if err != nil {
return false
}
return false
}
return true
return c.snapshotIterator.HasNext(ctx)
case c.cdcIterator != nil:
return c.cdcIterator.HasNext(ctx)
default:
Expand All @@ -110,12 +107,8 @@ func (c *CombinedIterator) Next(ctx context.Context) (opencdc.Record, error) {
return opencdc.Record{}, err
}
if !c.snapshotIterator.HasNext(ctx) {
// switch to cdc iterator
err := c.switchToCDCIterator(ctx)
if err != nil {
return opencdc.Record{}, err
}
// change the last record's position to CDC
c.snapshotIterator = nil
// change the last snapshot record's position to CDC
r.Position, err = position.ConvertToCDCPosition(r.Position)
if err != nil {
return opencdc.Record{}, fmt.Errorf("error converting position to CDC: %w", err)
Expand All @@ -135,16 +128,3 @@ func (c *CombinedIterator) Stop() {
c.cdcIterator.Stop()
}
}

func (c *CombinedIterator) switchToCDCIterator(ctx context.Context) error {
var err error
pos := position.Position{
IteratorType: position.TypeCDC,
}
c.cdcIterator, err = NewCDCIterator(ctx, c.tableName, c.partitionKey, c.sortKey, c.pollingPeriod, c.streamsClient, c.streamArn, pos)
if err != nil {
return fmt.Errorf("could not create cdc iterator: %w", err)
}
c.snapshotIterator = nil
return nil
}
29 changes: 16 additions & 13 deletions iterator/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ func NewSnapshotIterator(tableName string, pKey string, sKey string, client *dyn
}, nil
}

// todo can use dynamodb ScanPaginator instead
// refreshPage fetches the next page of items from DynamoDB.
func (s *SnapshotIterator) refreshPage(ctx context.Context) error {
s.items = nil
Expand Down Expand Up @@ -87,7 +86,11 @@ func (s *SnapshotIterator) HasNext(ctx context.Context) bool {
if s.lastEvaluatedKey != nil || s.firstIt {
s.firstIt = false
err := s.refreshPage(ctx)
return err == nil
if err != nil {
sdk.Logger(ctx).Error().Err(err).Msg("failed to get the next page of the snapshot.")
return false
}
return true
}
return false
}
Expand All @@ -106,29 +109,29 @@ func (s *SnapshotIterator) Stop() {
}

func (s *SnapshotIterator) buildOpenCDCRecord(item map[string]interface{}) (opencdc.Record, error) {
var structuredKey opencdc.StructuredData
s.p.Key = fmt.Sprintf("%v", item[s.partitionKey])
structuredKey := opencdc.StructuredData{
s.partitionKey: item[s.partitionKey],
}
if s.sortKey != "" {
s.p.Key = s.p.Key + "." + fmt.Sprintf("%v", item[s.sortKey])
structuredKey = opencdc.StructuredData{
s.partitionKey: item[s.partitionKey],
s.sortKey: item[s.sortKey],
}
} else {
structuredKey = opencdc.StructuredData{
s.partitionKey: item[s.partitionKey],
}
}
s.p.IteratorType = position.TypeSnapshot
s.p.Time = time.Now()
pos, err := s.p.ToRecordPosition()
pos := position.Position{
IteratorType: position.TypeSnapshot,
PartitionKey: fmt.Sprintf("%v", item[s.partitionKey]),
SortKey: fmt.Sprintf("%v", item[s.sortKey]),
Time: time.Now(),
}
recordPos, err := pos.ToRecordPosition()
if err != nil {
return opencdc.Record{}, fmt.Errorf("error building snapshot position: %w", err)
}

// Create the record
return sdk.Util.Source.NewRecordSnapshot(
pos,
recordPos,
map[string]string{
opencdc.MetadataCollection: s.tableName,
},
Expand Down
5 changes: 3 additions & 2 deletions position/position.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ type IteratorType int
type Position struct {
IteratorType IteratorType `json:"iterator_type"`

// the record's key, for the snapshot iterator
Key string `json:"key"`
// the record's keys, for the snapshot iterator
PartitionKey string `json:"partition_key"`
SortKey string `json:"sort_key"`

// the record's sequence number in the stream, for the CDC iterator.
SequenceNumber string `json:"sequence_number"`
Expand Down
Loading

0 comments on commit 25432e4

Please sign in to comment.