diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ConsistentHashingStickyKeyConsumerSelector.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ConsistentHashingStickyKeyConsumerSelector.java index b2b2b512c8cfc..3a5397722fed1 100644 --- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ConsistentHashingStickyKeyConsumerSelector.java +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ConsistentHashingStickyKeyConsumerSelector.java @@ -18,17 +18,19 @@ */ package org.apache.pulsar.broker.service; -import com.google.common.collect.Lists; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.NavigableMap; import java.util.TreeMap; +import java.util.WeakHashMap; import java.util.concurrent.CompletableFuture; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import org.apache.commons.lang3.mutable.MutableInt; import org.apache.pulsar.client.api.Range; import org.apache.pulsar.common.util.Murmur3_32Hash; @@ -44,15 +46,118 @@ public class ConsistentHashingStickyKeyConsumerSelector implements StickyKeyCons private final ReadWriteLock rwLock = new ReentrantReadWriteLock(); // Consistent-Hash ring - private final NavigableMap> hashRing; + private final NavigableMap hashRing; + // used for distributing consumer instance selections evenly in the hash ring when there + // are multiple instances of consumer with the same consumer name or when there are hash collisions + private final Map consumerSelectionCounters; private final int numberOfPoints; public ConsistentHashingStickyKeyConsumerSelector(int numberOfPoints) { this.hashRing = new TreeMap<>(); + this.consumerSelectionCounters = new WeakHashMap<>(); this.numberOfPoints = numberOfPoints; } + /** + * This class is used to store the consumers and the selected consumer for a hash value in the hash ring. + * This attempts to distribute the consumers evenly in the hash ring for consumers with the same + * consumer name and priority level. These entries collide in the hash ring. + * The selected consumer is the consumer that is selected to serve the hash value. + * It is not changed unless a consumer is removed or a colliding consumer with higher priority or + * lower selection count is added. + */ + private static class HashRingEntry { + // This class is used to store the consumer which it was added to the hash ring + // sorting will be by priority, consumer name and usage count of the consumer instance + record ConsumerEntry(Consumer consumer, MutableInt consumerSelectionCounter) + implements Comparable { + private static final Comparator + BASE_CONSUMER_ENTRY_COMPARATOR = Comparator. + comparing(entry -> entry.consumer().getPriorityLevel()).reversed() + .thenComparing(entry -> entry.consumer().consumerName()); + + + private static final Comparator + CONSUMER_ENTRY_COMPARATOR = BASE_CONSUMER_ENTRY_COMPARATOR + // prefer the consumer instance with lowest selection count so that consumers get + // evenly distributed + .thenComparing(ConsumerEntry::consumerSelectionCounter); + + @Override + public int compareTo(ConsumerEntry o) { + return CONSUMER_ENTRY_COMPARATOR.compare(this, o); + } + + // comparison without the usage count so that the consumer doesn't get changed too eagerly + // when entries are removed + public int baseCompareTo(ConsumerEntry o) { + return BASE_CONSUMER_ENTRY_COMPARATOR.compare(this, o); + } + } + + private final List consumers; + ConsumerEntry selectedConsumerEntry; + + public HashRingEntry() { + this.consumers = new ArrayList<>(); + } + + public void addConsumer(Consumer consumer, MutableInt selectedCounter) { + consumers.add(new ConsumerEntry(consumer, selectedCounter)); + selectConsumer(null); + } + + public boolean removeConsumer(Consumer consumer) { + boolean removed = consumers.removeIf(consumerEntry -> consumerEntry.consumer().equals(consumer)); + selectConsumer(consumer); + return removed; + } + + public Consumer getSelectedConsumer() { + return selectedConsumerEntry != null ? selectedConsumerEntry.consumer() : null; + } + + private void selectConsumer(Consumer removedConsumer) { + if (consumers.size() > 1) { + boolean addOperation = removedConsumer == null; + if (addOperation) { + Collections.sort(consumers); + } + ConsumerEntry newSelectedConsumer = consumers.get(0); + // change the selected consumer only if the newer has higher priority, + // or the same priority and an earlier name in sorting order + if (selectedConsumerEntry == null || addOperation + || selectedConsumerEntry.consumer.equals(removedConsumer) + || selectedConsumerEntry.baseCompareTo(newSelectedConsumer) > 0) { + changeSelectedConsumerEntry(newSelectedConsumer); + } + } else if (consumers.size() == 1) { + changeSelectedConsumerEntry(consumers.get(0)); + } else { + changeSelectedConsumerEntry(null); + } + } + + private void changeSelectedConsumerEntry(ConsumerEntry newSelectedConsumer) { + beforeChangingSelectedConsumerEntry(); + selectedConsumerEntry = newSelectedConsumer; + afterChangingSelectedConsumerEntry(); + } + + private void beforeChangingSelectedConsumerEntry() { + if (selectedConsumerEntry != null) { + selectedConsumerEntry.consumerSelectionCounter.decrement(); + } + } + + private void afterChangingSelectedConsumerEntry() { + if (selectedConsumerEntry != null) { + selectedConsumerEntry.consumerSelectionCounter.increment(); + } + } + } + @Override public CompletableFuture addConsumer(Consumer consumer) { rwLock.writeLock().lock(); @@ -61,17 +166,9 @@ public CompletableFuture addConsumer(Consumer consumer) { // The points are deterministically added based on the hash of the consumer name for (int i = 0; i < numberOfPoints; i++) { int hash = calculateHashForConsumerAndIndex(consumer, i); - hashRing.compute(hash, (k, v) -> { - if (v == null) { - return Lists.newArrayList(consumer); - } else { - if (!v.contains(consumer)) { - v.add(consumer); - v.sort(Comparator.comparing(Consumer::consumerName, String::compareTo)); - } - return v; - } - }); + HashRingEntry hashRingEntry = hashRing.computeIfAbsent(hash, k -> new HashRingEntry()); + // Add the consumer to the hash ring entry + hashRingEntry.addConsumer(consumer, getConsumerSelectedCount(consumer)); } return CompletableFuture.completedFuture(null); } finally { @@ -79,6 +176,10 @@ public CompletableFuture addConsumer(Consumer consumer) { } } + private MutableInt getConsumerSelectedCount(Consumer consumer) { + return consumerSelectionCounters.computeIfAbsent(consumer, k -> new MutableInt()); + } + private static int calculateHashForConsumerAndIndex(Consumer consumer, int index) { String key = consumer.consumerName() + KEY_SEPARATOR + index; return Murmur3_32Hash.getInstance().makeHash(key.getBytes()); @@ -92,15 +193,11 @@ public void removeConsumer(Consumer consumer) { for (int i = 0; i < numberOfPoints; i++) { int hash = calculateHashForConsumerAndIndex(consumer, i); hashRing.compute(hash, (k, v) -> { - if (v == null) { - return null; - } else { - v.removeIf(c -> c.equals(consumer)); - if (v.isEmpty()) { - v = null; - } - return v; + v.removeConsumer(consumer); + if (v.getSelectedConsumer() == null) { + v = null; } + return v; }); } } finally { @@ -115,16 +212,14 @@ public Consumer select(int hash) { if (hashRing.isEmpty()) { return null; } - - List consumerList; - Map.Entry> ceilingEntry = hashRing.ceilingEntry(hash); + HashRingEntry hashRingEntry; + Map.Entry ceilingEntry = hashRing.ceilingEntry(hash); if (ceilingEntry != null) { - consumerList = ceilingEntry.getValue(); + hashRingEntry = ceilingEntry.getValue(); } else { - consumerList = hashRing.firstEntry().getValue(); + hashRingEntry = hashRing.firstEntry().getValue(); } - - return consumerList.get(hash % consumerList.size()); + return hashRingEntry.getSelectedConsumer(); } finally { rwLock.readLock().unlock(); } @@ -135,13 +230,24 @@ public Map> getConsumerKeyHashRanges() { Map> result = new LinkedHashMap<>(); rwLock.readLock().lock(); try { + if (hashRing.isEmpty()) { + return result; + } int start = 0; - for (Map.Entry> entry: hashRing.entrySet()) { - for (Consumer consumer: entry.getValue()) { - result.computeIfAbsent(consumer, key -> new ArrayList<>()) + int lastKey = 0; + for (Map.Entry entry: hashRing.entrySet()) { + Consumer consumer = entry.getValue().getSelectedConsumer(); + result.computeIfAbsent(consumer, key -> new ArrayList<>()) .add(Range.of(start, entry.getKey())); - } - start = entry.getKey() + 1; + lastKey = entry.getKey(); + start = lastKey + 1; + } + // Handle wrap-around + HashRingEntry firstHashRingEntry = hashRing.firstEntry().getValue(); + Consumer firstSelectedConsumer = firstHashRingEntry.getSelectedConsumer(); + List ranges = result.get(firstSelectedConsumer); + if (lastKey != Integer.MAX_VALUE - 1) { + ranges.add(Range.of(lastKey + 1, Integer.MAX_VALUE - 1)); } } finally { rwLock.readLock().unlock(); diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ConsistentHashingStickyKeyConsumerSelectorTest.java b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ConsistentHashingStickyKeyConsumerSelectorTest.java index 48311c57338b5..7e21579741f0b 100644 --- a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ConsistentHashingStickyKeyConsumerSelectorTest.java +++ b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ConsistentHashingStickyKeyConsumerSelectorTest.java @@ -18,9 +18,9 @@ */ package org.apache.pulsar.broker.service; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; - import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -31,6 +31,7 @@ import java.util.stream.IntStream; import org.apache.pulsar.broker.service.BrokerServiceException.ConsumerAssignException; import org.apache.pulsar.client.api.Range; +import org.mockito.Mockito; import org.testng.Assert; import org.testng.annotations.Test; @@ -54,7 +55,7 @@ public void testConsumerSelect() throws ConsumerAssignException { selector.addConsumer(consumer2); final int N = 1000; - final double PERCENT_ERROR = 0.20; // 20 % + final double PERCENT_ERROR = 0.25; // 25 % Map selectionMap = new HashMap<>(); for (int i = 0; i < N; i++) { @@ -146,12 +147,17 @@ public void testGetConsumerKeyHashRanges() throws BrokerServiceException.Consume ConsistentHashingStickyKeyConsumerSelector selector = new ConsistentHashingStickyKeyConsumerSelector(3); List consumerName = Arrays.asList("consumer1", "consumer2", "consumer3"); List consumers = new ArrayList<>(); + long id=0; for (String s : consumerName) { - Consumer consumer = mock(Consumer.class); - when(consumer.consumerName()).thenReturn(s); + Consumer consumer = createMockConsumer(s, s, id++); selector.addConsumer(consumer); consumers.add(consumer); } + + // check that results are the same when called multiple times + assertThat(selector.getConsumerKeyHashRanges()) + .containsExactlyEntriesOf(selector.getConsumerKeyHashRanges()); + Map> expectedResult = new HashMap<>(); expectedResult.put(consumers.get(0), Arrays.asList( Range.of(119056335, 242013991), @@ -160,17 +166,47 @@ public void testGetConsumerKeyHashRanges() throws BrokerServiceException.Consume expectedResult.put(consumers.get(1), Arrays.asList( Range.of(0, 90164503), Range.of(90164504, 119056334), - Range.of(382436668, 722195656))); + Range.of(382436668, 722195656), + Range.of(1914695767, 2147483646))); expectedResult.put(consumers.get(2), Arrays.asList( Range.of(242013992, 242377547), Range.of(242377548, 382436667), Range.of(1656011843, 1707482097))); - for (Map.Entry> entry : selector.getConsumerKeyHashRanges().entrySet()) { - System.out.println(entry.getValue()); - Assert.assertEquals(entry.getValue(), expectedResult.get(entry.getKey())); - expectedResult.remove(entry.getKey()); + assertThat(selector.getConsumerKeyHashRanges()).containsExactlyInAnyOrderEntriesOf(expectedResult); + } + + @Test + public void testConsumersGetEvenlyMappedWhenThereAreCollisions() + throws BrokerServiceException.ConsumerAssignException { + ConsistentHashingStickyKeyConsumerSelector selector = new ConsistentHashingStickyKeyConsumerSelector(5); + List consumers = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + // use the same name for all consumers + Consumer consumer = createMockConsumer("consumer", "index " + i, i); + selector.addConsumer(consumer); + consumers.add(consumer); } - Assert.assertEquals(expectedResult.size(), 0); + // check that results are the same when called multiple times + assertThat(selector.getConsumerKeyHashRanges()) + .containsExactlyEntriesOf(selector.getConsumerKeyHashRanges()); + + Map> expectedResult = new HashMap<>(); + expectedResult.put(consumers.get(0), List.of(Range.of(306176209, 365902830))); + expectedResult.put(consumers.get(1), List.of(Range.of(216056714, 306176208))); + expectedResult.put(consumers.get(2), List.of(Range.of(365902831, 1240826377))); + expectedResult.put(consumers.get(3), List.of(Range.of(1240826378, 1862045174))); + expectedResult.put(consumers.get(4), List.of(Range.of(0, 216056713), Range.of(1862045175, 2147483646))); + assertThat(selector.getConsumerKeyHashRanges()).containsExactlyInAnyOrderEntriesOf(expectedResult); + } + + private static Consumer createMockConsumer(String consumerName, String toString, long id) { + // without stubOnly, the mock will record method invocations and run into OOME + Consumer consumer = mock(Consumer.class, Mockito.withSettings().stubOnly()); + when(consumer.consumerName()).thenReturn(consumerName); + when(consumer.getPriorityLevel()).thenReturn(0); + when(consumer.toString()).thenReturn(toString); + when(consumer.consumerId()).thenReturn(id); + return consumer; } // reproduces https://github.com/apache/pulsar/issues/22050 @@ -216,4 +252,46 @@ public void shouldRemoveConsumersFromConsumerKeyHashRanges() { // then there should be no mapping remaining Assert.assertEquals(selector.getConsumerKeyHashRanges().size(), 0); } + + @Test + public void testShouldNotChangeSelectedConsumerWhenConsumerIsRemoved() { + final ConsistentHashingStickyKeyConsumerSelector selector = new ConsistentHashingStickyKeyConsumerSelector(25); + final String consumerName = "consumer"; + final int numOfInitialConsumers = 25; + List consumers = new ArrayList<>(); + for (int i = 0; i < numOfInitialConsumers; i++) { + final Consumer consumer = createMockConsumer(consumerName, "index " + i, i); + consumers.add(consumer); + selector.addConsumer(consumer); + } + + int hashRangeSize = Integer.MAX_VALUE; + int validationPointCount = 100; + int increment = hashRangeSize / validationPointCount; + List selectedConsumerBeforeRemoval = new ArrayList<>(); + + for (int i = 0; i < validationPointCount; i++) { + selectedConsumerBeforeRemoval.add(selector.select(i * increment)); + } + + for (int i = 0; i < validationPointCount; i++) { + Consumer selected = selector.select(i * increment); + Consumer expected = selectedConsumerBeforeRemoval.get(i); + assertThat(selected.consumerId()).as("validationPoint %d", i).isEqualTo(expected.consumerId()); + } + + /* + TODO: failing test case + for (Consumer removedConsumer : consumers) { + selector.removeConsumer(removedConsumer); + for (int i = 0; i < validationPointCount; i++) { + Consumer selected = selector.select(i * increment); + Consumer expected = selectedConsumerBeforeRemoval.get(i); + if (expected != removedConsumer) { + assertThat(selected.consumerId()).as("validationPoint %d", i).isEqualTo(expected.consumerId()); + } + } + } + */ + } }