Skip to content

Commit

Permalink
Publish "messages persisted" events
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal committed Nov 7, 2024
1 parent 5aaf4ca commit e536a40
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
.build()
.toByteArray();

private final byte[] MESSAGES_PERSISTED_EVENT_BYTES = ClientEvent.newBuilder()
.setMessagesPersisted(MessagesPersistedEvent.getDefaultInstance())
.build()
.toByteArray();

@Nullable
private FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubConnection;

Expand Down Expand Up @@ -240,10 +245,29 @@ public CompletionStage<Boolean> handleNewMessageAvailable(final UUID accountIden
}

return pubSubConnection.withPubSubConnection(connection ->
connection.async().spublish(getClientPresenceKey(accountIdentifier, deviceId), NEW_MESSAGE_EVENT_BYTES))
connection.async().spublish(getClientPresenceKey(accountIdentifier, deviceId), NEW_MESSAGE_EVENT_BYTES))
.thenApply(listeners -> listeners > 0);
}

/**
* Publishes an event notifying a specific device that messages have been persisted from short-term to long-term
* storage.
*
* @param accountIdentifier the account identifier for which messages have been persisted
* @param deviceId the ID of the device within the target account
*
* @return a future that completes when the event has been published
*/
public CompletionStage<Void> handleMessagesPersisted(final UUID accountIdentifier, final byte deviceId) {
if (pubSubConnection == null) {
throw new IllegalStateException("Presence manager not started");
}

return pubSubConnection.withPubSubConnection(connection ->
connection.async().spublish(getClientPresenceKey(accountIdentifier, deviceId), MESSAGES_PERSISTED_EVENT_BYTES))
.thenRun(Util.NOOP);
}

/**
* Tests whether a client with the given account/device is connected to this presence manager instance.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.util.Util;
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;

Expand All @@ -30,6 +31,7 @@ public class MessagePersister implements Managed {
private final MessagesCache messagesCache;
private final MessagesManager messagesManager;
private final AccountsManager accountsManager;
private final PubSubClientEventManager pubSubClientEventManager;

private final Duration persistDelay;

Expand Down Expand Up @@ -63,13 +65,16 @@ public class MessagePersister implements Managed {

public MessagePersister(final MessagesCache messagesCache, final MessagesManager messagesManager,
final AccountsManager accountsManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, final Duration persistDelay,
final PubSubClientEventManager pubSubClientEventManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final Duration persistDelay,
final int dedicatedProcessWorkerThreadCount
) {

this.messagesCache = messagesCache;
this.messagesManager = messagesManager;
this.accountsManager = accountsManager;
this.pubSubClientEventManager = pubSubClientEventManager;
this.persistDelay = persistDelay;
this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount];

Expand Down Expand Up @@ -206,6 +211,7 @@ void persistQueue(final Account account, final Device device) throws MessagePers
maybeUnlink(account, deviceId); // may throw, in which case we'll retry later by the usual mechanism
} finally {
messagesCache.unlockQueueForPersistence(accountUuid, deviceId);
pubSubClientEventManager.handleMessagesPersisted(accountUuid, deviceId);
sample.stop(persistQueueTimer);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ record CommandDependencies(
ReportMessageManager reportMessageManager,
MessagesCache messagesCache,
MessagesManager messagesManager,
PubSubClientEventManager pubSubClientEventManager,
KeysManager keysManager,
APNSender apnSender,
FcmSender fcmSender,
Expand Down Expand Up @@ -271,6 +272,7 @@ static CommandDependencies build(
reportMessageManager,
messagesCache,
messagesManager,
pubSubClientEventManager,
keys,
apnSender,
fcmSender,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ protected void run(Environment environment, Namespace namespace, WhisperServerCo
});
}

final MessagePersister messagePersister = new MessagePersister(deps.messagesCache(), deps.messagesManager(),
deps.accountsManager(), deps.dynamicConfigurationManager(),
final MessagePersister messagePersister = new MessagePersister(deps.messagesCache(),
deps.messagesManager(),
deps.accountsManager(),
deps.pubSubClientEventManager(),
deps.dynamicConfigurationManager(),
Duration.ofMinutes(configuration.getMessageCacheConfiguration().getPersistDelayMinutes()),
namespace.getInt(WORKER_COUNT));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.IntStream;
import org.junit.jupiter.api.AfterAll;
Expand Down Expand Up @@ -97,8 +99,8 @@ void handleClientConnected(final boolean displaceRemotely) throws InterruptedExc
final byte deviceId = Device.PRIMARY_ID;

final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false);
final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false);

final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false);
final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false);

localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
Expand Down Expand Up @@ -144,15 +146,12 @@ void handleNewMessageAvailable(final boolean messageAvailableRemotely) throws In
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;

final AtomicBoolean messageReceived = new AtomicBoolean(false);
final CountDownLatch messageReceivedLatch = new CountDownLatch(1);

localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
@Override
public void handleNewMessageAvailable() {
synchronized (messageReceived) {
messageReceived.set(true);
messageReceived.notifyAll();
}
messageReceivedLatch.countDown();
}
}).toCompletableFuture().join();

Expand All @@ -161,13 +160,32 @@ public void handleNewMessageAvailable() {

assertTrue(messagePresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());

synchronized (messageReceived) {
while (!messageReceived.get()) {
messageReceived.wait();
assertTrue(messageReceivedLatch.await(2, TimeUnit.SECONDS),
"Message not received within time limit");
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void handleMessagesPersisted(final boolean messagesPersistedRemotely) throws InterruptedException {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;

final CountDownLatch messagesPersistedLatch = new CountDownLatch(1);

localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
@Override
public void handleMessagesPersistedPubSub() {
messagesPersistedLatch.countDown();
}
}
}).toCompletableFuture().join();

final PubSubClientEventManager persistingPresenceManager =
messagesPersistedRemotely ? remotePresenceManager : localPresenceManager;

persistingPresenceManager.handleMessagesPersisted(accountIdentifier, deviceId).toCompletableFuture().join();

assertTrue(messageReceived.get());
assertTrue(messagesPersistedLatch.await(2, TimeUnit.SECONDS),
"Message persistence event not received within time limit");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.protobuf.ByteString;
Expand All @@ -32,6 +33,7 @@
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
Expand All @@ -53,6 +55,7 @@ class MessagePersisterIntegrationTest {
private ExecutorService messageDeletionExecutorService;
private MessagesCache messagesCache;
private MessagesManager messagesManager;
private PubSubClientEventManager pubSubClientEventManager;
private MessagePersister messagePersister;
private Account account;

Expand Down Expand Up @@ -82,8 +85,10 @@ void setUp() throws Exception {
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
messageDeletionExecutorService);
pubSubClientEventManager = mock(PubSubClientEventManager.class);

messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY, 1);
pubSubClientEventManager, dynamicConfigurationManager, PERSIST_DELAY, 1);

account = mock(Account.class);

Expand Down Expand Up @@ -178,6 +183,8 @@ public boolean handleMessagesPersisted() {
.toList();

assertEquals(expectedMessages, persistedMessages);

verify(pubSubClientEventManager).handleMessagesPersisted(account.getUuid(), Device.PRIMARY_ID);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import reactor.core.scheduler.Scheduler;
Expand All @@ -66,6 +67,7 @@ class MessagePersisterTest {
private MessagePersister messagePersister;
private AccountsManager accountsManager;
private MessagesManager messagesManager;
private PubSubClientEventManager pubSubClientEventManager;
private Account destinationAccount;

private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
Expand Down Expand Up @@ -100,7 +102,8 @@ void setUp() throws Exception {
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService,
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
pubSubClientEventManager = mock(PubSubClientEventManager.class);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, pubSubClientEventManager,
dynamicConfigurationManager, PERSIST_DELAY, 1);

when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
Expand Down Expand Up @@ -154,6 +157,8 @@ void testPersistNextQueuesSingleQueue() {
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID),
eq(DESTINATION_DEVICE));
assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());

verify(pubSubClientEventManager).handleMessagesPersisted(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID);
}

@Test
Expand Down Expand Up @@ -223,6 +228,8 @@ void testPersistQueueRetry() {
assertEquals(List.of(queueName),
messagesCache.getQueuesToPersist(SlotHash.getSlot(queueName),
Instant.now().plus(messagePersister.getPersistDelay()), 1));

verify(pubSubClientEventManager).handleMessagesPersisted(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID);
}

@Test
Expand All @@ -241,6 +248,8 @@ void testPersistQueueRetryLoop() {
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
assertThrows(MessagePersistenceException.class,
() -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)));

verify(pubSubClientEventManager).handleMessagesPersisted(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void setUp() {
null,
null,
null,
null,
pushNotificationExperimentSamples,
null,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ private TestNotifyIdleDevicesCommand(final MessagesManager messagesManager,
null,
null,
null,
null,
null);

this.idleDeviceNotificationScheduler = idleDeviceNotificationScheduler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public TestStartPushNotificationExperimentCommand(
null,
null,
null,
null,
pushNotificationExperimentSamples,
null,
null,
Expand Down

0 comments on commit e536a40

Please sign in to comment.