Skip to content

Commit

Permalink
feat: implement logic for s3 select object content stream
Browse files Browse the repository at this point in the history
  • Loading branch information
benmcclelland committed Dec 6, 2023
1 parent 4881892 commit 19c1395
Showing 1 changed file with 308 additions and 2 deletions.
310 changes: 308 additions & 2 deletions s3select/message-handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,334 @@ package s3select
import (
"bufio"
"context"
"encoding/binary"
"encoding/xml"
"fmt"
"hash/crc32"
"sync"
"sync/atomic"
"time"
)

// Protocol definition for messages can be found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTSelectObjectAppendix.html

var (
// From ptotocol def:
// Enum indicating the header value type.
// For Amazon S3 Select, this is always 7.
headerValueType = byte(7)
)

func intToTwoBytes(i int) []byte {
return []byte{byte(i >> 8), byte(i)}
}

func generateHeader(messages ...string) []byte {
var header []byte

for i, message := range messages {
if i%2 == 1 {
header = append(header, headerValueType)
header = append(header, intToTwoBytes(len(message))...)
} else {
header = append(header, byte(len(message)))
}
header = append(header, message...)
}

return header
}

func generateOctetHeader(message string) []byte {
return generateHeader(
":message-type",
"event",
":content-type",
"application/octet-stream",
":event-type",
message)
}

func generateTextHeader(message string) []byte {
return generateHeader(
":message-type",
"event",
":content-type",
"text/xml",
":event-type",
message)
}

func generateNoContentHeader(message string) []byte {
return generateHeader(
":message-type",
"event",
":event-type",
message)
}

const (
// 4 bytes total byte len +
// 4 bytes headers bytes len +
// 4 bytes prelude CRC
preludeLen = 12
// CRC is uint32
msgCrcLen = 4
)

var (
recordsHeader = generateOctetHeader("Records")
continuationHeader = generateNoContentHeader("Cont")
continuationMessage = genMessage(continuationHeader, []byte{})
progressHeader = generateTextHeader("Progress")
statsHeader = generateTextHeader("Stats")
endHeader = generateNoContentHeader("End")
endMessage = genMessage(endHeader, []byte{})
)

func uintToBytes(n uint32) []byte {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, n)
return b
}

func generatePrelude(msgLen int, headerLen int) []byte {
prelude := make([]byte, 0, preludeLen)

// 4 bytes total byte len
prelude = append(prelude, uintToBytes(uint32(msgLen+headerLen+preludeLen+msgCrcLen))...)
// 4 bytes headers bytes len
prelude = append(prelude, uintToBytes(uint32(headerLen))...)
// 4 bytes prelude CRC
prelude = append(prelude, uintToBytes(crc32.ChecksumIEEE(prelude))...)

return prelude
}

const (
maxHeaderSize = 1024 * 1024
maxMessageSize = 5 * 1024 * 1024 * 1024
)

func genMessage(header, payload []byte) []byte {
var msg []byte
// below is always true since the size is validated
// in the send record
if len(header) <= maxHeaderSize && len(payload) <= maxMessageSize {
msglen := preludeLen + len(header) + len(payload) + msgCrcLen
msg = make([]byte, 0, msglen)
}

msg = append(msg, generatePrelude(len(payload), len(header))...)
msg = append(msg, header...)
msg = append(msg, payload...)
msg = append(msg, uintToBytes(crc32.ChecksumIEEE(msg))...)

return msg
}

func genRecordsMessage(payload []byte) []byte {
return genMessage(recordsHeader, payload)
}

type progress struct {
XMLName xml.Name `xml:"Progress"`
BytesScanned int64 `xml:"BytesScanned"`
BytesProcessed int64 `xml:"BytesProcessed"`
BytesReturned int64 `xml:"BytesReturned"`
}

func genProgressMessage(bytesScanned, bytesProcessed, bytesReturned int64) []byte {
progress := progress{
BytesScanned: bytesScanned,
BytesProcessed: bytesProcessed,
BytesReturned: bytesReturned,
}

xmlData, _ := xml.MarshalIndent(progress, "", " ")
payload := []byte(xml.Header + string(xmlData))
return genMessage(progressHeader, payload)
}

type stats struct {
XMLName xml.Name `xml:"Stats"`
BytesScanned int64 `xml:"BytesScanned"`
BytesProcessed int64 `xml:"BytesProcessed"`
BytesReturned int64 `xml:"BytesReturned"`
}

func genStatsMessage(bytesScanned, bytesProcessed, bytesReturned int64) []byte {
stats := stats{
BytesScanned: bytesScanned,
BytesProcessed: bytesProcessed,
BytesReturned: bytesReturned,
}

xmlData, _ := xml.MarshalIndent(stats, "", " ")
payload := []byte(xml.Header + string(xmlData))
return genMessage(statsHeader, payload)
}

func genErrorMessage(errorCode, errorMessage string) []byte {
return genMessage(generateHeader(
":error-code",
errorCode,
":error-message",
errorMessage,
":message-type",
"error",
), []byte{})
}

type GetProgress func() (bytesScanned int64, bytesProcessed int64)

type MessageHandler struct{}
type MessageHandler struct {
sync.Mutex
ctx context.Context
cancel context.CancelFunc
writer *bufio.Writer
data chan []byte
getProgress GetProgress
stopCh chan bool
resetCh chan bool
bytesReturned int64
}

// Creates a new MessageHandler instance and starts the event streaming
func NewMessageHandler(ctx context.Context, w *bufio.Writer, getProgressFunc GetProgress) *MessageHandler {
return &MessageHandler{}
ctx, cancel := context.WithCancel(ctx)

mh := &MessageHandler{
ctx: ctx,
cancel: cancel,
writer: w,
data: make(chan []byte),
getProgress: getProgressFunc,
resetCh: make(chan bool),
stopCh: make(chan bool),
}

go mh.sendBackgroundMessages(mh.resetCh, mh.stopCh)
return mh
}

func (mh *MessageHandler) write(data []byte) error {
mh.Lock()
defer mh.Unlock()

mh.stopCh <- true
defer func() { mh.resetCh <- true }()

_, err := mh.writer.Write(data)
if err != nil {
return err
}

return mh.writer.Flush()
}

const (
continuationInterval = time.Second
progressInterval = time.Minute
)

func (mh *MessageHandler) sendBackgroundMessages(resetCh, stopCh <-chan bool) {
continuationTicker := time.NewTicker(continuationInterval)
defer continuationTicker.Stop()

var progressTicker *time.Ticker
var progressTickerChan <-chan time.Time
if mh.getProgress != nil {
progressTicker = time.NewTicker(progressInterval)
progressTickerChan = progressTicker.C
defer progressTicker.Stop()
}

Loop:
for {
select {
case <-mh.ctx.Done():
break Loop

case <-continuationTicker.C:
err := mh.write(continuationMessage)
if err != nil {
mh.cancel()
break Loop
}

case <-resetCh:
continuationTicker.Reset(continuationInterval)

case <-stopCh:
continuationTicker.Stop()

case <-progressTickerChan:
bytesScanned, bytesProcessed := mh.getProgress()
bytesReturned := atomic.LoadInt64(&mh.bytesReturned)
err := mh.write(genProgressMessage(bytesScanned, bytesProcessed, bytesReturned))
if err != nil {
mh.cancel()
break Loop
}
}
}
}

// SendRecord sends a single Records message
func (mh *MessageHandler) SendRecord(payload []byte) error {
if mh.ctx.Err() != nil {
return mh.ctx.Err()
}

if len(payload) > maxMessageSize {
return fmt.Errorf("record max size exceeded")
}

err := mh.write(genRecordsMessage(payload))
if err != nil {
return err
}

atomic.AddInt64(&mh.bytesReturned, int64(len(payload)))
return nil
}

// Finish terminates message stream with Stat and End message
func (mh *MessageHandler) Finish() error {
if mh.ctx.Err() != nil {
return mh.ctx.Err()
}

if mh.getProgress != nil {
bytesScanned, bytesProcessed := mh.getProgress()

err := mh.write(genStatsMessage(bytesScanned, bytesProcessed, mh.bytesReturned))
if err != nil {
return err
}
}

err := mh.write(endMessage)
if err != nil {
return err
}

mh.cancel()
return nil
}

// FinishWithError terminates event stream with error
func (mh *MessageHandler) FinishWithError(errorCode, errorMessage string) error {
if mh.ctx.Err() != nil {
return mh.ctx.Err()
}
err := mh.write(genErrorMessage(errorCode, errorMessage))
if err != nil {
return err
}

mh.cancel()
return nil
}

0 comments on commit 19c1395

Please sign in to comment.