diff --git a/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java b/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java index 94926c3ef5..05c060ca4a 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java @@ -3005,13 +3005,20 @@ private void timedAcks() { } private void processSeeks() { - processTimestampSeeks(); + Collection assigned = getAssignedPartitions(); + processTimestampSeeks(assigned); TopicPartitionOffset offset = this.seeks.poll(); while (offset != null) { traceSeek(offset); try { - SeekPosition position = offset.getPosition(); TopicPartition topicPartition = offset.getTopicPartition(); + if (assigned == null || !assigned.contains(topicPartition)) { + this.logger.warn("No current assignment for partition " + topicPartition + + " due to partition reassignment prior to seeking."); + offset = this.seeks.poll(); + continue; + } + SeekPosition position = offset.getPosition(); Long whereTo = offset.getOffset(); Function offsetComputeFunction = offset.getOffsetComputeFunction(); if (position == null) { @@ -3056,11 +3063,17 @@ else if (SeekPosition.TIMESTAMP.equals(position)) { } } - private void processTimestampSeeks() { + private void processTimestampSeeks(@Nullable Collection assigned) { Iterator seekIterator = this.seeks.iterator(); Map timestampSeeks = null; while (seekIterator.hasNext()) { TopicPartitionOffset tpo = seekIterator.next(); + if (assigned == null || !assigned.contains(tpo.getTopicPartition())) { + this.logger.warn("No current assignment for partition " + tpo.getTopicPartition() + + " due to partition reassignment prior to seeking."); + seekIterator.remove(); + continue; + } if (SeekPosition.TIMESTAMP.equals(tpo.getPosition())) { if (timestampSeeks == null) { timestampSeeks = new HashMap<>(); diff --git a/spring-kafka/src/test/java/org/springframework/kafka/listener/AbstractConsumerSeekAwareTests.java b/spring-kafka/src/test/java/org/springframework/kafka/listener/AbstractConsumerSeekAwareTests.java index 41887f3d44..2c09187277 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/listener/AbstractConsumerSeekAwareTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/listener/AbstractConsumerSeekAwareTests.java @@ -58,7 +58,9 @@ */ @DirtiesContext @SpringJUnitConfig -@EmbeddedKafka(topics = {AbstractConsumerSeekAwareTests.TOPIC}, partitions = 3) +@EmbeddedKafka(topics = {AbstractConsumerSeekAwareTests.TOPIC}, + partitions = 9, + brokerProperties = "group.initial.rebalance.delay.ms:4000") class AbstractConsumerSeekAwareTests { static final String TOPIC = "Seek"; @@ -74,7 +76,7 @@ class AbstractConsumerSeekAwareTests { @Test public void checkCallbacksAndTopicPartitions() { - await().timeout(Duration.ofSeconds(5)) + await().timeout(Duration.ofSeconds(15)) .untilAsserted(() -> { Map> callbacksAndTopics = multiGroupListener.getCallbacksAndTopics(); @@ -103,29 +105,29 @@ public void checkCallbacksAndTopicPartitions() { void seekForAllGroups() throws Exception { template.send(TOPIC, "test-data"); template.send(TOPIC, "test-data"); - assertThat(MultiGroupListener.latch1.await(30, TimeUnit.SECONDS)).isTrue(); - assertThat(MultiGroupListener.latch2.await(30, TimeUnit.SECONDS)).isTrue(); + assertThat(MultiGroupListener.latch1.await(15, TimeUnit.SECONDS)).isTrue(); + assertThat(MultiGroupListener.latch2.await(15, TimeUnit.SECONDS)).isTrue(); MultiGroupListener.latch1 = new CountDownLatch(2); MultiGroupListener.latch2 = new CountDownLatch(2); multiGroupListener.seekToBeginning(); - assertThat(MultiGroupListener.latch1.await(30, TimeUnit.SECONDS)).isTrue(); - assertThat(MultiGroupListener.latch2.await(30, TimeUnit.SECONDS)).isTrue(); + assertThat(MultiGroupListener.latch1.await(15, TimeUnit.SECONDS)).isTrue(); + assertThat(MultiGroupListener.latch2.await(15, TimeUnit.SECONDS)).isTrue(); } @Test void seekForSpecificGroup() throws Exception { template.send(TOPIC, "test-data"); template.send(TOPIC, "test-data"); - assertThat(MultiGroupListener.latch1.await(30, TimeUnit.SECONDS)).isTrue(); - assertThat(MultiGroupListener.latch2.await(30, TimeUnit.SECONDS)).isTrue(); + assertThat(MultiGroupListener.latch1.await(15, TimeUnit.SECONDS)).isTrue(); + assertThat(MultiGroupListener.latch2.await(15, TimeUnit.SECONDS)).isTrue(); MultiGroupListener.latch1 = new CountDownLatch(2); MultiGroupListener.latch2 = new CountDownLatch(2); - multiGroupListener.seekToBeginningForGroup("group2"); - assertThat(MultiGroupListener.latch2.await(30, TimeUnit.SECONDS)).isTrue(); + multiGroupListener.seekToBeginningFor("group2"); + assertThat(MultiGroupListener.latch2.await(15, TimeUnit.SECONDS)).isTrue(); assertThat(MultiGroupListener.latch1.await(1, TimeUnit.SECONDS)).isFalse(); assertThat(MultiGroupListener.latch1.getCount()).isEqualTo(2); } @@ -168,19 +170,19 @@ static class MultiGroupListener extends AbstractConsumerSeekAware { static CountDownLatch latch2 = new CountDownLatch(2); - @KafkaListener(groupId = "group1", topics = TOPIC/*TODO until we figure out non-relevant partitions on assignment, concurrency = "2"*/) + @KafkaListener(groupId = "group1", topics = TOPIC, concurrency = "2") void listenForGroup1(String in) { latch1.countDown(); } - @KafkaListener(groupId = "group2", topics = TOPIC/*TODO until we figure out non-relevant partitions on assignment, concurrency = "2"*/) + @KafkaListener(groupId = "group2", topics = TOPIC, concurrency = "7") void listenForGroup2(String in) { latch2.countDown(); } - void seekToBeginningForGroup(String groupIdForSeek) { + void seekToBeginningFor(String groupId) { getCallbacksAndTopics().forEach((cb, topics) -> { - if (groupIdForSeek.equals(cb.getGroupId())) { + if (groupId.equals(cb.getGroupId())) { topics.forEach(tp -> cb.seekToBeginning(tp.topic(), tp.partition())); } });