diff --git a/constants/errors.go b/constants/errors.go new file mode 100644 index 0000000..275707c --- /dev/null +++ b/constants/errors.go @@ -0,0 +1,9 @@ +package constants + +const ( + GenericPublishError = "GenericPublishError" +) + +var ErrorStrings = map[string]string{ + GenericPublishError: "publish error", +} diff --git a/constants/publisher.go b/constants/publisher.go new file mode 100644 index 0000000..c45da30 --- /dev/null +++ b/constants/publisher.go @@ -0,0 +1,5 @@ +package constants + +const ( + MaxBatchSize = 10 // 10 is the maximum batch size for SNS.PublishBatch +) diff --git a/publisher/models/message.go b/publisher/models/message.go new file mode 100644 index 0000000..9819420 --- /dev/null +++ b/publisher/models/message.go @@ -0,0 +1,6 @@ +package models + +type Message struct { + ID string `json:"id"` + Data interface{} `json:"data"` +} diff --git a/publisher/sns/sns.go b/publisher/sns/sns.go index e215fd6..7c1f41e 100644 --- a/publisher/sns/sns.go +++ b/publisher/sns/sns.go @@ -3,13 +3,15 @@ package sns import ( "context" "encoding/json" + "errors" "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sns" - "github.com/google/uuid" + "github.com/creatorstack/htsqs/constants" + "github.com/creatorstack/htsqs/publisher/models" ) // sender is the interface to sns.SNS. Its sole purpose is to make @@ -67,19 +69,22 @@ func (p *Publisher) Publish(ctx context.Context, msg interface{}) error { // kept under 100 messages so that all messages can be published in 10 tries. In case // of failure when parsing or publishing any of the messages, this function will stop // further publishing and return an error -func (p *Publisher) PublishBatch(ctx context.Context, msgs []interface{}) error { +func (p *Publisher) PublishBatch(ctx context.Context, msgs []models.Message) (map[string]error, int64, int64, error) { var ( defaultMessageGroupID = "default" + publishResult = make(map[string]error) err error - ) - isFifo := strings.Contains(strings.ToLower(p.cfg.TopicArn), "fifo") + errorCount int64 + successCount int64 - var ( numPublishedMessages = 0 start = 0 - end = 10 // 10 is the maximum batch size for SNS.PublishBatch + end = constants.MaxBatchSize ) + + isFifo := strings.Contains(strings.ToLower(p.cfg.TopicArn), "fifo") + if end > len(msgs) { end = len(msgs) } @@ -90,14 +95,13 @@ func (p *Publisher) PublishBatch(ctx context.Context, msgs []interface{}) error for idx := start; idx < end; idx++ { msg := msgs[idx] - b, err := json.Marshal(msg) + b, err := json.Marshal(msg.Data) if err != nil { - return err + return publishResult, successCount, errorCount, err } - entryId := uuid.New().String() requestEntry := &sns.PublishBatchRequestEntry{ - Id: aws.String(entryId), + Id: aws.String(msg.ID), Message: aws.String(string(b)), } @@ -112,20 +116,38 @@ func (p *Publisher) PublishBatch(ctx context.Context, msgs []interface{}) error PublishBatchRequestEntries: requestEntries, TopicArn: &p.cfg.TopicArn, } - _, err = p.sns.PublishBatchWithContext(ctx, input) + response, err := p.sns.PublishBatchWithContext(ctx, input) if err != nil { - return err + return publishResult, successCount, errorCount, err + } + + for _, errEntry := range response.Failed { + if errEntry != nil && errEntry.Id != nil { + errMsg := constants.GenericPublishError + if errEntry.Message != nil { + errMsg = *errEntry.Message + } + publishResult[*errEntry.Id] = errors.New(errMsg) + errorCount++ + } + } + + for _, successEntry := range response.Successful { + if successEntry != nil && successEntry.Id != nil { + publishResult[*successEntry.Id] = nil + successCount++ + } } numPublishedMessages += len(requestEntries) start = end - end += 10 + end += constants.MaxBatchSize if end > len(msgs) { end = len(msgs) } } - return err + return publishResult, successCount, errorCount, err } func defaultPublisherConfig(cfg *Config) { diff --git a/publisher/sns/sns_test.go b/publisher/sns/sns_test.go index b0feddf..0cb5a63 100644 --- a/publisher/sns/sns_test.go +++ b/publisher/sns/sns_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws/session" + "github.com/creatorstack/htsqs/publisher/models" "github.com/stretchr/testify/require" ) @@ -27,9 +28,15 @@ func TestPublisher(t *testing.T) { } func TestPublisherBatch(t *testing.T) { - inputs := []interface{}{ - jsonString(`{"key":"val1"}`), - jsonString(`{"key":"val2"}`), + inputs := []models.Message{ + { + ID: "1", + Data: jsonString(`{"key":"val1"}`), + }, + { + ID: "2", + Data: jsonString(`{"key":"val2"}`), + }, } queue := make(chan *string, len(inputs)) @@ -38,12 +45,14 @@ func TestPublisherBatch(t *testing.T) { pubs := New(Config{}) pubs.sns = &snsPublisherMock{queue: queue} - require.NoError(t, pubs.PublishBatch(context.TODO(), inputs)) + _, _, _, err := pubs.PublishBatch(context.TODO(), inputs) + + require.NoError(t, err) idx := 0 for v := range queue { publishedMessage := *v - require.Equal(t, jsonString(publishedMessage), inputs[idx]) + require.Equal(t, jsonString(publishedMessage), inputs[idx].Data) idx++ if idx >= len(inputs) { break