From 830a07012b811ce5417426fdb455027ca84d6f7e Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Fri, 11 Oct 2024 12:13:45 -0500 Subject: [PATCH] Subscribe to remote presence changes before setting the key --- .../push/ClientPresenceManager.java | 66 +++++++++++++++---- .../push/ClientPresenceManagerTest.java | 12 +++- 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java index 86939abb3..39a212fb6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java @@ -13,12 +13,12 @@ import io.lettuce.core.RedisFuture; import io.lettuce.core.ScriptOutputType; import io.lettuce.core.cluster.SlotHash; -import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; import java.io.IOException; import java.time.Duration; import java.util.ArrayList; @@ -27,6 +27,8 @@ import java.util.Random; import java.util.Set; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; @@ -63,6 +65,7 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter pruneMissingPeersFuture; private final Map displacementListenersByPresenceKey = new ConcurrentHashMap<>(); + private final Map> pendingPresenceSetsByPresenceKey = new ConcurrentHashMap<>(); private final Timer checkPresenceTimer; private final Timer setPresenceTimer; @@ -167,14 +170,30 @@ public void setPresent(final UUID accountUuid, final byte deviceId, displacementListenersByPresenceKey.put(presenceKey, displacementListener); - presenceCluster.useCluster(connection -> { - final RedisAdvancedClusterCommands commands = connection.sync(); + final CompletableFuture presenceFuture = new CompletableFuture<>(); + final CompletionStage previousFuture = pendingPresenceSetsByPresenceKey.put(presenceKey, presenceFuture); + if (previousFuture != null) { + log.warn("Unexpected pending presence"); + + } + + subscribeForRemotePresenceChanges(presenceKey); + + presenceCluster.withCluster(connection -> { + final RedisAdvancedClusterAsyncCommands commands = connection.async(); commands.sadd(connectedClientSetKey, presenceKey); - commands.setex(presenceKey, PRESENCE_EXPIRATION_SECONDS, managerId); + return commands.setex(presenceKey, PRESENCE_EXPIRATION_SECONDS, managerId); + }).whenComplete((result, throwable) -> { + if (throwable != null) { + presenceFuture.completeExceptionally(throwable); + } else { + presenceFuture.complete(null); + } }); - subscribeForRemotePresenceChanges(presenceKey); + presenceFuture.whenComplete( + (ignored, throwable) -> pendingPresenceSetsByPresenceKey.remove(presenceKey, presenceFuture)); }); } @@ -308,19 +327,38 @@ public void message(final RedisClusterNode node, final String channel, final Str if (channel.startsWith("__keyspace@0__:presence::{")) { if ("set".equals(message) || "del".equals(message)) { - // for "set", another process has overwritten this presence key, which means the client has connected to another host. + // "set" might mean the client has connected to another host, although it might just be our own `set`, + // because we subscribe for changes before setting the key. // for "del", another process has indicated the client should be disconnected - final boolean connectedElsewhere = "set".equals(message); + final boolean maybeConnectedElsewhere = "set".equals(message); // At this point, we're on a Lettuce IO thread and need to dispatch to a separate thread before making // synchronous Lettuce calls to avoid deadlocking. keyspaceNotificationExecutorService.execute(() -> { - try { - displacePresence(channel.substring("__keyspace@0__:".length()), connectedElsewhere); - remoteDisplacementMeter.increment(); - } catch (final Exception e) { - log.warn("Error displacing presence", e); - } + final String clientPresenceKey = channel.substring("__keyspace@0__:".length()); + + final CompletionStage pendingConnection = pendingPresenceSetsByPresenceKey.getOrDefault(clientPresenceKey, + CompletableFuture.completedFuture(null)); + + pendingConnection.thenCompose(ignored -> { + if (maybeConnectedElsewhere) { + return presenceCluster.withCluster(connection -> connection.async().get(clientPresenceKey)) + .thenApply(currentManagerId -> !managerId.equals(currentManagerId)); + } + + return CompletableFuture.completedFuture(true); + }) + .exceptionally(ignored -> true) + .thenAcceptAsync(shouldDisplace -> { + if (shouldDisplace) { + try { + displacePresence(clientPresenceKey, maybeConnectedElsewhere); + remoteDisplacementMeter.increment(); + } catch (final Exception e) { + log.warn("Error displacing presence", e); + } + } + }, keyspaceNotificationExecutorService); }); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java index 5ccac117f..474ca73d6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java @@ -386,8 +386,8 @@ void testDisconnectPresenceRemotely() { displaced.join(); } - @RepeatedTest(value = 10, failureThreshold = 1) - void testConcurrentConnection() { + @RepeatedTest(value = 100) + void testConcurrentConnection() throws Exception { final UUID uuid1 = UUID.randomUUID(); final byte deviceId = 1; @@ -400,7 +400,13 @@ void testConcurrentConnection() { server1Thread.start(); server2Thread.start(); - assertTimeoutPreemptively(Duration.ofSeconds(10), displaced::join); + displaced.join(); + server2Thread.join(); + server1Thread.join(); + + while (server1.isLocallyPresent(uuid1, deviceId) == server2.isLocallyPresent(uuid1, deviceId)) { + Thread.sleep(50); + } } }