diff --git a/zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategySpec.scala b/zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategySpec.scala new file mode 100644 index 000000000..5d612de26 --- /dev/null +++ b/zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategySpec.scala @@ -0,0 +1,87 @@ +package zio.kafka.consumer.fetch + +import org.apache.kafka.common.TopicPartition +import zio.kafka.ZIOSpecDefaultSlf4j +import zio.kafka.consumer.internal.PartitionStream +import zio.test.{ assertTrue, Spec, TestEnvironment } +import zio.{ Chunk, Scope, UIO, ZIO } + +object ManyPartitionsQueueSizeBasedFetchStrategySpec extends ZIOSpecDefaultSlf4j { + + private val maxPartitionQueueSize = 50 + private val fetchStrategy = ManyPartitionsQueueSizeBasedFetchStrategy( + maxPartitionQueueSize, + maxTotalQueueSize = 80 + ) + + private val tp10 = new TopicPartition("topic1", 0) + private val tp11 = new TopicPartition("topic1", 1) + private val tp20 = new TopicPartition("topic2", 0) + private val tp21 = new TopicPartition("topic2", 1) + private val tp22 = new TopicPartition("topic2", 2) + + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("ManyPartitionsQueueSizeBasedFetchStrategySpec")( + test("stream with queue size above maxSize is paused") { + val streams = Chunk(newStream(tp10, currentQueueSize = 100)) + for { + result <- fetchStrategy.selectPartitionsToFetch(streams) + } yield assertTrue(result.isEmpty) + }, + test("stream with queue size below maxSize may resume when less-equal global max") { + val streams = Chunk(newStream(tp10, currentQueueSize = 10)) + for { + result <- fetchStrategy.selectPartitionsToFetch(streams) + } yield assertTrue(result == Set(tp10)) + }, + test("all streams with queue size less-equal maxSize may resume when total is less-equal global max") { + val streams = Chunk( + newStream(tp10, currentQueueSize = maxPartitionQueueSize), + newStream(tp11, currentQueueSize = 10), + newStream(tp20, currentQueueSize = 10), + newStream(tp21, currentQueueSize = 10) + ) + for { + result <- fetchStrategy.selectPartitionsToFetch(streams) + } yield assertTrue(result == Set(tp10, tp11, tp20, tp21)) + }, + test("not all streams with queue size less-equal maxSize may resume when total is less-equal global max") { + val streams = Chunk( + newStream(tp10, currentQueueSize = 40), + newStream(tp11, currentQueueSize = 40), + newStream(tp20, currentQueueSize = 40), + newStream(tp21, currentQueueSize = 40) + ) + for { + result <- fetchStrategy.selectPartitionsToFetch(streams) + } yield assertTrue(result.size == 2) + }, + test("all streams with queue size less-equal maxSize may resume eventually") { + val streams = Chunk( + newStream(tp10, currentQueueSize = 60), + newStream(tp11, currentQueueSize = 60), + newStream(tp20, currentQueueSize = 40), + newStream(tp21, currentQueueSize = 40), + newStream(tp22, currentQueueSize = 40) + ) + for { + result1 <- fetchStrategy.selectPartitionsToFetch(streams) + result2 <- fetchStrategy.selectPartitionsToFetch(streams) + result3 <- fetchStrategy.selectPartitionsToFetch(streams) + result4 <- fetchStrategy.selectPartitionsToFetch(streams) + result5 <- fetchStrategy.selectPartitionsToFetch(streams) + results = Chunk(result1, result2, result3, result4, result5) + } yield assertTrue( + results.forall(_.size == 2), + results.forall(_.forall(_.topic() == "topic2")), + results.flatten.toSet.size == 3 + ) + } + ) + + private def newStream(topicPartition: TopicPartition, currentQueueSize: Int): PartitionStream = + new PartitionStream { + override def tp: TopicPartition = topicPartition + override def queueSize: UIO[Int] = ZIO.succeed(currentQueueSize) + } +} diff --git a/zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/PollHistorySpec.scala b/zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/PollHistorySpec.scala new file mode 100644 index 000000000..b05f7a7ff --- /dev/null +++ b/zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/PollHistorySpec.scala @@ -0,0 +1,62 @@ +package zio.kafka.consumer.fetch + +import zio.Scope +import zio.kafka.ZIOSpecDefaultSlf4j +import zio.kafka.consumer.fetch.PollHistory.PollHistoryImpl +import zio.test._ + +object PollHistorySpec extends ZIOSpecDefaultSlf4j { + override def spec: Spec[TestEnvironment with Scope, Any] = suite("PollHistorySpec")( + test("estimates poll count for very regular pattern") { + assertTrue( + (("001" * 22) + "").toPollHistory.estimatedPollCountToResume == 3, + (("001" * 22) + "0").toPollHistory.estimatedPollCountToResume == 2, + (("001" * 22) + "00").toPollHistory.estimatedPollCountToResume == 1, + (("00001" * 13) + "").toPollHistory.estimatedPollCountToResume == 5 + ) + }, + test("estimates poll count for somewhat irregular pattern") { + assertTrue( + "000101001001010001000101001001001".toPollHistory.estimatedPollCountToResume == 3 + ) + }, + test("estimates poll count only when paused for less than 16 polls") { + assertTrue( + "0".toPollHistory.estimatedPollCountToResume == 64, + "10000000000000000000000000000000".toPollHistory.estimatedPollCountToResume == 64, + ("11" * 8 + "00" * 8).toPollHistory.estimatedPollCountToResume == 64, + ("11" * 9 + "00" * 7).toPollHistory.estimatedPollCountToResume == 0 + ) + }, + test("estimates poll count for edge cases") { + assertTrue( + "11111111111111111111111111111111".toPollHistory.estimatedPollCountToResume == 1, + "10000000000000001000000000000000".toPollHistory.estimatedPollCountToResume == 1, + "01000000000000000100000000000000".toPollHistory.estimatedPollCountToResume == 2, + "00100000000000000010000000000000".toPollHistory.estimatedPollCountToResume == 3, + "00010000000000000001000000000000".toPollHistory.estimatedPollCountToResume == 4 + ) + }, + test("add to history") { + assertTrue( + PollHistory.Empty.addPollHistory(true).asBitString == "1", + "101010".toPollHistory.addPollHistory(true).asBitString == "1010101", + PollHistory.Empty.addPollHistory(false).asBitString == "0", + "1".toPollHistory.addPollHistory(false).asBitString == "10", + "101010".toPollHistory.addPollHistory(false).asBitString == "1010100", + // Adding resume after a resume is not recorded: + "1".toPollHistory.addPollHistory(true).asBitString == "1", + "10101".toPollHistory.addPollHistory(true).asBitString == "10101" + ) + } + ) + + private implicit class RichPollHistory(private val ph: PollHistory) extends AnyVal { + def asBitString: String = + ph.asInstanceOf[PollHistoryImpl].resumeBits.toBinaryString + } + + private implicit class PollHistoryOps(private val s: String) extends AnyVal { + def toPollHistory: PollHistory = new PollHistoryImpl(java.lang.Long.parseUnsignedLong(s.takeRight(64), 2)) + } +} diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategy.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategy.scala new file mode 100644 index 000000000..fc07653d1 --- /dev/null +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategy.scala @@ -0,0 +1,51 @@ +package zio.kafka.consumer.fetch + +import org.apache.kafka.common.TopicPartition +import zio.{ Chunk, ZIO } +import zio.kafka.consumer.internal.PartitionStream + +import scala.collection.mutable + +/** + * A fetch strategy that allows a stream to fetch data when its queue size is at or below `maxPartitionQueueSize`, as + * long as the total queue size is at or below `maxTotalQueueSize`. This strategy is suitable when + * [[QueueSizeBasedFetchStrategy]] requires too much heap space, particularly when a lot of partitions are being + * consumed. + * + * @param maxPartitionQueueSize + * Maximum number of records to be buffered per partition. This buffer improves throughput and supports varying + * downstream message processing time, while maintaining some backpressure. Large values effectively disable + * backpressure at the cost of high memory usage, low values will effectively disable prefetching in favour of low + * memory consumption. The number of records that is fetched on every poll is controlled by the `max.poll.records` + * setting, the number of records fetched for every partition is somewhere between 0 and `max.poll.records`. + * + * The default value for this parameter is 2 * the default `max.poll.records` of 500, rounded to the nearest power of 2. + * + * @param maxTotalQueueSize + * Maximum number of records to be buffered over all partitions together. This can be used to limit memory usage when + * consuming a large number of partitions. + * + * The default value is 20 * the default for `maxTotalQueueSize`, allowing approximately 20 partitions to do + * pre-fetching in each poll. + */ +final case class ManyPartitionsQueueSizeBasedFetchStrategy( + maxPartitionQueueSize: Int = 1024, + maxTotalQueueSize: Int = 20480 +) extends FetchStrategy { + override def selectPartitionsToFetch( + streams: Chunk[PartitionStream] + ): ZIO[Any, Nothing, Set[TopicPartition]] = { + // By shuffling the streams we prevent read-starvation for streams at the end of the list. + val shuffledStreams = scala.util.Random.shuffle(streams) + ZIO + .foldLeft(shuffledStreams)((mutable.ArrayBuilder.make[TopicPartition], maxTotalQueueSize)) { + case (acc @ (partitions, queueBudget), stream) => + stream.queueSize.map { queueSize => + if (queueSize <= maxPartitionQueueSize && queueSize <= queueBudget) { + (partitions += stream.tp, queueBudget - queueSize) + } else acc + } + } + .map { case (tps, _) => tps.result().toSet } + } +} diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PollHistory.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PollHistory.scala new file mode 100644 index 000000000..1d19a7177 --- /dev/null +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PollHistory.scala @@ -0,0 +1,90 @@ +package zio.kafka.consumer.fetch + +import java.lang.{ Long => JavaLong } + +/** + * Keep track of a partition status ('resumed' or 'paused') history as it is just before a poll. + * + * The goal is to predict in how many polls the partition will be resumed. + * + * WARNING: this is an EXPERIMENTAL API and may change in an incompatible way without notice in any zio-kafka version. + */ +sealed trait PollHistory { + + /** + * @return + * the estimated number of polls before the partition is resumed (a positive number). When no estimate can be made, + * this returns a high positive number. + */ + def estimatedPollCountToResume: Int + + /** + * Creates a new poll history by appending the given partition status as the latest poll. The history length might be + * limited. When the maximum length is reached, older history is discarded. + * + * @param resumed + * true when this partition was 'resumed' before the poll, false when it was 'paused' + */ + def addPollHistory(resumed: Boolean): PollHistory +} + +object PollHistory { + + /** + * An implementation of [[PollHistory]] that stores the poll statuses as bits in an unsigned [[Long]]. + * + * Bit value 1 indicates that the partition was resumed and value 0 indicates it was paused. The most recent poll is + * in the least significant bit, the oldest poll is in the most significant bit. + */ + // exposed only for tests + private[fetch] final class PollHistoryImpl(val resumeBits: Long) extends PollHistory { + override def estimatedPollCountToResume: Int = { + // This class works with 64 bits, but let's assume an 8 bit history for this example. + // Full history is "00100100" + // We are currently paused for 2 polls (last "00") + // The 'before history' contains 2 polls (in "001001", 6 bits long), + // so the average resume cycle is 6 / 2 = 3 polls, + // and the estimated wait time before next resume is + // average resume cycle (3) - currently pause (2) = 1 poll. + + // Now consider the pattern "0100010001000100" (16 bit history). + // It is very regular but the estimate will be off because the oldest cycle + // (at beginning of the bitstring) is not complete. + // We compensate by removing the first cycle from the 'before history'. + // This also helps predicting when the stream only just started. + + // When no resumes are observed in 'before history', we cannot estimate and we return the maximum estimate (64). + + // Also when 'before history' is too short, we can not make a prediction and we return 64. + // We require that 'before history' is at least 16 polls long. + + val currentPausedCount = JavaLong.numberOfTrailingZeros(resumeBits) + val firstPollCycleLength = JavaLong.numberOfLeadingZeros(resumeBits) + 1 + val beforeHistory = resumeBits >>> currentPausedCount + val resumeCount = JavaLong.bitCount(beforeHistory) - 1 + val beforeHistoryLength = JavaLong.SIZE - firstPollCycleLength - currentPausedCount + if (resumeCount == 0 || beforeHistoryLength < 16) { + JavaLong.SIZE + } else { + val averageResumeCycleLength = Math.round(beforeHistoryLength / resumeCount.toDouble).toInt + Math.max(0, averageResumeCycleLength - currentPausedCount) + } + } + + override def addPollHistory(resumed: Boolean): PollHistory = + // When `resumed` is true, and the previous poll was 'resumed' as well, one of 2 cases are possible: + // 1. we're still waiting for the data, + // 2. we did get data, but it was already processed and we need more. + // + // For case 1. we should not add the the history, for case 2 we should. + // We'll err to the conservative side and assume case 1. + if (resumed && ((resumeBits & 1) == 1)) { + this + } else { + new PollHistoryImpl(resumeBits << 1 | (if (resumed) 1 else 0)) + } + } + + /** An empty poll history. */ + val Empty: PollHistory = new PollHistoryImpl(0) +} diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PredictiveFetchStrategy.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PredictiveFetchStrategy.scala new file mode 100644 index 000000000..2dd8acb45 --- /dev/null +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PredictiveFetchStrategy.scala @@ -0,0 +1,53 @@ +package zio.kafka.consumer.fetch + +import org.apache.kafka.common.TopicPartition +import zio.kafka.consumer.internal.PartitionStream +import zio.{ Chunk, ZIO } + +import scala.collection.mutable + +/** + * A fetch strategy that predicts when a stream needs more data by analyzing its history. + * + * The prediction is based on the average number of polls the stream needed to process data in the recent past. In + * addition, a stream can always fetch when it is out of data. + * + * This fetch strategy is suitable when processing takes at a least a few polls. It is especially suitable when + * different streams (partitions) have different processing times, but each stream has consistent processing time. + * + * Note: this strategy has mutable state; a separate instance is needed for each consumer. + * + * @param maxEstimatedPollCountsToFetch + * The maximum number of estimated polls before the stream may fetch data. The default (and minimum) is 1 which means + * that data is fetched 1 poll before it is needed. Setting this higher trades higher memory usage for a lower chance + * a stream needs to wait for data. + */ +final class PredictiveFetchStrategy(maxEstimatedPollCountsToFetch: Int = 1) extends FetchStrategy { + require(maxEstimatedPollCountsToFetch >= 1, s"`pollCount` must be at least 1, got $maxEstimatedPollCountsToFetch") + private val CleanupPollCount = 10 + private var cleanupCountDown = CleanupPollCount + private val pollHistories = mutable.Map.empty[PartitionStream, PollHistory] + + override def selectPartitionsToFetch( + streams: Chunk[PartitionStream] + ): ZIO[Any, Nothing, Set[TopicPartition]] = + ZIO.succeed { + if (cleanupCountDown == 0) { + pollHistories --= (pollHistories.keySet.toSet -- streams) + cleanupCountDown = CleanupPollCount + } else { + cleanupCountDown -= 1 + } + } *> + ZIO + .foldLeft(streams)(mutable.ArrayBuilder.make[TopicPartition]) { case (acc, stream) => + stream.queueSize.map { queueSize => + val outOfData = queueSize == 0 + val pollHistory = pollHistories.getOrElseUpdate(stream, PollHistory.Empty) + val predictiveResume = pollHistory.estimatedPollCountToResume <= maxEstimatedPollCountsToFetch + pollHistories += (stream -> pollHistory.addPollHistory(outOfData)) + if (outOfData || predictiveResume) acc += stream.tp else acc + } + } + .map(_.result().toSet) +}