Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions example-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ role: generate
# the starenv derefers, lambdafy adds the following derefers:
#
# - lambdafy_sqs_send: This derefer will be replaced with a URL which when POSTed
# to will send a message to the SQS queue whose ARN is specified. The body
# of the POST will be sent as the SQS message body. If header
# 'Lambdafy-SQS-Group-Id' is set, it will be used as Group ID for the
# to will send a message to the SQS queue whose ARN is specified. This accepts
# either a JSON array of messages or a single message. If an array, the body
# of the POST will be split into batches and sent as entries in SQS send message batch.
# Otherwise, if a single messsage, the body of the POST will be sent as the SQS message body.
# If header 'Lambdafy-SQS-Group-Id' is set, it will be used as Group ID for the
# message. A 2xx/3xx response is considered a success, otherwise a fail. See
# the example below for usage.
# Note: The necessary IAM role permissions to send SQS messages are added
Expand Down
129 changes: 121 additions & 8 deletions proxy/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,29 @@ package main
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"mime"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
sqs "github.com/aws/aws-sdk-go-v2/service/sqs"
sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
)

const maxSQSBatchSize = 10 // SQS allows a maximum of 10 messages per batch

var sqsARNPat = regexp.MustCompile(`^arn:aws:sqs:([^:]+):([^:]+):(.+)$`)

// getSQSQueueURL returns the URL of the SQS queue given its ARN.
Expand Down Expand Up @@ -128,10 +134,13 @@ func (d sqsSendDerefer) Deref(arn string) (string, error) {
var sqsIDToQueueURL = sqsSendDerefer{}

const sqsGroupIDHeader = "Lambdafy-SQS-Group-Id"
const batchMessageHeader = "Lambdafy-SQS-Batch-Message"

// handleSQSSend handles HTTP POST requests and translates them to SQS send
// message.
// Lambdafy-SQS-Group-Id header is used to set the message group ID.
// Lambdafy-SQS-Batch-Message header is used to indicate that the request body
// contains a JSON array of messages to be sent in a batch.
func handleSQSSend(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
Expand Down Expand Up @@ -168,18 +177,122 @@ func handleSQSSend(w http.ResponseWriter, r *http.Request) {
}
sqsCl := sqs.NewFromConfig(c)

if _, err := sqsCl.SendMessage(context.Background(), &sqs.SendMessageInput{
MessageBody: aws.String(string(body)),
QueueUrl: aws.String(qURL),
MessageGroupId: groupID,
}); err != nil {
log.Printf("error sending SQS message: %v", err)
http.Error(w, fmt.Sprintf("Error sending SQS message: %v", err), http.StatusInternalServerError)
isBatchMessage := r.Header.Get(batchMessageHeader) != ""
// Single message - use regular send
if !isBatchMessage {
if _, err := sqsCl.SendMessage(context.Background(), &sqs.SendMessageInput{
MessageBody: aws.String(string(body)),
QueueUrl: aws.String(qURL),
MessageGroupId: groupID,
}); err != nil {
log.Printf("error sending SQS message: %v", err)
http.Error(w, fmt.Sprintf("Error sending SQS message: %v", err), http.StatusInternalServerError)
return
}

log.Printf("sent an SQS message to '%s'", qURL)
return
}

// Batch send message - expect the correct Content-Type and
// a JSON array of string messages in the request body

// Check if the Content-Type media type is application/json
// instead of direct string equality check, as it may contain additional parameters.
contentType := r.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType)

if err != nil {
log.Printf("error parsing Content-Type header: %v", err)
http.Error(w, fmt.Sprintf("Error parsing Content-Type header: %v", err), http.StatusBadRequest)
return
}
if mediaType != "application/json" {
http.Error(w, "Content-Type must be application/json for batch messages", http.StatusBadRequest)
return
}

var messages []string
if err := json.Unmarshal(body, &messages); err != nil {
log.Printf("Send message batch failure - Invalid JSON array: %v", err)
http.Error(w, "Invalid JSON array", http.StatusBadRequest)
return
}

log.Printf("sent an SQS message to '%s'", qURL)
if len(messages) == 0 {
log.Printf("Send message batch failure - Empty message array")
http.Error(w, "Empty message array", http.StatusBadRequest)
return
}
if len(messages) > maxSQSBatchSize {
log.Printf("Send message batch failure - Too many messages in batch, maximum is %d", maxSQSBatchSize)
http.Error(w, fmt.Sprintf("Too many messages in batch, maximum is %d", maxSQSBatchSize), http.StatusBadRequest)
return
}

entries := make([]sqstypes.SendMessageBatchRequestEntry, len(messages))
for j, msg := range messages {
entries[j] = sqstypes.SendMessageBatchRequestEntry{
Id: aws.String(fmt.Sprintf("%d", j)),
MessageBody: aws.String(msg),
MessageGroupId: groupID,
}
}

var attempts int = 0
var retryable_entries []sqstypes.SendMessageBatchRequestEntry = entries
var nonRetryableEntries []sqstypes.SendMessageBatchRequestEntry = nil

for (attempts == 0 || len(retryable_entries) > 0) && attempts < 5 {
// Sleep for exponential backoff on retry
if attempts > 0 {
// bit shift to calculate the sleep duration -> 500ms, 1s, 2s, 4s, 8s
sleepDuration := (1 << attempts) * 500 // Exponential backoff in milliseconds
time.Sleep(time.Duration(sleepDuration) * time.Millisecond)
}

attempts++
output, err := sqsCl.SendMessageBatch(context.Background(), &sqs.SendMessageBatchInput{
QueueUrl: aws.String(qURL),
Entries: retryable_entries,
})
if err != nil {
log.Printf("error sending SQS message batch: %v", err)
http.Error(w, fmt.Sprintf("Error sending SQS message batch: %v", err), http.StatusInternalServerError)
return
}
retryable_entries = nil // Reset retryable entries for the next attempt
if len(output.Failed) > 0 {
log.Printf("failed to send %d SQS messages in batch", len(output.Failed))
for _, f := range output.Failed {
fmt.Printf(
"failed to send SQS message %s: %s (SenderFault: %t, Code: %s)\n",
*f.Id, *f.Message, f.SenderFault, *f.Code,
)
id, err := strconv.Atoi(*f.Id)
if err != nil {
log.Printf("error parsing SQS message ID '%s': %v", *f.Id, err)
http.Error(w, fmt.Sprintf("Error parsing SQS message ID '%s': %v", *f.Id, err), http.StatusInternalServerError)
return
}
if f.SenderFault {
// Non-retryable error
nonRetryableEntries = append(nonRetryableEntries, entries[id])
} else {
// Retryable error
retryable_entries = append(retryable_entries, entries[id])
}
}
}
}

if len(retryable_entries)+len(nonRetryableEntries) > 0 {
log.Printf("%d of %d SQS messages in batch failed", len(retryable_entries)+len(nonRetryableEntries), len(entries))
http.Error(w, fmt.Sprintf("%d of %d SQS messages in batch failed", len(retryable_entries)+len(nonRetryableEntries), len(entries)), http.StatusInternalServerError)
return
}

log.Printf("sent %d SQS messages to '%s'", len(messages), qURL)
}

const sendSQSStarenvTag = "lambdafy_sqs_send"