diff --git a/README.md b/README.md index 6041a60..fda7072 100644 --- a/README.md +++ b/README.md @@ -256,6 +256,7 @@ under [the specified folder](examples/with-sasl-plaintext) and then start the ap | `retryConfiguration.sasl.password` | SCRAM OR PLAIN password | | | `retryConfiguration.skipMessageByHeaderFn` | Function to filter messages based on headers, return true if you want to skip the message | nil | | `batchConfiguration.messageGroupLimit` | Maximum number of messages in a batch | | +| `batchConfiguration.messageGroupByteSizeLimit` | Maximum number of bytes in a batch | | | `batchConfiguration.batchConsumeFn` | Kafka batch consumer function, if retry enabled it, is also used to consume retriable messages | | | `batchConfiguration.preBatchFn` | This function enable for transforming messages before batch consuming starts | | | `batchConfiguration.balancer` | [see doc](https://pkg.go.dev/github.com/segmentio/kafka-go#Balancer) | leastBytes | diff --git a/batch_consumer.go b/batch_consumer.go index 647d065..e949e90 100644 --- a/batch_consumer.go +++ b/batch_consumer.go @@ -17,7 +17,8 @@ type batchConsumer struct { consumeFn BatchConsumeFn preBatchFn PreBatchFn - messageGroupLimit int + messageGroupLimit int + messageGroupByteSizeLimit int } func (b *batchConsumer) Pause() { @@ -34,11 +35,17 @@ func newBatchConsumer(cfg *ConsumerConfig) (Consumer, error) { return nil, err } + messageGroupByteSizeLimit, err := resolveUnionIntOrStringValue(cfg.BatchConfiguration.MessageGroupByteSizeLimit) + if err != nil { + return nil, err + } + c := batchConsumer{ - base: consumerBase, - consumeFn: cfg.BatchConfiguration.BatchConsumeFn, - preBatchFn: cfg.BatchConfiguration.PreBatchFn, - messageGroupLimit: cfg.BatchConfiguration.MessageGroupLimit, + base: consumerBase, + consumeFn: cfg.BatchConfiguration.BatchConsumeFn, + preBatchFn: cfg.BatchConfiguration.PreBatchFn, + messageGroupLimit: cfg.BatchConfiguration.MessageGroupLimit, + messageGroupByteSizeLimit: messageGroupByteSizeLimit, } if cfg.RetryEnabled { @@ -86,9 +93,10 @@ func (b *batchConsumer) startBatch() { defer ticker.Stop() maximumMessageLimit := b.messageGroupLimit * b.concurrency + maximumMessageByteSizeLimit := b.messageGroupByteSizeLimit * b.concurrency messages := make([]*Message, 0, maximumMessageLimit) commitMessages := make([]kafka.Message, 0, maximumMessageLimit) - + messageByteSize := 0 for { select { case <-ticker.C: @@ -96,7 +104,7 @@ func (b *batchConsumer) startBatch() { continue } - b.consume(&messages, &commitMessages) + b.consume(&messages, &commitMessages, &messageByteSize) case msg, ok := <-b.incomingMessageStream: if !ok { close(b.batchConsumingStream) @@ -104,11 +112,20 @@ func (b *batchConsumer) startBatch() { return } + msgSize := msg.message.TotalSize() + + // Check if there is an enough byte in batch, if not flush it. + if maximumMessageByteSizeLimit != 0 && messageByteSize+msgSize > maximumMessageByteSizeLimit { + b.consume(&messages, &commitMessages, &messageByteSize) + } + messages = append(messages, msg.message) commitMessages = append(commitMessages, *msg.kafkaMessage) + messageByteSize += msgSize + // Check if there is an enough size in batch, if not flush it. if len(messages) == maximumMessageLimit { - b.consume(&messages, &commitMessages) + b.consume(&messages, &commitMessages, &messageByteSize) } } } @@ -126,31 +143,50 @@ func (b *batchConsumer) setupConcurrentWorkers() { } } -func chunkMessages(allMessages *[]*Message, chunkSize int) [][]*Message { +func chunkMessages(allMessages *[]*Message, chunkSize int, chunkByteSize int) [][]*Message { var chunks [][]*Message allMessageList := *allMessages - for i := 0; i < len(allMessageList); i += chunkSize { - end := i + chunkSize - - // necessary check to avoid slicing beyond - // slice capacity - if end > len(allMessageList) { - end = len(allMessageList) + var currentChunk []*Message + currentChunkSize := 0 + currentChunkBytes := 0 + + for _, message := range allMessageList { + messageByteSize := len(message.Value) + + // Check if adding this message would exceed either the chunk size or the byte size + if len(currentChunk) >= chunkSize || (chunkByteSize != 0 && currentChunkBytes+messageByteSize > chunkByteSize) { + // Avoid too low chunkByteSize + if len(currentChunk) == 0 { + panic("invalid chunk byte size, please increase it") + } + // If it does, finalize the current chunk and start a new one + chunks = append(chunks, currentChunk) + currentChunk = []*Message{} + currentChunkSize = 0 + currentChunkBytes = 0 } - chunks = append(chunks, allMessageList[i:end]) + // Add the message to the current chunk + currentChunk = append(currentChunk, message) + currentChunkSize++ + currentChunkBytes += messageByteSize + } + + // Add the last chunk if it has any messages + if len(currentChunk) > 0 { + chunks = append(chunks, currentChunk) } return chunks } -func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka.Message) { - chunks := chunkMessages(allMessages, b.messageGroupLimit) +func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka.Message, messageByteSizeLimit *int) { + chunks := chunkMessages(allMessages, b.messageGroupLimit, b.messageGroupByteSizeLimit) if b.preBatchFn != nil { preBatchResult := b.preBatchFn(*allMessages) - chunks = chunkMessages(&preBatchResult, b.messageGroupLimit) + chunks = chunkMessages(&preBatchResult, b.messageGroupLimit, b.messageGroupByteSizeLimit) } // Send the messages to process @@ -170,6 +206,7 @@ func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka // Clearing resources *commitMessages = (*commitMessages)[:0] *allMessages = (*allMessages)[:0] + *messageByteSizeLimit = 0 } func (b *batchConsumer) process(chunkMessages []*Message) { diff --git a/batch_consumer_test.go b/batch_consumer_test.go index 55d1684..9cccaae 100644 --- a/batch_consumer_test.go +++ b/batch_consumer_test.go @@ -301,13 +301,15 @@ func Test_batchConsumer_process(t *testing.T) { func Test_batchConsumer_chunk(t *testing.T) { tests := []struct { - allMessages []*Message - expected [][]*Message - chunkSize int + allMessages []*Message + expected [][]*Message + chunkSize int + chunkByteSize int }{ { - allMessages: createMessages(0, 9), - chunkSize: 3, + allMessages: createMessages(0, 9), + chunkSize: 3, + chunkByteSize: 10000, expected: [][]*Message{ createMessages(0, 3), createMessages(3, 6), @@ -315,20 +317,23 @@ func Test_batchConsumer_chunk(t *testing.T) { }, }, { - allMessages: []*Message{}, - chunkSize: 3, - expected: [][]*Message{}, + allMessages: []*Message{}, + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{}, }, { - allMessages: createMessages(0, 1), - chunkSize: 3, + allMessages: createMessages(0, 1), + chunkSize: 3, + chunkByteSize: 10000, expected: [][]*Message{ createMessages(0, 1), }, }, { - allMessages: createMessages(0, 8), - chunkSize: 3, + allMessages: createMessages(0, 8), + chunkSize: 3, + chunkByteSize: 10000, expected: [][]*Message{ createMessages(0, 3), createMessages(3, 6), @@ -336,17 +341,29 @@ func Test_batchConsumer_chunk(t *testing.T) { }, }, { - allMessages: createMessages(0, 3), - chunkSize: 3, + allMessages: createMessages(0, 3), + chunkSize: 3, + chunkByteSize: 10000, expected: [][]*Message{ createMessages(0, 3), }, }, + + { + allMessages: createMessages(0, 3), + chunkSize: 100, + chunkByteSize: 4, + expected: [][]*Message{ + createMessages(0, 1), + createMessages(1, 2), + createMessages(2, 3), + }, + }, } for i, tc := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { - chunkedMessages := chunkMessages(&tc.allMessages, tc.chunkSize) + chunkedMessages := chunkMessages(&tc.allMessages, tc.chunkSize, tc.chunkByteSize) if !reflect.DeepEqual(chunkedMessages, tc.expected) && !(len(chunkedMessages) == 0 && len(tc.expected) == 0) { t.Errorf("For chunkSize %d, expected %v, but got %v", tc.chunkSize, tc.expected, chunkedMessages) @@ -444,6 +461,7 @@ func createMessages(partitionStart int, partitionEnd int) []*Message { for i := partitionStart; i < partitionEnd; i++ { messages = append(messages, &Message{ Partition: i, + Value: []byte("test"), }) } return messages diff --git a/consumer_config.go b/consumer_config.go index 3610f6b..eb24c5f 100644 --- a/consumer_config.go +++ b/consumer_config.go @@ -165,9 +165,10 @@ type RetryConfiguration struct { } type BatchConfiguration struct { - BatchConsumeFn BatchConsumeFn - PreBatchFn PreBatchFn - MessageGroupLimit int + BatchConsumeFn BatchConsumeFn + PreBatchFn PreBatchFn + MessageGroupLimit int + MessageGroupByteSizeLimit any } func (cfg *ConsumerConfig) newKafkaDialer() (*kafka.Dialer, error) { diff --git a/data_units.go b/data_units.go new file mode 100644 index 0000000..3b7b7ee --- /dev/null +++ b/data_units.go @@ -0,0 +1,63 @@ +package kafka + +import ( + "fmt" + "strconv" + "strings" +) + +func resolveUnionIntOrStringValue(input any) (int, error) { + switch value := input.(type) { + case int: + return value, nil + case uint: + return int(value), nil + case nil: + return 0, nil + case string: + intValue, err := strconv.ParseInt(value, 10, 64) + if err == nil { + return int(intValue), nil + } + + result, err := convertSizeUnitToByte(value) + if err != nil { + return 0, err + } + + return result, nil + } + + return 0, fmt.Errorf("invalid input: %v", input) +} + +func convertSizeUnitToByte(str string) (int, error) { + if len(str) < 2 { + return 0, fmt.Errorf("invalid input: %s", str) + } + + // Extract the numeric part of the input + sizeStr := str[:len(str)-2] + sizeStr = strings.TrimSpace(sizeStr) + sizeStr = strings.ReplaceAll(sizeStr, ",", ".") + + size, err := strconv.ParseFloat(sizeStr, 64) + if err != nil { + return 0, fmt.Errorf("cannot extract numeric part for the input %s, err = %w", str, err) + } + + // Determine the unit (B, KB, MB, GB) + unit := str[len(str)-2:] + switch strings.ToUpper(unit) { + case "B": + return int(size), nil + case "KB": + return int(size * 1024), nil + case "MB": + return int(size * 1024 * 1024), nil + case "GB": + return int(size * 1024 * 1024 * 1024), nil + default: + return 0, fmt.Errorf("unsupported unit: %s, you can specify one of B, KB, MB and GB", unit) + } +} diff --git a/data_units_test.go b/data_units_test.go new file mode 100644 index 0000000..a9c762e --- /dev/null +++ b/data_units_test.go @@ -0,0 +1,85 @@ +package kafka + +import "testing" + +func TestDcp_ResolveConnectionBufferSize(t *testing.T) { + tests := []struct { + input any + name string + want int + }{ + { + name: "When_Client_Gives_Int_Value", + input: 20971520, + want: 20971520, + }, + { + name: "When_Client_Gives_UInt_Value", + input: uint(10971520), + want: 10971520, + }, + { + name: "When_Client_Gives_StringInt_Value", + input: "15971520", + want: 15971520, + }, + { + name: "When_Client_Gives_KB_Value", + input: "500kb", + want: 500 * 1024, + }, + { + name: "When_Client_Gives_MB_Value", + input: "10mb", + want: 10 * 1024 * 1024, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got, _ := resolveUnionIntOrStringValue(tt.input); got != tt.want { + t.Errorf("ResolveConnectionBufferSize() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConvertToBytes(t *testing.T) { + testCases := []struct { + input string + expected int + err bool + }{ + {"1kb", 1024, false}, + {"5mb", 5 * 1024 * 1024, false}, + {"5,5mb", 5.5 * 1024 * 1024, false}, + {"8.5mb", 8.5 * 1024 * 1024, false}, + {"10,25 mb", 10.25 * 1024 * 1024, false}, + {"10gb", 10 * 1024 * 1024 * 1024, false}, + {"1KB", 1024, false}, + {"5MB", 5 * 1024 * 1024, false}, + {"12 MB", 12 * 1024 * 1024, false}, + {"10GB", 10 * 1024 * 1024 * 1024, false}, + {"123", 0, true}, + {"15TB", 0, true}, + {"invalid", 0, true}, + {"", 0, true}, + {"123 KB", 123 * 1024, false}, + {"1 MB", 1 * 1024 * 1024, false}, + } + + for _, tc := range testCases { + result, err := convertSizeUnitToByte(tc.input) + + if tc.err && err == nil { + t.Errorf("Expected an error for input %s, but got none", tc.input) + } + + if !tc.err && err != nil { + t.Errorf("Unexpected error for input %s: %v", tc.input, err) + } + + if result != tc.expected { + t.Errorf("For input %s, expected %d bytes, but got %d", tc.input, tc.expected, result) + } + } +} diff --git a/go.sum b/go.sum index 95a2f10..c23cb6f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/Trendyol/kafka-cronsumer v1.5.0 h1:MI0/ncHrlCvOV0Ro4h9avm2izsNprBw4QfabiSnzm0U= -github.com/Trendyol/kafka-cronsumer v1.5.0/go.mod h1:VpweJmKY+6dppFhzWOZDbZfxBNuJkSxB12CcuZWBNFU= github.com/Trendyol/kafka-cronsumer v1.5.1 h1:L8RLxo8zSGOfVpjtXLUqL3PsJLZdeoFcOvN1yCY/GyQ= github.com/Trendyol/kafka-cronsumer v1.5.1/go.mod h1:VpweJmKY+6dppFhzWOZDbZfxBNuJkSxB12CcuZWBNFU= github.com/Trendyol/otel-kafka-konsumer v0.0.7 h1:sT1TE2rgfsdrJWrXKz5j6dPkKJsvP+Tv0Dea4ORqJ+4= diff --git a/message.go b/message.go index 58cb114..2022eba 100644 --- a/message.go +++ b/message.go @@ -39,6 +39,27 @@ type Message struct { ErrDescription string } +func (m *Message) TotalSize() int { + return 14 + m.keySize() + m.valueSize() + m.headerSize() +} + +func (m *Message) headerSize() int { + s := 0 + for _, header := range m.Headers { + s += sizeofString(header.Key) + s += len(header.Value) + } + return s +} + +func (m *Message) keySize() int { + return sizeofBytes(m.Key) +} + +func (m *Message) valueSize() int { + return sizeofBytes(m.Value) +} + type IncomingMessage struct { kafkaMessage *kafka.Message message *Message @@ -157,3 +178,11 @@ func (m *Message) RemoveHeader(header Header) { } } } + +func sizeofBytes(b []byte) int { + return 4 + len(b) +} + +func sizeofString(s string) int { + return 2 + len(s) +} diff --git a/message_test.go b/message_test.go index 03e5012..b9dcd3f 100644 --- a/message_test.go +++ b/message_test.go @@ -98,6 +98,25 @@ func TestMessage_AddHeader(t *testing.T) { }) } +func TestMessage_Size(t *testing.T) { + // Given + m := Message{ + Headers: []kafka.Header{ + {Key: "foo", Value: []byte("fooValue")}, + }, + Value: []byte("barValue"), + Key: []byte("bar"), + } + + // When + s := m.TotalSize() + + // Then + if s != 46 { + t.Fatalf("Total message size must be equal to 46") + } +} + func TestMessage_RemoveHeader(t *testing.T) { // Given m := Message{ diff --git a/tls_test.go b/tls_test.go index 71860dc..e3351e0 100644 --- a/tls_test.go +++ b/tls_test.go @@ -26,7 +26,6 @@ func TestTLSConfig_TLSConfig(t *testing.T) { // When _, err = tlsCfg.TLSConfig() - // Then if err != nil { t.Fatalf("Error when settings tls certificates %s", err.Error())