Skip to content

Commit

Permalink
GH-3328: Add missing seek callbacks on each topic partition
Browse files Browse the repository at this point in the history
Fixes: #3328

When using an `AbstractConsumerSeekAware` in a multi-group listeners scenario, there are cases where the number of registered callbacks differs from the number of discovered callbacks.
This is due to the value type of callbacks Map in `AbstractConsumerSeekAware` class being simply `ConsumerSeekCallback`.
This causes some callbacks looking at the same partition to be missing.

* Change the value type of callbacks Map in `AbstractConsumerSeekAware` class from `ConsumerSeekCallback` to `List<ConsumerSeekCallback>`.
* Also modify some methods, test codes and docs that are affected by this change.
* Add test codes to verify that the callbacks registered via `registeredSeekCallback()` and the ones you can get via `getSeekCallbacks()` match completely.
  • Loading branch information
bky373 authored and artembilan committed Jul 11, 2024
1 parent f91f8a9 commit 0fcdf92
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,18 @@ public class SeekToLastOnIdleListener extends AbstractConsumerSeekAware {
* Rewind all partitions one record.
*/
public void rewindAllOneRecord() {
getSeekCallbacks()
.forEach((tp, callback) ->
callback.seekRelative(tp.topic(), tp.partition(), -1, true));
getTopicsAndCallbacks()
.forEach((tp, callbacks) ->
callbacks.forEach(callback -> callback.seekRelative(tp.topic(), tp.partition(), -1, true))
);
}
/**
* Rewind one partition one record.
*/
public void rewindOnePartitionOneRecord(String topic, int partition) {
getSeekCallbackFor(new TopicPartition(topic, partition))
.seekRelative(topic, partition, -1, true);
getSeekCallbacksFor(new TopicPartition(topic, partition))
.forEach(callback -> callback.seekRelative(topic, partition, -1, true));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The naming convention for DLT topics has been standardized to use the "-dlt" suf

A new method, `getGroupId()`, has been added to the `ConsumerSeekCallback` interface.
This method allows for more selective seek operations by targeting only the desired consumer group.
The `AbstractConsumerSeekAware` can also now register, retrieve, and remove all callbacks for each topic partition in a multi-group listener scenario without missing any.
See the new APIs (`getSeekCallbacksFor(TopicPartition topicPartition)`, `getTopicsAndCallbacks()`) for more details.
For more details, see xref:kafka/seek.adoc#seek[Seek API Docs].

[[x33-new-option-ignore-empty-batch]]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2023 the original author or authors.
* Copyright 2019-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,33 +16,37 @@

package org.springframework.kafka.listener;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

import org.apache.kafka.common.TopicPartition;

import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;

/**
* Manages the {@link ConsumerSeekAware.ConsumerSeekCallback} s for the listener. If the
* listener subclasses this class, it can easily seek arbitrary topics/partitions without
* having to keep track of the callbacks itself.
*
* @author Gary Russell
* @author Borahm Lee
* @since 2.3
*
*/
public abstract class AbstractConsumerSeekAware implements ConsumerSeekAware {

private final Map<Thread, ConsumerSeekCallback> callbackForThread = new ConcurrentHashMap<>();

private final Map<TopicPartition, ConsumerSeekCallback> callbacks = new ConcurrentHashMap<>();
private final Map<TopicPartition, List<ConsumerSeekCallback>> topicToCallbacks = new ConcurrentHashMap<>();

private final Map<ConsumerSeekCallback, List<TopicPartition>> callbacksToTopic = new ConcurrentHashMap<>();
private final Map<ConsumerSeekCallback, List<TopicPartition>> callbackToTopics = new ConcurrentHashMap<>();

@Override
public void registerSeekCallback(ConsumerSeekCallback callback) {
Expand All @@ -54,24 +58,26 @@ public void onPartitionsAssigned(Map<TopicPartition, Long> assignments, Consumer
ConsumerSeekCallback threadCallback = this.callbackForThread.get(Thread.currentThread());
if (threadCallback != null) {
assignments.keySet().forEach(tp -> {
this.callbacks.put(tp, threadCallback);
this.callbacksToTopic.computeIfAbsent(threadCallback, key -> new LinkedList<>()).add(tp);
this.topicToCallbacks.computeIfAbsent(tp, key -> new ArrayList<>()).add(threadCallback);
this.callbackToTopics.computeIfAbsent(threadCallback, key -> new LinkedList<>()).add(tp);
});
}
}

@Override
public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
partitions.forEach(tp -> {
ConsumerSeekCallback removed = this.callbacks.remove(tp);
if (removed != null) {
List<TopicPartition> topics = this.callbacksToTopic.get(removed);
if (topics != null) {
topics.remove(tp);
if (topics.size() == 0) {
this.callbacksToTopic.remove(removed);
List<ConsumerSeekCallback> removedCallbacks = this.topicToCallbacks.remove(tp);
if (removedCallbacks != null && !removedCallbacks.isEmpty()) {
removedCallbacks.forEach(cb -> {
List<TopicPartition> topics = this.callbackToTopics.get(cb);
if (topics != null) {
topics.remove(tp);
if (topics.isEmpty()) {
this.callbackToTopics.remove(cb);
}
}
}
});
}
});
}
Expand All @@ -82,21 +88,55 @@ public void unregisterSeekCallback() {
}

/**
* Return the callback for the specified topic/partition.
* @param topicPartition the topic/partition.
* @return the callback (or null if there is no assignment).
*/
* Return the callback for the specified topic/partition.
* @param topicPartition the topic/partition.
* @return the callback (or null if there is no assignment).
* @deprecated Replaced by {@link #getSeekCallbacksFor(TopicPartition)}
*/
@Deprecated(since = "3.3", forRemoval = true)
@Nullable
protected ConsumerSeekCallback getSeekCallbackFor(TopicPartition topicPartition) {
return this.callbacks.get(topicPartition);
List<ConsumerSeekCallback> callbacks = getSeekCallbacksFor(topicPartition);
if (CollectionUtils.isEmpty(callbacks)) {
return null;
}
return callbacks.get(0);
}

/**
* Return the callbacks for the specified topic/partition.
* @param topicPartition the topic/partition.
* @return the callbacks (or null if there is no assignment).
* @since 3.3
*/
@Nullable
protected List<ConsumerSeekCallback> getSeekCallbacksFor(TopicPartition topicPartition) {
return this.topicToCallbacks.get(topicPartition);
}

/**
* The map of callbacks for all currently assigned partitions.
* @return the map.
* @deprecated Replaced by {@link #getTopicsAndCallbacks()}
*/
@Deprecated(since = "3.3", forRemoval = true)
protected Map<TopicPartition, ConsumerSeekCallback> getSeekCallbacks() {
return Collections.unmodifiableMap(this.callbacks);
Map<TopicPartition, List<ConsumerSeekCallback>> topicsAndCallbacks = getTopicsAndCallbacks();
return topicsAndCallbacks.entrySet().stream()
.filter(entry -> !entry.getValue().isEmpty())
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> entry.getValue().get(0)
));
}

/**
* The map of callbacks for all currently assigned partitions.
* @return the map.
* @since 3.3
*/
protected Map<TopicPartition, List<ConsumerSeekCallback>> getTopicsAndCallbacks() {
return Collections.unmodifiableMap(this.topicToCallbacks);
}

/**
Expand All @@ -105,23 +145,23 @@ protected Map<TopicPartition, ConsumerSeekCallback> getSeekCallbacks() {
* @since 2.6
*/
protected Map<ConsumerSeekCallback, List<TopicPartition>> getCallbacksAndTopics() {
return Collections.unmodifiableMap(this.callbacksToTopic);
return Collections.unmodifiableMap(this.callbackToTopics);
}

/**
* Seek all assigned partitions to the beginning.
* @since 2.6
*/
public void seekToBeginning() {
getCallbacksAndTopics().forEach((cb, topics) -> cb.seekToBeginning(topics));
getCallbacksAndTopics().forEach(ConsumerSeekCallback::seekToBeginning);
}

/**
* Seek all assigned partitions to the end.
* @since 2.6
*/
public void seekToEnd() {
getCallbacksAndTopics().forEach((cb, topics) -> cb.seekToEnd(topics));
getCallbacksAndTopics().forEach(ConsumerSeekCallback::seekToEnd);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
* @author Nakul Mishra
* @author Soby Chacko
* @author Wang Zhiyang
* @author Borahm Lee
*/
@SpringJUnitConfig
@DirtiesContext
Expand Down Expand Up @@ -1081,7 +1082,7 @@ public void testSeekToLastOnIdle() throws InterruptedException {
assertThat(this.seekOnIdleListener.latch3.await(10, TimeUnit.SECONDS)).isTrue();
this.registry.getListenerContainer("seekOnIdle").stop();
assertThat(this.seekOnIdleListener.latch4.await(10, TimeUnit.SECONDS)).isTrue();
assertThat(KafkaTestUtils.getPropertyValue(this.seekOnIdleListener, "callbacks", Map.class)).hasSize(0);
assertThat(KafkaTestUtils.getPropertyValue(this.seekOnIdleListener, "topicToCallbacks", Map.class)).hasSize(0);
}

@SuppressWarnings({"unchecked", "rawtypes"})
Expand Down Expand Up @@ -2523,11 +2524,10 @@ public void listen(String in) throws InterruptedException {
if (latch1.getCount() > 0) {
latch1.countDown();
if (latch1.getCount() == 0) {
ConsumerSeekCallback seekToComputeFn = getSeekCallbackFor(
List<ConsumerSeekCallback> seekToComputeFunctions = getSeekCallbacksFor(
new org.apache.kafka.common.TopicPartition("seekToComputeFn", 0));
assertThat(seekToComputeFn).isNotNull();
seekToComputeFn.
seek("seekToComputeFn", 0, current -> 0L);
assertThat(seekToComputeFunctions).isNotEmpty();
seekToComputeFunctions.forEach(callback -> callback.seek("seekToComputeFn", 0, current -> 0L));
}
}
}
Expand Down Expand Up @@ -2576,14 +2576,15 @@ public void onIdleContainer(Map<org.apache.kafka.common.TopicPartition, Long> as
}

public void rewindAllOneRecord() {
getSeekCallbacks()
.forEach((tp, callback) ->
callback.seekRelative(tp.topic(), tp.partition(), -1, true));
getTopicsAndCallbacks()
.forEach((tp, callbacks) ->
callbacks.forEach(callback -> callback.seekRelative(tp.topic(), tp.partition(), -1, true))
);
}

public void rewindOnePartitionOneRecord(String topic, int partition) {
getSeekCallbackFor(new org.apache.kafka.common.TopicPartition(topic, partition))
.seekRelative(topic, partition, -1, true);
getSeekCallbacksFor(new org.apache.kafka.common.TopicPartition(topic, partition))
.forEach(callback -> callback.seekRelative(topic, partition, -1, true));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@
package org.springframework.kafka.listener;

import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;

import java.time.Duration;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import org.apache.kafka.common.TopicPartition;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -35,6 +43,7 @@
import org.springframework.kafka.core.KafkaTemplate;
import org.springframework.kafka.core.ProducerFactory;
import org.springframework.kafka.listener.AbstractConsumerSeekAwareTests.Config.MultiGroupListener;
import org.springframework.kafka.listener.ConsumerSeekAware.ConsumerSeekCallback;
import org.springframework.kafka.test.EmbeddedKafkaBroker;
import org.springframework.kafka.test.context.EmbeddedKafka;
import org.springframework.kafka.test.utils.KafkaTestUtils;
Expand Down Expand Up @@ -62,6 +71,22 @@ class AbstractConsumerSeekAwareTests {
@Autowired
MultiGroupListener multiGroupListener;

@Test
public void checkCallbacksAndTopicPartitions() {
await().timeout(Duration.ofSeconds(5)).untilAsserted(() -> {
Map<ConsumerSeekCallback, List<TopicPartition>> callbacksAndTopics = multiGroupListener.getCallbacksAndTopics();
Set<ConsumerSeekCallback> registeredCallbacks = callbacksAndTopics.keySet();
Set<TopicPartition> registeredTopicPartitions = callbacksAndTopics.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());

Map<TopicPartition, List<ConsumerSeekCallback>> topicsAndCallbacks = multiGroupListener.getTopicsAndCallbacks();
Set<TopicPartition> getTopicPartitions = topicsAndCallbacks.keySet();
Set<ConsumerSeekCallback> getCallbacks = topicsAndCallbacks.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());

assertThat(registeredCallbacks).containsExactlyInAnyOrderElementsOf(getCallbacks).isNotEmpty();
assertThat(registeredTopicPartitions).containsExactlyInAnyOrderElementsOf(getTopicPartitions).hasSize(3);
});
}

@Test
void seekForAllGroups() throws Exception {
template.send(TOPIC, "test-data");
Expand Down Expand Up @@ -130,12 +155,12 @@ static class MultiGroupListener extends AbstractConsumerSeekAware {

static CountDownLatch latch2 = new CountDownLatch(2);

@KafkaListener(groupId = "group1", topics = TOPIC)
@KafkaListener(groupId = "group1", topics = TOPIC, concurrency = "2")
void listenForGroup1(String in) {
latch1.countDown();
}

@KafkaListener(groupId = "group2", topics = TOPIC)
@KafkaListener(groupId = "group2", topics = TOPIC, concurrency = "2")
void listenForGroup2(String in) {
latch2.countDown();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

/**
* @author Gary Russell
* @author Borahm Lee
* @since 2.6
*
*/
Expand Down Expand Up @@ -104,8 +105,8 @@ class CSA extends AbstractConsumerSeekAware {
};
exec1.submit(revoke2).get();
exec2.submit(revoke2).get();
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbacks", Map.class)).isEmpty();
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbacksToTopic", Map.class)).isEmpty();
assertThat(KafkaTestUtils.getPropertyValue(csa, "topicToCallbacks", Map.class)).isEmpty();
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbackToTopics", Map.class)).isEmpty();
var checkTL = (Callable<Void>) () -> {
csa.unregisterSeekCallback();
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbackForThread", Map.class).get(Thread.currentThread()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
* @author Soby Chacko
* @author Wang Zhiyang
* @author Mikael Carlstedt
* @author Borahm Lee
*/
@EmbeddedKafka(topics = { KafkaMessageListenerContainerTests.topic1, KafkaMessageListenerContainerTests.topic2,
KafkaMessageListenerContainerTests.topic3, KafkaMessageListenerContainerTests.topic4,
Expand Down Expand Up @@ -2595,16 +2596,18 @@ public void onPartitionsAssigned(Map<TopicPartition, Long> assignments, Consumer
public void onMessage(ConsumerRecord<String, String> data) {
if (data.partition() == 0 && data.offset() == 0) {
TopicPartition topicPartition = new TopicPartition(data.topic(), data.partition());
final ConsumerSeekCallback seekCallbackFor = getSeekCallbackFor(topicPartition);
assertThat(seekCallbackFor).isNotNull();
seekCallbackFor.seekToBeginning(records.keySet());
Iterator<TopicPartition> iterator = records.keySet().iterator();
seekCallbackFor.seekToBeginning(Collections.singletonList(iterator.next()));
seekCallbackFor.seekToBeginning(Collections.singletonList(iterator.next()));
seekCallbackFor.seekToEnd(records.keySet());
iterator = records.keySet().iterator();
seekCallbackFor.seekToEnd(Collections.singletonList(iterator.next()));
seekCallbackFor.seekToEnd(Collections.singletonList(iterator.next()));
final List<ConsumerSeekCallback> seekCallbacksFor = getSeekCallbacksFor(topicPartition);
assertThat(seekCallbacksFor).isNotEmpty();
seekCallbacksFor.forEach(callback -> {
callback.seekToBeginning(records.keySet());
Iterator<TopicPartition> iterator = records.keySet().iterator();
callback.seekToBeginning(Collections.singletonList(iterator.next()));
callback.seekToBeginning(Collections.singletonList(iterator.next()));
callback.seekToEnd(records.keySet());
iterator = records.keySet().iterator();
callback.seekToEnd(Collections.singletonList(iterator.next()));
callback.seekToEnd(Collections.singletonList(iterator.next()));
});
}
}

Expand Down

0 comments on commit 0fcdf92

Please sign in to comment.