Skip to content

Commit

Permalink
GH-3373: Enhancing handling of seeking failures due to consumer rebal…
Browse files Browse the repository at this point in the history
…ancing

Fixes: #3373

* Exclude unassigned partitions during seeking
* Set group initial rebalancing delay for test
* Rename method
  • Loading branch information
bky373 authored Aug 8, 2024
1 parent 4c1ac20 commit cc926b3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3005,13 +3005,20 @@ private void timedAcks() {
}

private void processSeeks() {
processTimestampSeeks();
Collection<TopicPartition> assigned = getAssignedPartitions();
processTimestampSeeks(assigned);
TopicPartitionOffset offset = this.seeks.poll();
while (offset != null) {
traceSeek(offset);
try {
SeekPosition position = offset.getPosition();
TopicPartition topicPartition = offset.getTopicPartition();
if (assigned == null || !assigned.contains(topicPartition)) {
this.logger.warn("No current assignment for partition " + topicPartition +
" due to partition reassignment prior to seeking.");
offset = this.seeks.poll();
continue;
}
SeekPosition position = offset.getPosition();
Long whereTo = offset.getOffset();
Function<Long, Long> offsetComputeFunction = offset.getOffsetComputeFunction();
if (position == null) {
Expand Down Expand Up @@ -3056,11 +3063,17 @@ else if (SeekPosition.TIMESTAMP.equals(position)) {
}
}

private void processTimestampSeeks() {
private void processTimestampSeeks(@Nullable Collection<TopicPartition> assigned) {
Iterator<TopicPartitionOffset> seekIterator = this.seeks.iterator();
Map<TopicPartition, Long> timestampSeeks = null;
while (seekIterator.hasNext()) {
TopicPartitionOffset tpo = seekIterator.next();
if (assigned == null || !assigned.contains(tpo.getTopicPartition())) {
this.logger.warn("No current assignment for partition " + tpo.getTopicPartition() +
" due to partition reassignment prior to seeking.");
seekIterator.remove();
continue;
}
if (SeekPosition.TIMESTAMP.equals(tpo.getPosition())) {
if (timestampSeeks == null) {
timestampSeeks = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@
*/
@DirtiesContext
@SpringJUnitConfig
@EmbeddedKafka(topics = {AbstractConsumerSeekAwareTests.TOPIC}, partitions = 3)
@EmbeddedKafka(topics = {AbstractConsumerSeekAwareTests.TOPIC},
partitions = 9,
brokerProperties = "group.initial.rebalance.delay.ms:4000")
class AbstractConsumerSeekAwareTests {

static final String TOPIC = "Seek";
Expand All @@ -74,7 +76,7 @@ class AbstractConsumerSeekAwareTests {

@Test
public void checkCallbacksAndTopicPartitions() {
await().timeout(Duration.ofSeconds(5))
await().timeout(Duration.ofSeconds(15))
.untilAsserted(() -> {
Map<ConsumerSeekCallback, List<TopicPartition>> callbacksAndTopics =
multiGroupListener.getCallbacksAndTopics();
Expand Down Expand Up @@ -103,29 +105,29 @@ public void checkCallbacksAndTopicPartitions() {
void seekForAllGroups() throws Exception {
template.send(TOPIC, "test-data");
template.send(TOPIC, "test-data");
assertThat(MultiGroupListener.latch1.await(30, TimeUnit.SECONDS)).isTrue();
assertThat(MultiGroupListener.latch2.await(30, TimeUnit.SECONDS)).isTrue();
assertThat(MultiGroupListener.latch1.await(15, TimeUnit.SECONDS)).isTrue();
assertThat(MultiGroupListener.latch2.await(15, TimeUnit.SECONDS)).isTrue();

MultiGroupListener.latch1 = new CountDownLatch(2);
MultiGroupListener.latch2 = new CountDownLatch(2);

multiGroupListener.seekToBeginning();
assertThat(MultiGroupListener.latch1.await(30, TimeUnit.SECONDS)).isTrue();
assertThat(MultiGroupListener.latch2.await(30, TimeUnit.SECONDS)).isTrue();
assertThat(MultiGroupListener.latch1.await(15, TimeUnit.SECONDS)).isTrue();
assertThat(MultiGroupListener.latch2.await(15, TimeUnit.SECONDS)).isTrue();
}

@Test
void seekForSpecificGroup() throws Exception {
template.send(TOPIC, "test-data");
template.send(TOPIC, "test-data");
assertThat(MultiGroupListener.latch1.await(30, TimeUnit.SECONDS)).isTrue();
assertThat(MultiGroupListener.latch2.await(30, TimeUnit.SECONDS)).isTrue();
assertThat(MultiGroupListener.latch1.await(15, TimeUnit.SECONDS)).isTrue();
assertThat(MultiGroupListener.latch2.await(15, TimeUnit.SECONDS)).isTrue();

MultiGroupListener.latch1 = new CountDownLatch(2);
MultiGroupListener.latch2 = new CountDownLatch(2);

multiGroupListener.seekToBeginningForGroup("group2");
assertThat(MultiGroupListener.latch2.await(30, TimeUnit.SECONDS)).isTrue();
multiGroupListener.seekToBeginningFor("group2");
assertThat(MultiGroupListener.latch2.await(15, TimeUnit.SECONDS)).isTrue();
assertThat(MultiGroupListener.latch1.await(1, TimeUnit.SECONDS)).isFalse();
assertThat(MultiGroupListener.latch1.getCount()).isEqualTo(2);
}
Expand Down Expand Up @@ -168,19 +170,19 @@ static class MultiGroupListener extends AbstractConsumerSeekAware {

static CountDownLatch latch2 = new CountDownLatch(2);

@KafkaListener(groupId = "group1", topics = TOPIC/*TODO until we figure out non-relevant partitions on assignment, concurrency = "2"*/)
@KafkaListener(groupId = "group1", topics = TOPIC, concurrency = "2")
void listenForGroup1(String in) {
latch1.countDown();
}

@KafkaListener(groupId = "group2", topics = TOPIC/*TODO until we figure out non-relevant partitions on assignment, concurrency = "2"*/)
@KafkaListener(groupId = "group2", topics = TOPIC, concurrency = "7")
void listenForGroup2(String in) {
latch2.countDown();
}

void seekToBeginningForGroup(String groupIdForSeek) {
void seekToBeginningFor(String groupId) {
getCallbacksAndTopics().forEach((cb, topics) -> {
if (groupIdForSeek.equals(cb.getGroupId())) {
if (groupId.equals(cb.getGroupId())) {
topics.forEach(tp -> cb.seekToBeginning(tp.topic(), tp.partition()));
}
});
Expand Down

0 comments on commit cc926b3

Please sign in to comment.