diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java index e7e17aba6053..a9bb29ce73f0 100644 --- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java @@ -37,6 +37,7 @@ import org.apache.kafka.common.Node; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.RecordBatchTooLargeException; import org.apache.kafka.common.errors.TimeoutException; import org.apache.kafka.common.errors.UnsupportedVersionException; import org.apache.kafka.common.header.Header; @@ -667,6 +668,13 @@ boolean flushInProgress() { return flushesInProgress.get() > 0; } + /** + * This method should be used only for testing. + */ + IncompleteBatches incompleteBatches() { + return incomplete; + } + /* Visible for testing */ Map> batches() { return Collections.unmodifiableMap(batches); @@ -691,17 +699,29 @@ private boolean appendsInProgress() { */ public void awaitFlushCompletion(long timeoutMs) throws InterruptedException { try { + boolean retry; Long expireMs = System.currentTimeMillis() + timeoutMs; - for (ProducerBatch batch : this.incomplete.copyAll()) { - Long currentMs = System.currentTimeMillis(); - if (currentMs > expireMs) { - throw new TimeoutException("Failed to flush accumulated records within" + timeoutMs + "milliseconds."); - } - boolean completed = batch.produceFuture.await(Math.max(expireMs - currentMs, 0), TimeUnit.MILLISECONDS); - if (!completed) { - throw new TimeoutException("Failed to flush accumulated records within" + timeoutMs + "milliseconds."); + do { + retry = false; + for (ProducerBatch batch : this.incomplete.copyAll()) { + Long currentMs = System.currentTimeMillis(); + if (currentMs > expireMs) { + throw new TimeoutException("Failed to flush accumulated records within" + timeoutMs + "milliseconds."); + } + boolean completed = batch.produceFuture.await(Math.max(expireMs - currentMs, 0), TimeUnit.MILLISECONDS); + if (!completed) { + throw new TimeoutException("Failed to flush accumulated records within" + timeoutMs + "milliseconds."); + } + // If the produceFuture failed with RecordBatchTooLargeException, it means that the + // batch was split into smaller batches and re-enqueued into the RecordAccumulator by Sender thread. + // This if condition will make sure to retry and send all the split batches. + // Note that, More records get sent to the broker than necessary because the retry mechanism + // will also include all the newly added records via kafkaProducer.send() api. + if (batch.produceFuture.error() instanceof RecordBatchTooLargeException) { + retry = true; + } } - } + } while (retry); } finally { this.flushesInProgress.decrementAndGet(); } diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java index 12cf53bb873c..a521b342d5ee 100644 --- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java @@ -16,6 +16,8 @@ */ package org.apache.kafka.clients.producer.internals; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.kafka.clients.ApiVersions; import org.apache.kafka.clients.NodeApiVersions; import org.apache.kafka.clients.producer.Callback; @@ -44,6 +46,7 @@ import org.apache.kafka.common.utils.Time; import org.apache.kafka.test.TestUtils; import org.junit.After; +import org.junit.Assert; import org.junit.Test; import java.nio.ByteBuffer; @@ -417,6 +420,54 @@ public void run() { t.start(); } + @Test + public void testSplitAwaitFlushComplete() throws Exception { + RecordAccumulator accum = createTestRecordAccumulator(1024, 10 * 1024, CompressionType.GZIP, 10); + + // Create a big batch + byte[] value = new byte[256]; + // Create a batch such that it fails with RecordBatchTooLargeException + accum.append(new TopicPartition(topic, 0), 0L, null, value, null, null, maxBlockTimeMs); + accum.append(new TopicPartition(topic, 0), 0L, null, value, null, null, maxBlockTimeMs); + + CountDownLatch flushInProgress = new CountDownLatch(1); + Iterator incompleteBatches = accum.incompleteBatches().copyAll().iterator(); + + // Assert that there is only one batch + Assert.assertTrue(incompleteBatches.hasNext()); + ProducerBatch producerBatch = incompleteBatches.next(); + Assert.assertFalse(incompleteBatches.hasNext()); + + AtomicBoolean timedOut = new AtomicBoolean(false); + Thread thread = new Thread(() -> { + Assert.assertTrue(accum.hasIncomplete()); + accum.beginFlush(); + Assert.assertTrue(accum.flushInProgress()); + try { + flushInProgress.countDown(); + accum.awaitFlushCompletion(2000); + } catch (TimeoutException timeoutException) { + // Catch it and set the timedout variable + // This is the only valid path for this thread. + timedOut.set(true); + } catch (InterruptedException e) { + } + }); + thread.start(); + flushInProgress.await(); + // Wait for 100ms to make sure that the flush is actually in progress + Thread.sleep(100); + + // Split the big batch and re-enqueue + accum.splitAndReenqueue(producerBatch); + accum.deallocate(producerBatch); + + thread.join(); + // The thread would have failed with timeout exception because the child batches + // are not evaluated and it would have waited for 2seconds before the timeout. + Assert.assertTrue("The thread should have timed out", timedOut.get()); + } + @Test public void testAwaitFlushComplete() throws Exception { RecordAccumulator accum = createTestRecordAccumulator(