diff --git a/Makefile b/Makefile index fecda8c..af2f0d9 100644 --- a/Makefile +++ b/Makefile @@ -39,3 +39,6 @@ integration-test: ## run-act: act for running github actions on your local machine run-act: act -j test --container-architecture linux/arm64 + +run-benchmarks: + go test -run none -bench . -benchtime=5s \ No newline at end of file diff --git a/batch_consumer.go b/batch_consumer.go index c338b55..c82aeca 100644 --- a/batch_consumer.go +++ b/batch_consumer.go @@ -90,24 +90,42 @@ func (b *batchConsumer) Consume() { func (b *batchConsumer) startBatch() { defer b.wg.Done() - ticker := time.NewTicker(b.messageGroupDuration) - defer ticker.Stop() + flushTimer := time.NewTimer(b.messageGroupDuration) + defer flushTimer.Stop() maximumMessageLimit := b.messageGroupLimit * b.concurrency maximumMessageByteSizeLimit := b.messageGroupByteSizeLimit * b.concurrency + messages := make([]*Message, 0, maximumMessageLimit) commitMessages := make([]kafka.Message, 0, maximumMessageLimit) messageByteSize := 0 + + flushBatch := func(reason string) { + if len(messages) == 0 { + return + } + + b.consume(&messages, &commitMessages, &messageByteSize) + + b.logger.Debugf("[batchConsumer] Flushed batch, reason=%s", reason) + + // After flushing, we always reset the timer + // But first we need to stop it and drain any event that might be pending + if !flushTimer.Stop() { + drainTimer(flushTimer) + } + + // Now reset to start a new "rolling" interval + flushTimer.Reset(b.messageGroupDuration) + } + for { select { - case <-ticker.C: - if len(messages) == 0 { - continue - } - - b.consume(&messages, &commitMessages, &messageByteSize) + case <-flushTimer.C: + flushBatch("time-based (rolling timer)") case msg, ok := <-b.incomingMessageStream: if !ok { + flushBatch("channel-closed (final flush)") close(b.batchConsumingStream) close(b.messageProcessedStream) return @@ -117,7 +135,7 @@ func (b *batchConsumer) startBatch() { // Check if there is an enough byte in batch, if not flush it. if maximumMessageByteSizeLimit != 0 && messageByteSize+msgSize > maximumMessageByteSizeLimit { - b.consume(&messages, &commitMessages, &messageByteSize) + flushBatch("byte-size-limit") } messages = append(messages, msg.message) @@ -126,7 +144,14 @@ func (b *batchConsumer) startBatch() { // Check if there is an enough size in batch, if not flush it. if len(messages) == maximumMessageLimit { - b.consume(&messages, &commitMessages, &messageByteSize) + flushBatch("message-count-limit") + } else { + // Rolling timer logic: reset the timer each time we get a new message + // Because we "stop" it, we might need to drain the channel + if !flushTimer.Stop() { + drainTimer(flushTimer) + } + flushTimer.Reset(b.messageGroupDuration) } } } @@ -144,33 +169,36 @@ func (b *batchConsumer) setupConcurrentWorkers() { } } -func chunkMessages(allMessages *[]*Message, chunkSize int, chunkByteSize int) [][]*Message { +func chunkMessagesOptimized(allMessages []*Message, chunkSize int, chunkByteSize int) [][]*Message { + if chunkSize <= 0 { + panic("chunkSize must be greater than 0") + } + var chunks [][]*Message + totalMessages := len(allMessages) + estimatedChunks := (totalMessages + chunkSize - 1) / chunkSize + chunks = make([][]*Message, 0, estimatedChunks) - allMessageList := *allMessages var currentChunk []*Message - currentChunkSize := 0 + currentChunk = make([]*Message, 0, chunkSize) currentChunkBytes := 0 - for _, message := range allMessageList { + for _, message := range allMessages { messageByteSize := len(message.Value) // Check if adding this message would exceed either the chunk size or the byte size - if len(currentChunk) >= chunkSize || (chunkByteSize != 0 && currentChunkBytes+messageByteSize > chunkByteSize) { - // Avoid too low chunkByteSize + if len(currentChunk) >= chunkSize || (chunkByteSize > 0 && currentChunkBytes+messageByteSize > chunkByteSize) { if len(currentChunk) == 0 { - panic("invalid chunk byte size, please increase it") + panic(fmt.Sprintf("invalid chunk byte size (messageGroupByteSizeLimit) %d, "+ + "message byte size is %d, bigger!, increase chunk byte size limit", chunkByteSize, messageByteSize)) } - // If it does, finalize the current chunk and start a new one chunks = append(chunks, currentChunk) - currentChunk = []*Message{} - currentChunkSize = 0 + currentChunk = make([]*Message, 0, chunkSize) currentChunkBytes = 0 } // Add the message to the current chunk currentChunk = append(currentChunk, message) - currentChunkSize++ currentChunkBytes += messageByteSize } @@ -183,11 +211,11 @@ func chunkMessages(allMessages *[]*Message, chunkSize int, chunkByteSize int) [] } func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka.Message, messageByteSizeLimit *int) { - chunks := chunkMessages(allMessages, b.messageGroupLimit, b.messageGroupByteSizeLimit) + chunks := chunkMessagesOptimized(*allMessages, b.messageGroupLimit, b.messageGroupByteSizeLimit) if b.preBatchFn != nil { preBatchResult := b.preBatchFn(*allMessages) - chunks = chunkMessages(&preBatchResult, b.messageGroupLimit, b.messageGroupByteSizeLimit) + chunks = chunkMessagesOptimized(preBatchResult, b.messageGroupLimit, b.messageGroupByteSizeLimit) } // Send the messages to process diff --git a/batch_consumer_test.go b/batch_consumer_test.go index 1f32005..efe3279 100644 --- a/batch_consumer_test.go +++ b/batch_consumer_test.go @@ -4,8 +4,8 @@ import ( "context" "errors" "reflect" - "strconv" "sync" + "sync/atomic" "testing" "time" @@ -18,7 +18,7 @@ import ( func Test_batchConsumer_startBatch(t *testing.T) { // Given - var numberOfBatch int + var numberOfBatch atomic.Int64 mc := mockReader{} bc := batchConsumer{ @@ -32,10 +32,11 @@ func Test_batchConsumer_startBatch(t *testing.T) { messageGroupDuration: 500 * time.Millisecond, r: &mc, concurrency: 1, + logger: NewZapLogger(LogLevelDebug), }, messageGroupLimit: 3, consumeFn: func(_ []*Message) error { - numberOfBatch++ + numberOfBatch.Add(1) return nil }, } @@ -75,7 +76,7 @@ func Test_batchConsumer_startBatch(t *testing.T) { bc.startBatch() // Then - if numberOfBatch != 2 { + if numberOfBatch.Load() != 2 { t.Fatalf("Number of batch group must equal to 2") } if bc.metric.TotalProcessedMessagesCounter != 4 { @@ -99,6 +100,7 @@ func Test_batchConsumer_startBatch_with_preBatch(t *testing.T) { messageGroupDuration: 20 * time.Second, r: &mc, concurrency: 1, + logger: NewZapLogger(LogLevelDebug), }, messageGroupLimit: 2, consumeFn: func(_ []*Message) error { @@ -322,79 +324,6 @@ func Test_batchConsumer_process(t *testing.T) { }) } -func Test_batchConsumer_chunk(t *testing.T) { - tests := []struct { - allMessages []*Message - expected [][]*Message - chunkSize int - chunkByteSize int - }{ - { - allMessages: createMessages(0, 9), - chunkSize: 3, - chunkByteSize: 10000, - expected: [][]*Message{ - createMessages(0, 3), - createMessages(3, 6), - createMessages(6, 9), - }, - }, - { - allMessages: []*Message{}, - chunkSize: 3, - chunkByteSize: 10000, - expected: [][]*Message{}, - }, - { - allMessages: createMessages(0, 1), - chunkSize: 3, - chunkByteSize: 10000, - expected: [][]*Message{ - createMessages(0, 1), - }, - }, - { - allMessages: createMessages(0, 8), - chunkSize: 3, - chunkByteSize: 10000, - expected: [][]*Message{ - createMessages(0, 3), - createMessages(3, 6), - createMessages(6, 8), - }, - }, - { - allMessages: createMessages(0, 3), - chunkSize: 3, - chunkByteSize: 10000, - expected: [][]*Message{ - createMessages(0, 3), - }, - }, - - { - allMessages: createMessages(0, 3), - chunkSize: 100, - chunkByteSize: 4, - expected: [][]*Message{ - createMessages(0, 1), - createMessages(1, 2), - createMessages(2, 3), - }, - }, - } - - for i, tc := range tests { - t.Run(strconv.Itoa(i), func(t *testing.T) { - chunkedMessages := chunkMessages(&tc.allMessages, tc.chunkSize, tc.chunkByteSize) - - if !reflect.DeepEqual(chunkedMessages, tc.expected) && !(len(chunkedMessages) == 0 && len(tc.expected) == 0) { - t.Errorf("For chunkSize %d, expected %v, but got %v", tc.chunkSize, tc.expected, chunkedMessages) - } - }) - } -} - func Test_batchConsumer_Pause(t *testing.T) { // Given ctx, cancelFn := context.WithCancel(context.Background()) @@ -479,6 +408,187 @@ func Test_batchConsumer_runKonsumerFn(t *testing.T) { }) } +func Test_batchConsumer_chunk(t *testing.T) { + type testCase struct { + name string + allMessages []*Message + chunkSize int + chunkByteSize int + expected [][]*Message + shouldPanic bool + } + + tests := []testCase{ + { + name: "Should_Return_3_Chunks_For_9_Messages", + allMessages: createMessages(0, 9), + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{ + createMessages(0, 3), + createMessages(3, 6), + createMessages(6, 9), + }, + shouldPanic: false, + }, + { + name: "Should_Return_Empty_Slice_When_Input_Is_Empty", + allMessages: []*Message{}, + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{}, + shouldPanic: false, + }, + { + name: "Should_Return_Single_Chunk_When_Single_Message", + allMessages: createMessages(0, 1), + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{ + createMessages(0, 1), + }, + shouldPanic: false, + }, + { + name: "Should_Splits_Into_Multiple_Chunks_With_Incomplete_Final_Chunk", + allMessages: createMessages(0, 8), + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{ + createMessages(0, 3), + createMessages(3, 6), + createMessages(6, 8), + }, + shouldPanic: false, + }, + { + name: "Should_Return_Exact_Chunk_Size_Forms_Single_Chunk", + allMessages: createMessages(0, 3), + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{ + createMessages(0, 3), + }, + shouldPanic: false, + }, + { + name: "Should_Forces_Single_Message_Per_Chunk_When_Small_chunkByteSize_Is_Given", + allMessages: createMessages(0, 3), + chunkSize: 100, + chunkByteSize: 4, // Each message has Value size 4 + expected: [][]*Message{ + createMessages(0, 1), + createMessages(1, 2), + createMessages(2, 3), + }, + shouldPanic: false, + }, + { + name: "Should_Ignore_Byte_Size_When_chunkByteSize=0", + allMessages: createMessages(0, 5), + chunkSize: 2, + chunkByteSize: 0, + expected: [][]*Message{ + createMessages(0, 2), + createMessages(2, 4), + createMessages(4, 5), + }, + shouldPanic: false, + }, + { + name: "Should_Panic_When_chunkByteSize_Less_Than_Message_Size", + allMessages: createMessages(0, 1), + chunkSize: 2, + chunkByteSize: 3, // Message size is 4 + expected: nil, + shouldPanic: true, + }, + { + name: "Should_Panic_When_chunkSize=0", + allMessages: createMessages(0, 1), + chunkSize: 0, + chunkByteSize: 10000, + expected: nil, + shouldPanic: true, + }, + { + name: "Should_Panic_When_Negative_chunkSize", + allMessages: createMessages(0, 1), + chunkSize: -1, + chunkByteSize: 10000, + expected: nil, + shouldPanic: true, + }, + { + name: "Should_Return_Exact_chunkByteSize", + allMessages: createMessages(0, 4), + chunkSize: 2, + chunkByteSize: 8, // Each message has Value size 4, total 16 bytes + expected: [][]*Message{ + createMessages(0, 2), + createMessages(2, 4), + }, + shouldPanic: false, + }, + { + name: "Should_Handle_Varying_Message_Byte_Sizes", + allMessages: []*Message{ + {Partition: 0, Value: []byte("a")}, // 1 byte + {Partition: 1, Value: []byte("ab")}, // 2 bytes + {Partition: 2, Value: []byte("abc")}, // 3 bytes + {Partition: 3, Value: []byte("abcd")}, // 4 bytes + }, + chunkSize: 3, + chunkByteSize: 6, + expected: [][]*Message{ + { + {Partition: 0, Value: []byte("a")}, + {Partition: 1, Value: []byte("ab")}, + {Partition: 2, Value: []byte("abc")}, + }, + { + {Partition: 3, Value: []byte("abcd")}, + }, + }, + shouldPanic: false, + }, + } + + for _, tc := range tests { + tc := tc // Capture range variable + t.Run(tc.name, func(t *testing.T) { + if tc.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic for test case '%s', but did not panic", tc.name) + } + }() + } + + chunkedMessages := chunkMessagesOptimized(tc.allMessages, tc.chunkSize, tc.chunkByteSize) + + if !tc.shouldPanic { + // Verify the number of chunks + if len(chunkedMessages) != len(tc.expected) { + t.Errorf("Test case '%s': expected %d chunks, got %d", tc.name, len(tc.expected), len(chunkedMessages)) + } + + // Verify each chunk's content + for i, expectedChunk := range tc.expected { + if i >= len(chunkedMessages) { + t.Errorf("Test case '%s': missing chunk %d", tc.name, i) + continue + } + actualChunk := chunkedMessages[i] + if !messagesEqual(actualChunk, expectedChunk) { + t.Errorf("Test case '%s': expected chunk %d to be %v, but got %v", tc.name, i, expectedChunk, actualChunk) + } + } + } + }) + } +} + func createMessages(partitionStart int, partitionEnd int) []*Message { messages := make([]*Message, 0) for i := partitionStart; i < partitionEnd; i++ { @@ -490,6 +600,21 @@ func createMessages(partitionStart int, partitionEnd int) []*Message { return messages } +func messagesEqual(a, b []*Message) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].Partition != b[i].Partition { + return false + } + if !reflect.DeepEqual(a[i].Value, b[i].Value) { + return false + } + } + return true +} + type mockCronsumer struct { wantErr bool retryBehaviorOpen bool diff --git a/chunkMessages_benchmark_test.go b/chunkMessages_benchmark_test.go new file mode 100644 index 0000000..b977c17 --- /dev/null +++ b/chunkMessages_benchmark_test.go @@ -0,0 +1,85 @@ +package kafka + +import ( + "math/rand" + "testing" + "time" +) + +func BenchmarkChunkMessages(b *testing.B) { + b.ReportAllocs() + rand.New(rand.NewSource(time.Now().UnixNano())) + messages := generateMessages(10000, 100) // 10,000 messages, each 100 bytes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Create a copy of the messages slice to prevent compiler optimizations + msgsCopy := make([]*Message, len(messages)) + copy(msgsCopy, messages) + oldChunkMessages(&msgsCopy, 100, 10000) + } +} + +func BenchmarkChunkMessagesOptimized(b *testing.B) { + b.ReportAllocs() + rand.New(rand.NewSource(time.Now().UnixNano())) + messages := generateMessages(10000, 100) // 10,000 messages, each 100 bytes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Create a copy of the messages slice to prevent compiler optimizations + msgsCopy := make([]*Message, len(messages)) + copy(msgsCopy, messages) + chunkMessagesOptimized(msgsCopy, 100, 10000) + } +} + +func oldChunkMessages(allMessages *[]*Message, chunkSize int, chunkByteSize int) [][]*Message { + var chunks [][]*Message + + allMessageList := *allMessages + var currentChunk []*Message + currentChunkSize := 0 + currentChunkBytes := 0 + + for _, message := range allMessageList { + messageByteSize := len(message.Value) + + // Check if adding this message would exceed either the chunk size or the byte size + if len(currentChunk) >= chunkSize || (chunkByteSize != 0 && currentChunkBytes+messageByteSize > chunkByteSize) { + // Avoid too low chunkByteSize + if len(currentChunk) == 0 { + panic("invalid chunk byte size, please increase it") + } + // If it does, finalize the current chunk and start a new one + chunks = append(chunks, currentChunk) + currentChunk = []*Message{} + currentChunkSize = 0 + currentChunkBytes = 0 + } + + // Add the message to the current chunk + currentChunk = append(currentChunk, message) + currentChunkSize++ + currentChunkBytes += messageByteSize + } + + // Add the last chunk if it has any messages + if len(currentChunk) > 0 { + chunks = append(chunks, currentChunk) + } + + return chunks +} + +func generateMessages(count int, valueSize int) []*Message { + messages := make([]*Message, count) + for i := 0; i < count; i++ { + b := make([]byte, valueSize) + for j := range b { + b[j] = byte(rand.Intn(26) + 97) + } + messages[i] = &Message{Value: b} + } + return messages +} diff --git a/consumer.go b/consumer.go index d3ee70a..257c5b8 100644 --- a/consumer.go +++ b/consumer.go @@ -68,22 +68,38 @@ func (c *consumer) Consume() { func (c *consumer) startBatch() { defer c.wg.Done() - ticker := time.NewTicker(c.messageGroupDuration) - defer ticker.Stop() + flushTimer := time.NewTimer(c.messageGroupDuration) + defer flushTimer.Stop() messages := make([]*Message, 0, c.concurrency) commitMessages := make([]kafka.Message, 0, c.concurrency) + flushBatch := func(reason string) { + if len(messages) == 0 { + return + } + + c.consume(&messages, &commitMessages) + + c.logger.Debugf("[singleConsumer] Flushed batch, reason=%s", reason) + + // After flushing, we always reset the timer + // But first we need to stop it and drain any event that might be pending + if !flushTimer.Stop() { + drainTimer(flushTimer) + } + + // Now reset to start a new "rolling" interval + flushTimer.Reset(c.messageGroupDuration) + } + for { select { - case <-ticker.C: - if len(messages) == 0 { - continue - } - - c.consume(&messages, &commitMessages) + case <-flushTimer.C: + flushBatch("time-based (rolling timer)") case msg, ok := <-c.incomingMessageStream: if !ok { + flushBatch("channel-closed (final flush)") close(c.singleConsumingStream) close(c.messageProcessedStream) return @@ -93,7 +109,14 @@ func (c *consumer) startBatch() { commitMessages = append(commitMessages, *msg.kafkaMessage) if len(messages) == c.concurrency { - c.consume(&messages, &commitMessages) + flushBatch("message-count-limit") + } else { + // Rolling timer logic: reset the timer each time we get a new message + // Because we "stop" it, we might need to drain the channel + if !flushTimer.Stop() { + drainTimer(flushTimer) + } + flushTimer.Reset(c.messageGroupDuration) } } } diff --git a/consumer_base.go b/consumer_base.go index 083e2df..9ff590d 100644 --- a/consumer_base.go +++ b/consumer_base.go @@ -294,3 +294,10 @@ func (c *base) Stop() error { return err } + +func drainTimer(t *time.Timer) { + select { + case <-t.C: + default: + } +} diff --git a/consumer_base_test.go b/consumer_base_test.go index 40e0000..13589ba 100644 --- a/consumer_base_test.go +++ b/consumer_base_test.go @@ -59,7 +59,6 @@ func Test_base_startConsume(t *testing.T) { t.Error(diff) } }) - t.Run("Skip_Incoming_Messages_When_SkipMessageByHeaderFn_Is_Applied", func(t *testing.T) { // Given mc := mockReader{} @@ -220,6 +219,31 @@ func Test_base_Resume(t *testing.T) { }) } +func Test_drainTimer(t *testing.T) { + // Test case 1: Timer expires before calling drainTimer + t1 := time.NewTimer(10 * time.Millisecond) + time.Sleep(20 * time.Millisecond) // Ensure the timer has expired + drainTimer(t1) + select { + case <-t1.C: + t.Error("Timer channel should be drained but is not.") + default: + // Success, the channel is drained + } + + // clear timer state for test case 2 + t1.Reset(50 * time.Millisecond) + + // Test case 2: Timer is still active when calling drainTimer + drainTimer(t1) + select { + case <-t1.C: + // Timer should still expire normally + case <-time.After(100 * time.Millisecond): + t.Error("Timer did not expire as expected.") + } +} + type mockReader struct { wantErr bool } diff --git a/consumer_test.go b/consumer_test.go index 61623c4..87367ef 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -4,9 +4,77 @@ import ( "context" "errors" "sync" + "sync/atomic" "testing" + "time" + + "github.com/segmentio/kafka-go" ) +func Test_consumer_startBatch(t *testing.T) { + // Given + var numberOfBatch atomic.Int64 + + mc := mockReader{} + c := consumer{ + base: &base{ + incomingMessageStream: make(chan *IncomingMessage, 1), + singleConsumingStream: make(chan *Message, 1), + messageProcessedStream: make(chan struct{}, 1), + metric: &ConsumerMetric{}, + wg: sync.WaitGroup{}, + messageGroupDuration: 500 * time.Millisecond, + r: &mc, + concurrency: 1, + logger: NewZapLogger(LogLevelDebug), + }, + consumeFn: func(*Message) error { + numberOfBatch.Add(1) + return nil + }, + } + + go func() { + // Simulate concurrency of value 3 + c.base.incomingMessageStream <- &IncomingMessage{ + kafkaMessage: &kafka.Message{}, + message: &Message{}, + } + c.base.incomingMessageStream <- &IncomingMessage{ + kafkaMessage: &kafka.Message{}, + message: &Message{}, + } + + time.Sleep(1 * time.Second) + + // Simulate messageGroupDuration + c.base.incomingMessageStream <- &IncomingMessage{ + kafkaMessage: &kafka.Message{}, + message: &Message{}, + } + + time.Sleep(1 * time.Second) + + // Return from startBatch + close(c.base.incomingMessageStream) + }() + + c.base.wg.Add(1 + c.base.concurrency) + + // When + c.setupConcurrentWorkers() + c.startBatch() + + // Then + if numberOfBatch.Load() != 3 { + t.Fatalf("Number of batch group must equal to 3") + } + + if c.metric.TotalProcessedMessagesCounter != 3 { + t.Fatalf("Total Processed Message Counter must equal to 3") + } +} + func Test_consumer_process(t *testing.T) { t.Run("When_Processing_Is_Successful", func(t *testing.T) { // Given