diff --git a/spring-kafka-docs/src/main/antora/modules/ROOT/pages/kafka/seek.adoc b/spring-kafka-docs/src/main/antora/modules/ROOT/pages/kafka/seek.adoc index 8b97c94cb..ba4edab22 100644 --- a/spring-kafka-docs/src/main/antora/modules/ROOT/pages/kafka/seek.adoc +++ b/spring-kafka-docs/src/main/antora/modules/ROOT/pages/kafka/seek.adoc @@ -186,17 +186,18 @@ public class SeekToLastOnIdleListener extends AbstractConsumerSeekAware { * Rewind all partitions one record. */ public void rewindAllOneRecord() { - getSeekCallbacks() - .forEach((tp, callback) -> - callback.seekRelative(tp.topic(), tp.partition(), -1, true)); + getTopicsAndCallbacks() + .forEach((tp, callbacks) -> + callbacks.forEach(callback -> callback.seekRelative(tp.topic(), tp.partition(), -1, true)) + ); } /** * Rewind one partition one record. */ public void rewindOnePartitionOneRecord(String topic, int partition) { - getSeekCallbackFor(new TopicPartition(topic, partition)) - .seekRelative(topic, partition, -1, true); + getSeekCallbacksFor(new TopicPartition(topic, partition)) + .forEach(callback -> callback.seekRelative(topic, partition, -1, true)); } } diff --git a/spring-kafka-docs/src/main/antora/modules/ROOT/pages/whats-new.adoc b/spring-kafka-docs/src/main/antora/modules/ROOT/pages/whats-new.adoc index 5688b7b4d..2c2dc2f25 100644 --- a/spring-kafka-docs/src/main/antora/modules/ROOT/pages/whats-new.adoc +++ b/spring-kafka-docs/src/main/antora/modules/ROOT/pages/whats-new.adoc @@ -17,6 +17,8 @@ The naming convention for DLT topics has been standardized to use the "-dlt" suf A new method, `getGroupId()`, has been added to the `ConsumerSeekCallback` interface. This method allows for more selective seek operations by targeting only the desired consumer group. +The `AbstractConsumerSeekAware` can also now register, retrieve, and remove all callbacks for each topic partition in a multi-group listener scenario without missing any. +See the new APIs (`getSeekCallbacksFor(TopicPartition topicPartition)`, `getTopicsAndCallbacks()`) for more details. For more details, see xref:kafka/seek.adoc#seek[Seek API Docs]. [[x33-new-option-ignore-empty-batch]] diff --git a/spring-kafka/src/main/java/org/springframework/kafka/listener/AbstractConsumerSeekAware.java b/spring-kafka/src/main/java/org/springframework/kafka/listener/AbstractConsumerSeekAware.java index 093a1c356..0af8a3d27 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/listener/AbstractConsumerSeekAware.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/listener/AbstractConsumerSeekAware.java @@ -1,5 +1,5 @@ /* - * Copyright 2019-2023 the original author or authors. + * Copyright 2019-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,16 +16,19 @@ package org.springframework.kafka.listener; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; import org.apache.kafka.common.TopicPartition; import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; /** * Manages the {@link ConsumerSeekAware.ConsumerSeekCallback} s for the listener. If the @@ -33,6 +36,7 @@ * having to keep track of the callbacks itself. * * @author Gary Russell + * @author Borahm Lee * @since 2.3 * */ @@ -40,9 +44,9 @@ public abstract class AbstractConsumerSeekAware implements ConsumerSeekAware { private final Map callbackForThread = new ConcurrentHashMap<>(); - private final Map callbacks = new ConcurrentHashMap<>(); + private final Map> topicToCallbacks = new ConcurrentHashMap<>(); - private final Map> callbacksToTopic = new ConcurrentHashMap<>(); + private final Map> callbackToTopics = new ConcurrentHashMap<>(); @Override public void registerSeekCallback(ConsumerSeekCallback callback) { @@ -54,8 +58,8 @@ public void onPartitionsAssigned(Map assignments, Consumer ConsumerSeekCallback threadCallback = this.callbackForThread.get(Thread.currentThread()); if (threadCallback != null) { assignments.keySet().forEach(tp -> { - this.callbacks.put(tp, threadCallback); - this.callbacksToTopic.computeIfAbsent(threadCallback, key -> new LinkedList<>()).add(tp); + this.topicToCallbacks.computeIfAbsent(tp, key -> new ArrayList<>()).add(threadCallback); + this.callbackToTopics.computeIfAbsent(threadCallback, key -> new LinkedList<>()).add(tp); }); } } @@ -63,15 +67,17 @@ public void onPartitionsAssigned(Map assignments, Consumer @Override public void onPartitionsRevoked(Collection partitions) { partitions.forEach(tp -> { - ConsumerSeekCallback removed = this.callbacks.remove(tp); - if (removed != null) { - List topics = this.callbacksToTopic.get(removed); - if (topics != null) { - topics.remove(tp); - if (topics.size() == 0) { - this.callbacksToTopic.remove(removed); + List removedCallbacks = this.topicToCallbacks.remove(tp); + if (removedCallbacks != null && !removedCallbacks.isEmpty()) { + removedCallbacks.forEach(cb -> { + List topics = this.callbackToTopics.get(cb); + if (topics != null) { + topics.remove(tp); + if (topics.isEmpty()) { + this.callbackToTopics.remove(cb); + } } - } + }); } }); } @@ -82,21 +88,55 @@ public void unregisterSeekCallback() { } /** - * Return the callback for the specified topic/partition. - * @param topicPartition the topic/partition. - * @return the callback (or null if there is no assignment). - */ + * Return the callback for the specified topic/partition. + * @param topicPartition the topic/partition. + * @return the callback (or null if there is no assignment). + * @deprecated Replaced by {@link #getSeekCallbacksFor(TopicPartition)} + */ + @Deprecated(since = "3.3", forRemoval = true) @Nullable protected ConsumerSeekCallback getSeekCallbackFor(TopicPartition topicPartition) { - return this.callbacks.get(topicPartition); + List callbacks = getSeekCallbacksFor(topicPartition); + if (CollectionUtils.isEmpty(callbacks)) { + return null; + } + return callbacks.get(0); + } + + /** + * Return the callbacks for the specified topic/partition. + * @param topicPartition the topic/partition. + * @return the callbacks (or null if there is no assignment). + * @since 3.3 + */ + @Nullable + protected List getSeekCallbacksFor(TopicPartition topicPartition) { + return this.topicToCallbacks.get(topicPartition); } /** * The map of callbacks for all currently assigned partitions. * @return the map. + * @deprecated Replaced by {@link #getTopicsAndCallbacks()} */ + @Deprecated(since = "3.3", forRemoval = true) protected Map getSeekCallbacks() { - return Collections.unmodifiableMap(this.callbacks); + Map> topicsAndCallbacks = getTopicsAndCallbacks(); + return topicsAndCallbacks.entrySet().stream() + .filter(entry -> !entry.getValue().isEmpty()) + .collect(Collectors.toMap( + Map.Entry::getKey, + entry -> entry.getValue().get(0) + )); + } + + /** + * The map of callbacks for all currently assigned partitions. + * @return the map. + * @since 3.3 + */ + protected Map> getTopicsAndCallbacks() { + return Collections.unmodifiableMap(this.topicToCallbacks); } /** @@ -105,7 +145,7 @@ protected Map getSeekCallbacks() { * @since 2.6 */ protected Map> getCallbacksAndTopics() { - return Collections.unmodifiableMap(this.callbacksToTopic); + return Collections.unmodifiableMap(this.callbackToTopics); } /** @@ -113,7 +153,7 @@ protected Map> getCallbacksAndTopics( * @since 2.6 */ public void seekToBeginning() { - getCallbacksAndTopics().forEach((cb, topics) -> cb.seekToBeginning(topics)); + getCallbacksAndTopics().forEach(ConsumerSeekCallback::seekToBeginning); } /** @@ -121,7 +161,7 @@ public void seekToBeginning() { * @since 2.6 */ public void seekToEnd() { - getCallbacksAndTopics().forEach((cb, topics) -> cb.seekToEnd(topics)); + getCallbacksAndTopics().forEach(ConsumerSeekCallback::seekToEnd); } /** diff --git a/spring-kafka/src/test/java/org/springframework/kafka/annotation/EnableKafkaIntegrationTests.java b/spring-kafka/src/test/java/org/springframework/kafka/annotation/EnableKafkaIntegrationTests.java index 9b06a9bb6..6c6f70265 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/annotation/EnableKafkaIntegrationTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/annotation/EnableKafkaIntegrationTests.java @@ -180,6 +180,7 @@ * @author Nakul Mishra * @author Soby Chacko * @author Wang Zhiyang + * @author Borahm Lee */ @SpringJUnitConfig @DirtiesContext @@ -1081,7 +1082,7 @@ public void testSeekToLastOnIdle() throws InterruptedException { assertThat(this.seekOnIdleListener.latch3.await(10, TimeUnit.SECONDS)).isTrue(); this.registry.getListenerContainer("seekOnIdle").stop(); assertThat(this.seekOnIdleListener.latch4.await(10, TimeUnit.SECONDS)).isTrue(); - assertThat(KafkaTestUtils.getPropertyValue(this.seekOnIdleListener, "callbacks", Map.class)).hasSize(0); + assertThat(KafkaTestUtils.getPropertyValue(this.seekOnIdleListener, "topicToCallbacks", Map.class)).hasSize(0); } @SuppressWarnings({"unchecked", "rawtypes"}) @@ -2523,11 +2524,10 @@ public void listen(String in) throws InterruptedException { if (latch1.getCount() > 0) { latch1.countDown(); if (latch1.getCount() == 0) { - ConsumerSeekCallback seekToComputeFn = getSeekCallbackFor( + List seekToComputeFunctions = getSeekCallbacksFor( new org.apache.kafka.common.TopicPartition("seekToComputeFn", 0)); - assertThat(seekToComputeFn).isNotNull(); - seekToComputeFn. - seek("seekToComputeFn", 0, current -> 0L); + assertThat(seekToComputeFunctions).isNotEmpty(); + seekToComputeFunctions.forEach(callback -> callback.seek("seekToComputeFn", 0, current -> 0L)); } } } @@ -2576,14 +2576,15 @@ public void onIdleContainer(Map as } public void rewindAllOneRecord() { - getSeekCallbacks() - .forEach((tp, callback) -> - callback.seekRelative(tp.topic(), tp.partition(), -1, true)); + getTopicsAndCallbacks() + .forEach((tp, callbacks) -> + callbacks.forEach(callback -> callback.seekRelative(tp.topic(), tp.partition(), -1, true)) + ); } public void rewindOnePartitionOneRecord(String topic, int partition) { - getSeekCallbackFor(new org.apache.kafka.common.TopicPartition(topic, partition)) - .seekRelative(topic, partition, -1, true); + getSeekCallbacksFor(new org.apache.kafka.common.TopicPartition(topic, partition)) + .forEach(callback -> callback.seekRelative(topic, partition, -1, true)); } @Override diff --git a/spring-kafka/src/test/java/org/springframework/kafka/listener/AbstractConsumerSeekAwareTests.java b/spring-kafka/src/test/java/org/springframework/kafka/listener/AbstractConsumerSeekAwareTests.java index 169987e1a..5230dc967 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/listener/AbstractConsumerSeekAwareTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/listener/AbstractConsumerSeekAwareTests.java @@ -17,10 +17,18 @@ package org.springframework.kafka.listener; import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.apache.kafka.common.TopicPartition; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; @@ -35,6 +43,7 @@ import org.springframework.kafka.core.KafkaTemplate; import org.springframework.kafka.core.ProducerFactory; import org.springframework.kafka.listener.AbstractConsumerSeekAwareTests.Config.MultiGroupListener; +import org.springframework.kafka.listener.ConsumerSeekAware.ConsumerSeekCallback; import org.springframework.kafka.test.EmbeddedKafkaBroker; import org.springframework.kafka.test.context.EmbeddedKafka; import org.springframework.kafka.test.utils.KafkaTestUtils; @@ -62,6 +71,22 @@ class AbstractConsumerSeekAwareTests { @Autowired MultiGroupListener multiGroupListener; + @Test + public void checkCallbacksAndTopicPartitions() { + await().timeout(Duration.ofSeconds(5)).untilAsserted(() -> { + Map> callbacksAndTopics = multiGroupListener.getCallbacksAndTopics(); + Set registeredCallbacks = callbacksAndTopics.keySet(); + Set registeredTopicPartitions = callbacksAndTopics.values().stream().flatMap(Collection::stream).collect(Collectors.toSet()); + + Map> topicsAndCallbacks = multiGroupListener.getTopicsAndCallbacks(); + Set getTopicPartitions = topicsAndCallbacks.keySet(); + Set getCallbacks = topicsAndCallbacks.values().stream().flatMap(Collection::stream).collect(Collectors.toSet()); + + assertThat(registeredCallbacks).containsExactlyInAnyOrderElementsOf(getCallbacks).isNotEmpty(); + assertThat(registeredTopicPartitions).containsExactlyInAnyOrderElementsOf(getTopicPartitions).hasSize(3); + }); + } + @Test void seekForAllGroups() throws Exception { template.send(TOPIC, "test-data"); @@ -130,12 +155,12 @@ static class MultiGroupListener extends AbstractConsumerSeekAware { static CountDownLatch latch2 = new CountDownLatch(2); - @KafkaListener(groupId = "group1", topics = TOPIC) + @KafkaListener(groupId = "group1", topics = TOPIC, concurrency = "2") void listenForGroup1(String in) { latch1.countDown(); } - @KafkaListener(groupId = "group2", topics = TOPIC) + @KafkaListener(groupId = "group2", topics = TOPIC, concurrency = "2") void listenForGroup2(String in) { latch2.countDown(); } diff --git a/spring-kafka/src/test/java/org/springframework/kafka/listener/ConsumerSeekAwareTests.java b/spring-kafka/src/test/java/org/springframework/kafka/listener/ConsumerSeekAwareTests.java index d579423f3..88a17e74e 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/listener/ConsumerSeekAwareTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/listener/ConsumerSeekAwareTests.java @@ -36,6 +36,7 @@ /** * @author Gary Russell + * @author Borahm Lee * @since 2.6 * */ @@ -104,8 +105,8 @@ class CSA extends AbstractConsumerSeekAware { }; exec1.submit(revoke2).get(); exec2.submit(revoke2).get(); - assertThat(KafkaTestUtils.getPropertyValue(csa, "callbacks", Map.class)).isEmpty(); - assertThat(KafkaTestUtils.getPropertyValue(csa, "callbacksToTopic", Map.class)).isEmpty(); + assertThat(KafkaTestUtils.getPropertyValue(csa, "topicToCallbacks", Map.class)).isEmpty(); + assertThat(KafkaTestUtils.getPropertyValue(csa, "callbackToTopics", Map.class)).isEmpty(); var checkTL = (Callable) () -> { csa.unregisterSeekCallback(); assertThat(KafkaTestUtils.getPropertyValue(csa, "callbackForThread", Map.class).get(Thread.currentThread())) diff --git a/spring-kafka/src/test/java/org/springframework/kafka/listener/KafkaMessageListenerContainerTests.java b/spring-kafka/src/test/java/org/springframework/kafka/listener/KafkaMessageListenerContainerTests.java index 17e8a9e39..a1355692c 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/listener/KafkaMessageListenerContainerTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/listener/KafkaMessageListenerContainerTests.java @@ -144,6 +144,7 @@ * @author Soby Chacko * @author Wang Zhiyang * @author Mikael Carlstedt + * @author Borahm Lee */ @EmbeddedKafka(topics = { KafkaMessageListenerContainerTests.topic1, KafkaMessageListenerContainerTests.topic2, KafkaMessageListenerContainerTests.topic3, KafkaMessageListenerContainerTests.topic4, @@ -2595,16 +2596,18 @@ public void onPartitionsAssigned(Map assignments, Consumer public void onMessage(ConsumerRecord data) { if (data.partition() == 0 && data.offset() == 0) { TopicPartition topicPartition = new TopicPartition(data.topic(), data.partition()); - final ConsumerSeekCallback seekCallbackFor = getSeekCallbackFor(topicPartition); - assertThat(seekCallbackFor).isNotNull(); - seekCallbackFor.seekToBeginning(records.keySet()); - Iterator iterator = records.keySet().iterator(); - seekCallbackFor.seekToBeginning(Collections.singletonList(iterator.next())); - seekCallbackFor.seekToBeginning(Collections.singletonList(iterator.next())); - seekCallbackFor.seekToEnd(records.keySet()); - iterator = records.keySet().iterator(); - seekCallbackFor.seekToEnd(Collections.singletonList(iterator.next())); - seekCallbackFor.seekToEnd(Collections.singletonList(iterator.next())); + final List seekCallbacksFor = getSeekCallbacksFor(topicPartition); + assertThat(seekCallbacksFor).isNotEmpty(); + seekCallbacksFor.forEach(callback -> { + callback.seekToBeginning(records.keySet()); + Iterator iterator = records.keySet().iterator(); + callback.seekToBeginning(Collections.singletonList(iterator.next())); + callback.seekToBeginning(Collections.singletonList(iterator.next())); + callback.seekToEnd(records.keySet()); + iterator = records.keySet().iterator(); + callback.seekToEnd(Collections.singletonList(iterator.next())); + callback.seekToEnd(Collections.singletonList(iterator.next())); + }); } }