diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 7a6553fcc..5fcf25655 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -632,7 +632,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro keyspaceNotificationDispatchExecutor); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor, - messageDeliveryScheduler, messageDeletionAsyncExecutor, clock); + messageDeliveryScheduler, messageDeletionAsyncExecutor, clock, dynamicConfigurationManager); ClientReleaseManager clientReleaseManager = new ClientReleaseManager(clientReleases, recurringJobExecutor, config.getClientReleaseConfiguration().refreshInterval(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/MessageCacheConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/MessageCacheConfiguration.java index 82146f1a4..62519ab32 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/MessageCacheConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/MessageCacheConfiguration.java @@ -27,5 +27,4 @@ public FaultTolerantRedisClusterFactory getRedisClusterConfiguration() { public int getPersistDelayMinutes() { return persistDelayMinutes; } - } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java index 726e9a27c..4e55f47cf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java @@ -5,21 +5,9 @@ package org.whispersystems.textsecuregcm.configuration.dynamic; -import java.util.List; - -import javax.validation.constraints.NotNull; - -public record DynamicMessagesConfiguration(@NotNull List dynamoKeySchemes) { - public enum DynamoKeyScheme { - TRADITIONAL, - LAZY_DELETION; - } +public record DynamicMessagesConfiguration(boolean storeSharedMrmData, boolean mrmViewExperimentEnabled) { public DynamicMessagesConfiguration() { - this(List.of(DynamoKeyScheme.TRADITIONAL)); - } - - public DynamoKeyScheme writeKeyScheme() { - return dynamoKeySchemes().getLast(); + this(false, false); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 77791c152..bc84b3ffe 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -24,7 +24,6 @@ import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.responses.ApiResponse; - import java.security.MessageDigest; import java.time.Clock; import java.time.Duration; @@ -73,8 +72,8 @@ import org.apache.commons.lang3.StringUtils; import org.glassfish.jersey.server.ManagedAsync; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; -import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient; +import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.zkgroup.ServerSecretParams; import org.signal.libsignal.zkgroup.VerificationFailedException; @@ -261,7 +260,7 @@ public MessageController( @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) @ManagedAsync - @Operation( + @Operation( summary = "Send a message", description = """ Deliver a message to a single recipient. May be authenticated or unauthenticated; if unauthenticated, @@ -309,9 +308,10 @@ public Response sendMessage(@ReadOnly @Auth Optional source if (groupSendToken != null) { if (!source.isEmpty() || !accessKey.isEmpty()) { - throw new BadRequestException("Group send endorsement tokens should not be combined with other authentication"); + throw new BadRequestException( + "Group send endorsement tokens should not be combined with other authentication"); } else if (isStory) { - throw new BadRequestException("Group send endorsement tokens should not be sent for story messages"); + throw new BadRequestException("Group send endorsement tokens should not be sent for story messages"); } } @@ -346,8 +346,7 @@ public Response sendMessage(@ReadOnly @Auth Optional source } final Optional spamReportToken = switch (senderType) { - case SENDER_TYPE_IDENTIFIED -> - reportSpamTokenProvider.makeReportSpamToken(context, source.get(), destination); + case SENDER_TYPE_IDENTIFIED -> reportSpamTokenProvider.makeReportSpamToken(context, source.get(), destination); default -> Optional.empty(); }; @@ -470,7 +469,7 @@ public Response sendMessage(@ReadOnly @Auth Optional source throw new WebApplicationException(Response.status(409) .type(MediaType.APPLICATION_JSON_TYPE) .entity(new MismatchedDevices(e.getMissingDevices(), - e.getExtraDevices())) + e.getExtraDevices())) .build()); } catch (StaleDevicesException e) { throw new WebApplicationException(Response.status(410) @@ -621,27 +620,28 @@ public Response sendMultiRecipientMessage( Collection accountMismatchedDevices = new ArrayList<>(); Collection accountStaleDevices = new ArrayList<>(); recipients.values().forEach(recipient -> { - final Account account = recipient.account(); - - try { - DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(), Collections.emptySet()); - - DestinationDeviceValidator.validateRegistrationIds( - account, - recipient.deviceIdToRegistrationId().entrySet(), - Map.Entry::getKey, - e -> Integer.valueOf(e.getValue()), - recipient.serviceIdentifier().identityType() == IdentityType.PNI); - } catch (MismatchedDevicesException e) { - accountMismatchedDevices.add( - new AccountMismatchedDevices( - recipient.serviceIdentifier(), - new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); - } catch (StaleDevicesException e) { - accountStaleDevices.add( - new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices()))); - } - }); + final Account account = recipient.account(); + + try { + DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(), + Collections.emptySet()); + + DestinationDeviceValidator.validateRegistrationIds( + account, + recipient.deviceIdToRegistrationId().entrySet(), + Map.Entry::getKey, + e -> Integer.valueOf(e.getValue()), + recipient.serviceIdentifier().identityType() == IdentityType.PNI); + } catch (MismatchedDevicesException e) { + accountMismatchedDevices.add( + new AccountMismatchedDevices( + recipient.serviceIdentifier(), + new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); + } catch (StaleDevicesException e) { + accountStaleDevices.add( + new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices()))); + } + }); if (!accountMismatchedDevices.isEmpty()) { return Response .status(409) @@ -667,6 +667,11 @@ public Response sendMultiRecipientMessage( } try { + @Nullable final byte[] sharedMrmKey = + dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().storeSharedMrmData() + ? messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage) + : null; + CompletableFuture.allOf( recipients.values().stream() .flatMap(recipientData -> { @@ -692,8 +697,7 @@ public Response sendMultiRecipientMessage( sentMessageCounter.increment(); sendCommonPayloadMessage( destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, - online, - isStory, isUrgent, payload); + online, isStory, isUrgent, payload, sharedMrmKey); }, multiRecipientMessageExecutor)); }) @@ -739,8 +743,8 @@ private void checkAccessKeys( .filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess)) .map(account -> account.getUnidentifiedAccessKey() - .filter(b -> b.length == keyLength) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED))) + .filter(b -> b.length == keyLength) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED))) .reduce(new byte[keyLength], (a, b) -> { final byte[] xor = new byte[keyLength]; @@ -828,23 +832,28 @@ public CompletableFuture removePendingMessage(@ReadOnly @Auth Authenti auth.getAuthenticatedDevice(), uuid, null) - .thenAccept(maybeDeletedMessage -> { - maybeDeletedMessage.ifPresent(deletedMessage -> { + .thenAccept(maybeRemovedMessage -> maybeRemovedMessage.ifPresent(removedMessage -> { - WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getServerTimestamp(), - auth.getAuthenticatedDevice()); + WebSocketConnection.recordMessageDeliveryDuration(removedMessage.serverTimestamp(), + auth.getAuthenticatedDevice()); - if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) { + if (removedMessage.sourceServiceId().isPresent() + && removedMessage.envelopeType() != Type.SERVER_DELIVERY_RECEIPT) { + if (removedMessage.sourceServiceId().get() instanceof AciServiceIdentifier aciServiceIdentifier) { try { - receiptSender.sendReceipt( - ServiceIdentifier.valueOf(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(), - AciServiceIdentifier.valueOf(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp()); + receiptSender.sendReceipt(removedMessage.destinationServiceId(), auth.getAuthenticatedDevice().getId(), + aciServiceIdentifier, removedMessage.clientTimestamp()); } catch (Exception e) { logger.warn("Failed to send delivery receipt", e); } + } else { + // If source service ID is present and the envelope type is not a server delivery receipt, then + // the source service ID *should always* be an ACI -- PNIs are receive-only, so they can only be the + // "source" via server delivery receipts + logger.warn("Source service ID unexpectedly a PNI service ID"); } - }); - }) + } + })) .thenApply(Util.ASYNC_EMPTY_RESPONSE); } @@ -943,19 +952,25 @@ private void sendCommonPayloadMessage(Account destinationAccount, boolean online, boolean story, boolean urgent, - byte[] payload) { + byte[] payload, + @Nullable byte[] sharedMrmKey) { final Envelope.Builder messageBuilder = Envelope.newBuilder(); final long serverTimestamp = System.currentTimeMillis(); messageBuilder .setType(Type.UNIDENTIFIED_SENDER) - .setTimestamp(timestamp == 0 ? serverTimestamp : timestamp) + .setClientTimestamp(timestamp == 0 ? serverTimestamp : timestamp) .setServerTimestamp(serverTimestamp) - .setContent(ByteString.copyFrom(payload)) .setStory(story) .setUrgent(urgent) - .setDestinationUuid(serviceIdentifier.toServiceIdentifierString()); + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()); + + if (sharedMrmKey != null) { + messageBuilder.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)); + } + // mrm views phase 1: always set content + messageBuilder.setContent(ByteString.copyFrom(payload)); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index 3e0f8e0f7..37cd1d886 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -31,15 +31,15 @@ public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIden final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder(); envelopeBuilder.setType(envelopeType) - .setTimestamp(timestamp) + .setClientTimestamp(timestamp) .setServerTimestamp(System.currentTimeMillis()) - .setDestinationUuid(destinationIdentifier.toServiceIdentifierString()) + .setDestinationServiceId(destinationIdentifier.toServiceIdentifierString()) .setStory(story) .setUrgent(urgent); if (sourceAccount != null && sourceDeviceId != null) { envelopeBuilder - .setSourceUuid(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString()) + .setSourceServiceId(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString()) .setSourceDevice(sourceDeviceId.intValue()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java index 226d243a1..505988a60 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java @@ -40,15 +40,15 @@ public record OutgoingMessageEntity(UUID guid, public MessageProtos.Envelope toEnvelope() { final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() .setType(MessageProtos.Envelope.Type.forNumber(type())) - .setTimestamp(timestamp()) + .setClientTimestamp(timestamp()) .setServerTimestamp(serverTimestamp()) - .setDestinationUuid(destinationUuid().toServiceIdentifierString()) + .setDestinationServiceId(destinationUuid().toServiceIdentifierString()) .setServerGuid(guid().toString()) .setStory(story) .setUrgent(urgent); if (sourceUuid() != null) { - builder.setSourceUuid(sourceUuid().toServiceIdentifierString()); + builder.setSourceServiceId(sourceUuid().toServiceIdentifierString()); builder.setSourceDevice(sourceDevice()); } @@ -72,10 +72,10 @@ public static OutgoingMessageEntity fromEnvelope(final MessageProtos.Envelope en return new OutgoingMessageEntity( UUID.fromString(envelope.getServerGuid()), envelope.getType().getNumber(), - envelope.getTimestamp(), - envelope.hasSourceUuid() ? ServiceIdentifier.valueOf(envelope.getSourceUuid()) : null, + envelope.getClientTimestamp(), + envelope.hasSourceServiceId() ? ServiceIdentifier.valueOf(envelope.getSourceServiceId()) : null, envelope.getSourceDevice(), - envelope.hasDestinationUuid() ? ServiceIdentifier.valueOf(envelope.getDestinationUuid()) : null, + envelope.hasDestinationServiceId() ? ServiceIdentifier.valueOf(envelope.getDestinationServiceId()) : null, envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null, envelope.getContent().toByteArray(), envelope.getServerTimestamp(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java index a7b473f24..db8e28662 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java @@ -50,11 +50,11 @@ public void measureAccountOutgoingMessageUuidMismatches(final Account account, public void measureAccountEnvelopeUuidMismatches(final Account account, final MessageProtos.Envelope envelope) { - if (envelope.hasDestinationUuid()) { + if (envelope.hasDestinationServiceId()) { try { - measureAccountDestinationUuidMismatches(account, ServiceIdentifier.valueOf(envelope.getDestinationUuid())); + measureAccountDestinationUuidMismatches(account, ServiceIdentifier.valueOf(envelope.getDestinationServiceId())); } catch (final IllegalArgumentException ignored) { - logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationUuid()); + logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationServiceId()); } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java index ecc2ef724..a209aca73 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java @@ -92,7 +92,7 @@ public void sendMessage(final Account account, final Device device, final Envelo CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent), URGENT_TAG_NAME, String.valueOf(message.getUrgent()), STORY_TAG_NAME, String.valueOf(message.getStory()), - SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceUuid())) + SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId())) .increment(); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java index fff044b84..74f623518 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java @@ -45,10 +45,10 @@ public void sendReceipt(ServiceIdentifier sourceIdentifier, byte sourceDeviceId, destinationAccount -> { final Envelope.Builder message = Envelope.newBuilder() .setServerTimestamp(System.currentTimeMillis()) - .setSourceUuid(sourceIdentifier.toServiceIdentifierString()) - .setSourceDevice((int) sourceDeviceId) - .setDestinationUuid(destinationIdentifier.toServiceIdentifierString()) - .setTimestamp(messageId) + .setSourceServiceId(sourceIdentifier.toServiceIdentifierString()) + .setSourceDevice(sourceDeviceId) + .setDestinationServiceId(destinationIdentifier.toServiceIdentifierString()) + .setClientTimestamp(messageId) .setType(Envelope.Type.SERVER_DELIVERY_RECEIPT) .setUrgent(false); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index e98832099..de4f5da62 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -138,12 +138,13 @@ void sendMessageToSelf( final long serverTimestamp = System.currentTimeMillis(); final Envelope envelope = Envelope.newBuilder() .setType(Envelope.Type.forNumber(message.type())) - .setTimestamp(serverTimestamp) + .setClientTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp) - .setDestinationUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString()) + .setDestinationServiceId( + new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString()) .setContent(ByteString.copyFrom(contents.get())) - .setSourceUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString()) - .setSourceDevice((int) Device.PRIMARY_ID) + .setSourceServiceId(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString()) + .setSourceDevice(Device.PRIMARY_ID) .setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString()) .setUrgent(true) .build(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 8a9f088c4..e8b660268 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -8,10 +8,10 @@ import static com.codahale.metrics.MetricRegistry.name; import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.lifecycle.Managed; import io.lettuce.core.ScoredValue; -import io.lettuce.core.ScriptOutputType; import io.lettuce.core.ZAddArgs; import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.models.partitions.RedisClusterNode; @@ -20,6 +20,7 @@ import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.time.Clock; import java.time.Duration; @@ -38,14 +39,17 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Predicate; -import java.util.stream.Collectors; import javax.annotation.Nullable; import org.reactivestreams.Publisher; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.experiment.Experiment; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; -import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.util.Pair; @@ -57,6 +61,62 @@ import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; +/** + * Manages short-term storage of messages in Redis. Messages are frequently delivered to their destination and deleted + * shortly after they reach the server, and this cache acts as a low-latency holding area for new messages, reducing + * load on higher-latency, longer-term storage systems. Redis in particular provides keyspace notifications, which act + * as a form of pub-sub notifications to alert listeners when new messages arrive. + *

+ * The following structures are used: + *

+ *
{@code queueKey}
+ *
A sorted set of messages in a device’s queue. A message’s score is its queue-local message ID. See + * Redis.io: The + * Sorted Set data type for background on scores and this data structure.
+ *
{@code queueMetadataKey}
+ *
A hash containing message guids and their queue-local message ID. It also contains a {@code counter} key, which is + * incremented to supply the next message ID. This is used to remove a message by GUID from {@code queueKey} by its + * local messageId.
+ *
{@code sharedMrmKey}
+ *
A hash containing a single multi-recipient message pending delivery. It contains: + *
    + *
  • {@code data} - the serialized SealedSenderMultiRecipientMessage data
  • + *
  • fields with each recipient device's “view” into the payload ({@link SealedSenderMultiRecipientMessage#serializedRecipientView(SealedSenderMultiRecipientMessage.Recipient)}
  • + *
+ * Note: this is shared among all of the message's recipients, and it may be located in any Redis shard. As each recipient’s + * message is delivered, its corresponding view is idempotently removed. When {@code data} is the only remaining + * field, the hash will be deleted. + *
+ *
{@code queueLockKey}
+ *
Used to indicate that a queue is being modified by the {@link MessagePersister} and that {@code get_items} should + * return an empty list.
+ *
{@code queueTotalIndexKey}
+ *
A sorted set of all queues in a shard. A queue’s score is the timestamp of its oldest message, which is used by + * the {@link MessagePersister} to prioritize queues to persist.
+ *
+ *

+ * At a high level, the process is: + *

    + *
  1. Insert: the queue metadata is queried for the next incremented message ID. The message data is inserted into + * the queue at that ID, and the message GUID is inserted in the queue metadata.
  2. + *
  3. Get: a batch of messages are retrieved from the queue, potentially with an after-message-ID offset.
  4. + *
  5. Remove: a set of messages are remove by GUID. For each GUID, the message ID is retrieved from the queue metadata, + * and then that single-value range is removed from the queue.
  6. + *
+ * For multi-recipient messages (sometimes abbreviated “MRM”), there are similar operations on the common data during + * insert, get, and remove. MRM inserts must occur before individual queue inserts, while removal is considered + * best-effort, and uses key expiration as back-stop garbage collection. + *

+ * For atomicity, many operations are implemented as Lua scripts that are executed on the Redis server using + * {@code EVAL}/{@code EVALSHA}. + * + * @see MessagesCacheInsertScript + * @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript + * @see MessagesCacheGetItemsScript + * @see MessagesCacheRemoveByGuidScript + * @see MessagesCacheRemoveRecipientViewFromMrmDataScript + * @see MessagesCacheRemoveQueueScript + */ public class MessagesCache extends RedisClusterPubSubAdapter implements Managed { private final FaultTolerantRedisCluster redisCluster; @@ -69,17 +129,22 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp // messageDeletionExecutorService wrapped into a reactor Scheduler private final Scheduler messageDeletionScheduler; - private final ClusterLuaScript insertScript; - private final ClusterLuaScript removeByGuidScript; - private final ClusterLuaScript getItemsScript; - private final ClusterLuaScript removeQueueScript; - private final ClusterLuaScript getQueuesToPersistScript; + private final DynamicConfigurationManager dynamicConfigurationManager; + + private final MessagesCacheInsertScript insertScript; + private final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript; + private final MessagesCacheRemoveByGuidScript removeByGuidScript; + private final MessagesCacheGetItemsScript getItemsScript; + private final MessagesCacheRemoveQueueScript removeQueueScript; + private final MessagesCacheGetQueuesToPersistScript getQueuesToPersistScript; + private final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript; private final ReentrantLock messageListenersLock = new ReentrantLock(); private final Map messageListenersByQueueName = new HashMap<>(); private final Map queueNamesByMessageListener = new IdentityHashMap<>(); private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert")); + private final Timer insertSharedMrmPayloadTimer = Metrics.timer(name(MessagesCache.class, "insertSharedMrmPayload")); private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get")); private final Timer getQueuesToPersistTimer = Metrics.timer(name(MessagesCache.class, "getQueuesToPersist")); private final Timer removeByGuidTimer = Metrics.timer(name(MessagesCache.class, "removeByGuid")); @@ -95,6 +160,9 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp name(MessagesCache.class, "messageAvailabilityListenerRemovedAfterAdd")); private final Counter prunedStaleSubscriptionCounter = Metrics.counter( name(MessagesCache.class, "prunedStaleSubscription")); + private final Counter mrmContentRetrievedCounter = Metrics.counter(name(MessagesCache.class, "mrmViewRetrieved")); + private final Counter sharedMrmDataKeyRemovedCounter = Metrics.counter( + name(MessagesCache.class, "sharedMrmKeyRemoved")); static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); @@ -102,16 +170,49 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp private static final String QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue::"; private static final String PERSISTING_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_persisting::"; + private static final String MRM_VIEWS_EXPERIMENT_NAME = "mrmViews"; + @VisibleForTesting static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10); private static final String GET_FLUX_NAME = MetricsUtil.name(MessagesCache.class, "get"); private static final int PAGE_SIZE = 100; + private static final int REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY = 8; + private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class); public MessagesCache(final FaultTolerantRedisCluster redisCluster, final ExecutorService notificationExecutorService, - final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock) + final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock, + final DynamicConfigurationManager dynamicConfigurationManager) + throws IOException { + this( + redisCluster, + notificationExecutorService, + messageDeliveryScheduler, + messageDeletionExecutorService, + clock, + dynamicConfigurationManager, + new MessagesCacheInsertScript(redisCluster), + new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(redisCluster), + new MessagesCacheGetItemsScript(redisCluster), + new MessagesCacheRemoveByGuidScript(redisCluster), + new MessagesCacheRemoveQueueScript(redisCluster), + new MessagesCacheGetQueuesToPersistScript(redisCluster), + new MessagesCacheRemoveRecipientViewFromMrmDataScript(redisCluster) + ); + } + + @VisibleForTesting + MessagesCache(final FaultTolerantRedisCluster redisCluster, final ExecutorService notificationExecutorService, + final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock, + final DynamicConfigurationManager dynamicConfigurationManager, + final MessagesCacheInsertScript insertScript, + final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript, + final MessagesCacheGetItemsScript getItemsScript, final MessagesCacheRemoveByGuidScript removeByGuidScript, + final MessagesCacheRemoveQueueScript removeQueueScript, + final MessagesCacheGetQueuesToPersistScript getQueuesToPersistScript, + final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript) throws IOException { this.redisCluster = redisCluster; @@ -123,14 +224,15 @@ public MessagesCache(final FaultTolerantRedisCluster redisCluster, final Executo this.messageDeletionExecutorService = messageDeletionExecutorService; this.messageDeletionScheduler = Schedulers.fromExecutorService(messageDeletionExecutorService, "messageDeletion"); - this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); - this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", - ScriptOutputType.MULTI); - this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI); - this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", - ScriptOutputType.STATUS); - this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua", - ScriptOutputType.MULTI); + this.dynamicConfigurationManager = dynamicConfigurationManager; + + this.insertScript = insertScript; + this.insertMrmScript = insertMrmScript; + this.removeByGuidScript = removeByGuidScript; + this.getItemsScript = getItemsScript; + this.removeQueueScript = removeQueueScript; + this.getQueuesToPersistScript = getQueuesToPersistScript; + this.removeRecipientViewFromMrmDataScript = removeRecipientViewFromMrmDataScript; } @Override @@ -164,51 +266,51 @@ private void resubscribeAll() { public long insert(final UUID guid, final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope message) { final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); - return (long) insertTimer.record(() -> - insertScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), - getMessageQueueMetadataKey(destinationUuid, destinationDevice), - getQueueIndexKey(destinationUuid, destinationDevice)), - List.of(messageWithGuid.toByteArray(), - String.valueOf(message.getServerTimestamp()).getBytes(StandardCharsets.UTF_8), - guid.toString().getBytes(StandardCharsets.UTF_8)))); + return insertTimer.record(() -> insertScript.execute(destinationUuid, destinationDevice, messageWithGuid)); + } + + public byte[] insertSharedMultiRecipientMessagePayload(UUID mrmGuid, + SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { + final byte[] sharedMrmKey = getSharedMrmKey(mrmGuid); + insertSharedMrmPayloadTimer.record(() -> insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage)); + return sharedMrmKey; } - public CompletableFuture> remove(final UUID destinationUuid, + public CompletableFuture> remove(final UUID destinationUuid, final byte destinationDevice, final UUID messageGuid) { return remove(destinationUuid, destinationDevice, List.of(messageGuid)) - .thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.get(0))); + .thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.getFirst())); } - @SuppressWarnings("unchecked") - public CompletableFuture> remove(final UUID destinationUuid, - final byte destinationDevice, - final List messageGuids) { + public CompletableFuture> remove(final UUID destinationUuid, + final byte destinationDevice, final List messageGuids) { final Timer.Sample sample = Timer.start(); - return removeByGuidScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, destinationDevice), - getMessageQueueMetadataKey(destinationUuid, destinationDevice), - getQueueIndexKey(destinationUuid, destinationDevice)), - messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8)) - .collect(Collectors.toList())) - .thenApplyAsync(result -> { - List serialized = (List) result; + return removeByGuidScript.execute(destinationUuid, destinationDevice, messageGuids) + .thenApplyAsync(serialized -> { - final List removedMessages = new ArrayList<>(serialized.size()); + final List removedMessages = new ArrayList<>(serialized.size()); + final List sharedMrmKeysToUpdate = new ArrayList<>(); for (final byte[] bytes : serialized) { try { - removedMessages.add(MessageProtos.Envelope.parseFrom(bytes)); + final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(bytes); + removedMessages.add(RemovedMessage.fromEnvelope(envelope)); + if (envelope.hasSharedMrmKey()) { + sharedMrmKeysToUpdate.add(envelope.getSharedMrmKey().toByteArray()); + } } catch (final InvalidProtocolBufferException e) { logger.warn("Failed to parse envelope", e); } } + removeRecipientViewFromMrmData(sharedMrmKeysToUpdate, destinationUuid, destinationDevice); return removedMessages; - }, messageDeletionExecutorService) - .whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer)); + }, messageDeletionExecutorService).whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer)); + } public boolean hasMessages(final UUID destinationUuid, final byte destinationDevice) { @@ -251,7 +353,7 @@ public Publisher get(final UUID destinationUuid, final b private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message, long earliestAllowableTimestamp) { - return message.hasEphemeral() && message.getEphemeral() && message.getTimestamp() < earliestAllowableTimestamp; + return message.getEphemeral() && message.getClientTimestamp() < earliestAllowableTimestamp; } private void discardStaleEphemeralMessages(final UUID destinationUuid, final byte destinationDevice, @@ -283,37 +385,101 @@ Flux getAllMessages(final UUID destinationUuid, final by // we want to ensure we don’t accidentally block the Lettuce/netty i/o executors .publishOn(messageDeliveryScheduler) .map(Pair::first) - .flatMapIterable(queueItems -> { - final List envelopes = new ArrayList<>(queueItems.size() / 2); + .concatMap(queueItems -> { + + final List> envelopes = new ArrayList<>(queueItems.size() / 2); for (int i = 0; i < queueItems.size() - 1; i += 2) { try { final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i)); - envelopes.add(message); + final Mono messageMono; + if (message.hasSharedMrmKey()) { + maybeRunMrmViewExperiment(message, destinationUuid, destinationDevice); + + // mrm views phase 1: messageMono for sharedMrmKey is always Mono.just(), because messages always have content + messageMono = Mono.just(message.toBuilder().clearSharedMrmKey().build()); + } else { + messageMono = Mono.just(message); + } + + envelopes.add(messageMono); + } catch (InvalidProtocolBufferException e) { logger.warn("Failed to parse envelope", e); } } - return envelopes; + return Flux.mergeSequential(envelopes); }); } - private Flux, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice, + /** + * Runs the fetch and compare logic for the MRM view experiment, if it is enabled. + * + * @see DynamicMessagesConfiguration#mrmViewExperimentEnabled() + */ + private void maybeRunMrmViewExperiment(final MessageProtos.Envelope mrmMessage, final UUID destinationUuid, + final byte destinationDevice) { + if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration() + .mrmViewExperimentEnabled()) { + + final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME); + + final byte[] key = mrmMessage.getSharedMrmKey().toByteArray(); + final byte[] sharedMrmViewKey = MessagesCache.getSharedMrmViewKey( + new AciServiceIdentifier(destinationUuid), destinationDevice); + + final Mono mrmMessageMono = Mono.from(redisCluster.withBinaryClusterReactive( + conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey) + .collectList() + .publishOn(messageDeliveryScheduler) + .handle((mrmDataAndView, sink) -> { + try { + assert mrmDataAndView.size() == 2; + + final byte[] content = SealedSenderMultiRecipientMessage.messageForRecipient( + mrmDataAndView.getFirst().getValue(), + mrmDataAndView.getLast().getValue()); + + sink.next(mrmMessage.toBuilder() + .clearSharedMrmKey() + .setContent(ByteString.copyFrom(content)) + .build()); + + mrmContentRetrievedCounter.increment(); + } catch (Exception e) { + sink.error(e); + } + }))); + + experiment.compareFutureResult(mrmMessage.toBuilder().clearSharedMrmKey().build(), + mrmMessageMono.toFuture()); + } + } + + /** + * Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure + */ + private void removeRecipientViewFromMrmData(final List sharedMrmKeys, final UUID accountUuid, + final byte deviceId) { + Flux.fromIterable(sharedMrmKeys) + .collectMultimap(SlotHash::getSlot) + .flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values())) + .flatMap( + keys -> removeRecipientViewFromMrmDataScript.execute(keys, new AciServiceIdentifier(accountUuid), deviceId), + REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY) + .subscribe(sharedMrmDataKeyRemovedCounter::increment, e -> logger.warn("Error removing recipient view", e)); + } + + private Mono, Long>> getNextMessagePage(final UUID destinationUuid, + final byte destinationDevice, long messageId) { - return getItemsScript.executeBinaryReactive( - List.of(getMessageQueueKey(destinationUuid, destinationDevice), - getPersistInProgressKey(destinationUuid, destinationDevice)), - List.of(String.valueOf(PAGE_SIZE).getBytes(StandardCharsets.UTF_8), - String.valueOf(messageId).getBytes(StandardCharsets.UTF_8))) - .map(result -> { + return getItemsScript.execute(destinationUuid, destinationDevice, PAGE_SIZE, messageId) + .map(queueItems -> { logger.trace("Processing page: {}", messageId); - @SuppressWarnings("unchecked") - List queueItems = (List) result; - if (queueItems.isEmpty()) { return new Pair<>(Collections.emptyList(), null); } @@ -324,7 +490,7 @@ private Flux, Long>> getNextMessagePage(final UUID destination } final long lastMessageId = Long.parseLong( - new String(queueItems.get(queueItems.size() - 1), StandardCharsets.UTF_8)); + new String(queueItems.getLast(), StandardCharsets.UTF_8)); return new Pair<>(queueItems, lastMessageId); }); @@ -362,10 +528,35 @@ public CompletableFuture clear(final UUID destinationUuid) { public CompletableFuture clear(final UUID destinationUuid, final byte deviceId) { final Timer.Sample sample = Timer.start(); - return removeQueueScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, deviceId), - getMessageQueueMetadataKey(destinationUuid, deviceId), - getQueueIndexKey(destinationUuid, deviceId)), - Collections.emptyList()) + return removeQueueScript.execute(destinationUuid, deviceId, Collections.emptyList()) + .publishOn(messageDeletionScheduler) + .expand(messagesToProcess -> { + if (messagesToProcess.isEmpty()) { + return Mono.empty(); + } + + final List mrmKeys = new ArrayList<>(messagesToProcess.size()); + final List processedMessages = new ArrayList<>(messagesToProcess.size()); + for (byte[] serialized : messagesToProcess) { + try { + final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(serialized); + + processedMessages.add(message.getServerGuid()); + + if (message.hasSharedMrmKey()) { + mrmKeys.add(message.getSharedMrmKey().toByteArray()); + } + } catch (final InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + } + + removeRecipientViewFromMrmData(mrmKeys, destinationUuid, deviceId); + + return removeQueueScript.execute(destinationUuid, deviceId, processedMessages); + }) + .then() + .toFuture() .thenRun(() -> sample.stop(clearQueueTimer)); } @@ -375,11 +566,7 @@ int getNextSlotToPersist() { } List getQueuesToPersist(final int slot, final Instant maxTime, final int limit) { - //noinspection unchecked - return getQueuesToPersistTimer.record(() -> (List) getQueuesToPersistScript.execute( - List.of(new String(getQueueIndexKey(slot), StandardCharsets.UTF_8)), - List.of(String.valueOf(maxTime.toEpochMilli()), - String.valueOf(limit)))); + return getQueuesToPersistTimer.record(() -> getQueuesToPersistScript.execute(slot, maxTime, limit)); } void addQueueToPersist(final UUID accountUuid, final byte deviceId) { @@ -538,29 +725,36 @@ static String getQueueNameFromKeyspaceChannel(final String channel) { return channel.substring(startOfHashTag + 1, endOfHashTag); } - @VisibleForTesting static byte[] getMessageQueueKey(final UUID accountUuid, final byte deviceId) { return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } - private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final byte deviceId) { + static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final byte deviceId) { return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } - private static byte[] getQueueIndexKey(final UUID accountUuid, final byte deviceId) { + static byte[] getQueueIndexKey(final UUID accountUuid, final byte deviceId) { return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId)); } - private static byte[] getQueueIndexKey(final int slot) { + static byte[] getQueueIndexKey(final int slot) { return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8); } - private static byte[] getPersistInProgressKey(final UUID accountUuid, final byte deviceId) { + static byte[] getSharedMrmKey(final UUID mrmGuid) { + return ("mrm::{" + mrmGuid.toString() + "}").getBytes(StandardCharsets.UTF_8); + } + + static byte[] getPersistInProgressKey(final UUID accountUuid, final byte deviceId) { return ("user_queue_persisting::{" + accountUuid + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } - private static byte[] getUnlinkInProgressKey(final UUID accountUuid) { - return ("user_account_unlinking::{" + accountUuid + "}").getBytes(StandardCharsets.UTF_8); + static byte[] getSharedMrmViewKey(final AciServiceIdentifier serviceIdentifier, final byte deviceId) { + final ByteBuffer keyBb = ByteBuffer.allocate(18); + keyBb.put(serviceIdentifier.toFixedWidthByteArray()); + keyBb.put(deviceId); + assert !keyBb.hasRemaining(); + return keyBb.array(); } static UUID getAccountUuidFromQueueName(final String queueName) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScript.java new file mode 100644 index 000000000..58e5a824a --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScript.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.UUID; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import reactor.core.publisher.Mono; + +/** + * Retrieves a list of messages and their corresponding queue-local IDs for the device. To support streaming processing, + * the last queue-local message ID from a previous call may be used as the {@code afterMessageId}. + */ +class MessagesCacheGetItemsScript { + + private final ClusterLuaScript getItemsScript; + + MessagesCacheGetItemsScript(FaultTolerantRedisCluster redisCluster) throws IOException { + this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.OBJECT); + } + + Mono> execute(final UUID destinationUuid, final byte destinationDevice, + int limit, long afterMessageId) { + final List keys = List.of( + MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey + MessagesCache.getPersistInProgressKey(destinationUuid, destinationDevice) // queueLockKey + ); + final List args = List.of( + String.valueOf(limit).getBytes(StandardCharsets.UTF_8), // limit + String.valueOf(afterMessageId).getBytes(StandardCharsets.UTF_8) // afterMessageId + ); + //noinspection unchecked + return getItemsScript.executeBinaryReactive(keys, args) + .map(result -> (List) result) + .next(); + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetQueuesToPersistScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetQueuesToPersistScript.java new file mode 100644 index 000000000..8f3703f3e --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetQueuesToPersistScript.java @@ -0,0 +1,43 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.List; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; + +/** + * Returns a list of queues that may be persisted. They will be sorted from oldest to more recent, limited by the + * {@code maxTime} argument. + * + * @see MessagePersister + */ +class MessagesCacheGetQueuesToPersistScript { + + private final ClusterLuaScript getQueuesToPersistScript; + + MessagesCacheGetQueuesToPersistScript(final FaultTolerantRedisCluster redisCluster) throws IOException { + this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua", + ScriptOutputType.MULTI); + } + + List execute(final int slot, final Instant maxTime, final int limit) { + final List keys = List.of( + new String(MessagesCache.getQueueIndexKey(slot), StandardCharsets.UTF_8) // queueTotalIndexKey + ); + final List args = List.of( + String.valueOf(maxTime.toEpochMilli()), // maxTime + String.valueOf(limit) // limit + ); + + //noinspection unchecked + return (List) getQueuesToPersistScript.execute(keys, args); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java new file mode 100644 index 000000000..257e90b86 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java @@ -0,0 +1,48 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.UUID; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; + +/** + * Inserts an envelope into the message queue for a destination device. + */ +class MessagesCacheInsertScript { + + private final ClusterLuaScript insertScript; + + MessagesCacheInsertScript(FaultTolerantRedisCluster redisCluster) throws IOException { + this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); + } + + long execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) { + assert envelope.hasServerGuid(); + assert envelope.hasServerTimestamp(); + + final List keys = List.of( + MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey + MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey + MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey + ); + + final List args = new ArrayList<>(Arrays.asList( + envelope.toByteArray(), // message + String.valueOf(envelope.getServerTimestamp()).getBytes(StandardCharsets.UTF_8), // currentTime + envelope.getServerGuid().getBytes(StandardCharsets.UTF_8) // guid + )); + + return (long) insertScript.executeBinary(keys, args); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java new file mode 100644 index 000000000..28e8f0d59 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java @@ -0,0 +1,53 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; + +/** + * Inserts the shared multi-recipient message payload into the cache. The list of recipients and views will be set as + * fields in the hash. + * + * @see SealedSenderMultiRecipientMessage#serializedRecipientView(SealedSenderMultiRecipientMessage.Recipient) + */ +class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript { + + private final ClusterLuaScript script; + + MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(FaultTolerantRedisCluster redisCluster) + throws IOException { + this.script = ClusterLuaScript.fromResource(redisCluster, "lua/insert_shared_multirecipient_message_data.lua", + ScriptOutputType.INTEGER); + } + + void execute(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage message) { + final List keys = List.of( + sharedMrmKey // sharedMrmKey + ); + + // Pre-allocate capacity for the most fields we expect -- 6 devices per recipient, plus the data field. + final List args = new ArrayList<>(message.getRecipients().size() * 6 + 1); + args.add(message.serialized()); + + message.getRecipients().forEach((serviceId, recipient) -> { + for (byte device : recipient.getDevices()) { + final byte[] key = new byte[18]; + System.arraycopy(serviceId.toServiceIdFixedWidthBinary(), 0, key, 0, 17); + key[17] = device; + args.add(key); + args.add(message.serializedRecipientView(recipient)); + } + }); + + script.executeBinary(keys, args); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveByGuidScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveByGuidScript.java new file mode 100644 index 000000000..c8847a57e --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveByGuidScript.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; + +/** + * Removes a list of message GUIDs from the queue of a destination device. + */ +class MessagesCacheRemoveByGuidScript { + + private final ClusterLuaScript removeByGuidScript; + + MessagesCacheRemoveByGuidScript(final FaultTolerantRedisCluster redisCluster) throws IOException { + this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", + ScriptOutputType.OBJECT); + } + + CompletableFuture> execute(final UUID destinationUuid, final byte destinationDevice, + final List messageGuids) { + + final List keys = List.of( + MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey + MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey + MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey + ); + final List args = messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8)) + .toList(); + + //noinspection unchecked + return removeByGuidScript.executeBinaryAsync(keys, args) + .thenApply(result -> (List) result); + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveQueueScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveQueueScript.java new file mode 100644 index 000000000..2188f0933 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveQueueScript.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import reactor.core.publisher.Mono; + +/** + * Removes a device's queue from the cache. For a non-empty queue, this script must be executed multiple times. + *

    + *
  1. The first call will return a list of messages to check for {@code sharedMrmKeys}. If a {@code sharedMrmKey} is present, {@link MessagesCacheRemoveRecipientViewFromMrmDataScript} must be called.
  2. + *
  3. Once theses messages have been processed, this script should be called again, confirming that the messages have been processed.
  4. + *
  5. This should be repeated until the script returns an empty list, as the script only returns a page ({@value PAGE_SIZE}) of messages at a time.
  6. + *
+ */ +class MessagesCacheRemoveQueueScript { + + private static final int PAGE_SIZE = 100; + + private final ClusterLuaScript removeQueueScript; + + MessagesCacheRemoveQueueScript(FaultTolerantRedisCluster redisCluster) throws IOException { + this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", + ScriptOutputType.MULTI); + } + + Mono> execute(final UUID destinationUuid, final byte destinationDevice, + final List processedMessageGuids) { + + final List keys = List.of( + MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey + MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey + MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey + ); + + final List args = new ArrayList<>(); + + args.addFirst(String.valueOf(PAGE_SIZE).getBytes(StandardCharsets.UTF_8)); // limit + args.addAll(processedMessageGuids.stream().map(guid -> guid.getBytes(StandardCharsets.UTF_8)) + .toList()); // processedMessageGuids + + //noinspection unchecked + return removeQueueScript.executeBinaryReactive(keys, args) + .map(result -> (List) result) + .next(); + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveRecipientViewFromMrmDataScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveRecipientViewFromMrmDataScript.java new file mode 100644 index 000000000..ada2b651c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveRecipientViewFromMrmDataScript.java @@ -0,0 +1,44 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import reactor.core.publisher.Mono; + +/** + * Removes the given destination device from the given {@code sharedMrmKeys}. If there are no devices remaining in the + * hash as a result, the shared payload is deleted. + *

+ * NOTE: Callers are responsible for ensuring that all keys are in the same slot. + */ +class MessagesCacheRemoveRecipientViewFromMrmDataScript { + + private final ClusterLuaScript removeRecipientViewFromMrmDataScript; + + MessagesCacheRemoveRecipientViewFromMrmDataScript(final FaultTolerantRedisCluster redisCluster) throws IOException { + this.removeRecipientViewFromMrmDataScript = ClusterLuaScript.fromResource(redisCluster, + "lua/remove_recipient_view_from_mrm_data.lua", ScriptOutputType.INTEGER); + } + + Mono execute(final Collection keysCollection, final AciServiceIdentifier serviceIdentifier, + final byte deviceId) { + final List keys = keysCollection instanceof List + ? (List) keysCollection + : new ArrayList<>(keysCollection); + + return removeRecipientViewFromMrmDataScript.executeBinaryReactive(keys, + List.of(MessagesCache.getSharedMrmViewKey(serviceIdentifier, deviceId))) + .map(o -> (long) o) + .next(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 2c7cc4011..9fb05c7bb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -19,8 +19,10 @@ import java.util.stream.Collectors; import javax.annotation.Nullable; import org.reactivestreams.Publisher; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.util.Pair; @@ -62,8 +64,8 @@ public void insert(UUID destinationUuid, byte destinationDevice, Envelope messag messagesCache.insert(messageGuid, destinationUuid, destinationDevice, message); - if (message.hasSourceUuid() && !destinationUuid.toString().equals(message.getSourceUuid())) { - reportMessageManager.store(message.getSourceUuid(), messageGuid); + if (message.hasSourceServiceId() && !destinationUuid.toString().equals(message.getSourceServiceId())) { + reportMessageManager.store(message.getSourceServiceId(), messageGuid); } } @@ -137,7 +139,7 @@ public CompletableFuture clear(UUID destinationUuid, byte deviceId) { return messagesCache.clear(destinationUuid, deviceId); } - public CompletableFuture> delete(UUID destinationUuid, Device destinationDevice, UUID guid, + public CompletableFuture> delete(UUID destinationUuid, Device destinationDevice, UUID guid, @Nullable Long serverTimestamp) { return messagesCache.remove(destinationUuid, destinationDevice.getId(), guid) .thenComposeAsync(removed -> { @@ -146,12 +148,16 @@ public CompletableFuture> delete(UUID destinationUuid, Device return CompletableFuture.completedFuture(removed); } + final CompletableFuture> maybeDeletedEnvelope; if (serverTimestamp == null) { - return messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, destinationDevice, guid); + maybeDeletedEnvelope = messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, + destinationDevice, guid); } else { - return messagesDynamoDb.deleteMessage(destinationUuid, destinationDevice, guid, serverTimestamp); + maybeDeletedEnvelope = messagesDynamoDb.deleteMessage(destinationUuid, destinationDevice, guid, + serverTimestamp); } + return maybeDeletedEnvelope.thenApply(maybeEnvelope -> maybeEnvelope.map(RemovedMessage::fromEnvelope)); }, messageDeletionExecutor); } @@ -194,4 +200,14 @@ public void removeMessageAvailabilityListener(final MessageAvailabilityListener messagesCache.removeMessageAvailabilityListener(listener); } + /** + * Inserts the shared multi-recipient message payload to storage. + * + * @return a key where the shared data is stored + * @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript + */ + public byte[] insertSharedMultiRecipientMessagePayload( + SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { + return messagesCache.insertSharedMultiRecipientMessagePayload(UUID.randomUUID(), sealedSenderMultiRecipientMessage); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RemovedMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RemovedMessage.java new file mode 100644 index 000000000..b37dc8248 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RemovedMessage.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import com.google.common.annotations.VisibleForTesting; +import java.util.Optional; +import java.util.UUID; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; + +public record RemovedMessage(Optional sourceServiceId, ServiceIdentifier destinationServiceId, + @VisibleForTesting UUID serverGuid, long serverTimestamp, long clientTimestamp, + MessageProtos.Envelope.Type envelopeType) { + + public static RemovedMessage fromEnvelope(MessageProtos.Envelope envelope) { + return new RemovedMessage( + envelope.hasSourceServiceId() + ? Optional.of(ServiceIdentifier.valueOf(envelope.getSourceServiceId())) + : Optional.empty(), + ServiceIdentifier.valueOf(envelope.getDestinationServiceId()), + UUID.fromString(envelope.getServerGuid()), + envelope.getServerTimestamp(), + envelope.getClientTimestamp(), + envelope.getType() + ); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 2334f24d9..e21ace10d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -294,16 +294,16 @@ public static void recordMessageDeliveryDuration(long timestamp, Device messageD } private void sendDeliveryReceiptFor(Envelope message) { - if (!message.hasSourceUuid()) { + if (!message.hasSourceServiceId()) { return; } try { - receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationUuid()), - auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceUuid()), - message.getTimestamp()); + receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()), + auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()), + message.getClientTimestamp()); } catch (IllegalArgumentException e) { - logger.error("Could not parse UUID: {}", message.getSourceUuid()); + logger.error("Could not parse UUID: {}", message.getSourceServiceId()); } catch (Exception e) { logger.warn("Failed to send receipt", e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index d6fdcdddc..71e113d28 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -205,7 +205,7 @@ static CommandDependencies build( ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster, recurringJobExecutor, keyspaceNotificationDispatchExecutor); MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor, - messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC()); + messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, configuration.getDynamoDbTables().getReportMessage().getTableName(), diff --git a/service/src/main/proto/TextSecure.proto b/service/src/main/proto/TextSecure.proto index ef340fb1b..5b7df2ff2 100644 --- a/service/src/main/proto/TextSecure.proto +++ b/service/src/main/proto/TextSecure.proto @@ -11,30 +11,31 @@ option java_outer_classname = "MessageProtos"; message Envelope { enum Type { - UNKNOWN = 0; - CIPHERTEXT = 1; - KEY_EXCHANGE = 2; - PREKEY_BUNDLE = 3; + UNKNOWN = 0; + CIPHERTEXT = 1; + KEY_EXCHANGE = 2; + PREKEY_BUNDLE = 3; SERVER_DELIVERY_RECEIPT = 5; UNIDENTIFIED_SENDER = 6; reserved 7; PLAINTEXT_CONTENT = 8; // for decryption error receipts } - optional Type type = 1; - optional string source_uuid = 11; + optional Type type = 1; + optional string source_service_id = 11; optional uint32 source_device = 7; - optional uint64 timestamp = 5; - optional bytes content = 8; // Contains an encrypted Content + optional uint64 client_timestamp = 5; + optional bytes content = 8; // Contains an encrypted Content optional string server_guid = 9; optional uint64 server_timestamp = 10; optional bool ephemeral = 12; // indicates that the message should not be persisted if the recipient is offline - optional string destination_uuid = 13; + optional string destination_service_id = 13; optional bool urgent = 14 [default=true]; optional string updated_pni = 15; optional bool story = 16; // indicates that the content is a story. optional bytes report_spam_token = 17; // token sent when reporting spam - // next: 18 + optional bytes shared_mrm_key = 18; // indicates content should be fetched from multi-recipient message datastore + // next: 19 } message ProvisioningUuid { @@ -42,25 +43,25 @@ message ProvisioningUuid { } message ServerCertificate { - message Certificate { - optional uint32 id = 1; - optional bytes key = 2; - } + message Certificate { + optional uint32 id = 1; + optional bytes key = 2; + } - optional bytes certificate = 1; - optional bytes signature = 2; + optional bytes certificate = 1; + optional bytes signature = 2; } message SenderCertificate { - message Certificate { - optional string sender = 1; - optional string sender_uuid = 6; - optional uint32 sender_device = 2; - optional fixed64 expires = 3; - optional bytes identity_key = 4; - optional ServerCertificate signer = 5; - } + message Certificate { + optional string sender = 1; + optional string sender_uuid = 6; + optional uint32 sender_device = 2; + optional fixed64 expires = 3; + optional bytes identity_key = 4; + optional ServerCertificate signer = 5; + } - optional bytes certificate = 1; - optional bytes signature = 2; + optional bytes certificate = 1; + optional bytes signature = 2; } diff --git a/service/src/main/resources/lua/apn/get.lua b/service/src/main/resources/lua/apn/get.lua index 404a21b19..45d984cec 100644 --- a/service/src/main/resources/lua/apn/get.lua +++ b/service/src/main/resources/lua/apn/get.lua @@ -46,7 +46,7 @@ local getNextInterval = function(interval) end -local results = redis.call("ZRANGEBYSCORE", pendingNotificationQueue, 0, maxTime, "LIMIT", 0, limit) +local results = redis.call("ZRANGE", pendingNotificationQueue, 0, maxTime, "BYSCORE", "LIMIT", 0, limit) local collated = {} if results and next(results) then diff --git a/service/src/main/resources/lua/get_items.lua b/service/src/main/resources/lua/get_items.lua index 045a804ac..cb930b53b 100644 --- a/service/src/main/resources/lua/get_items.lua +++ b/service/src/main/resources/lua/get_items.lua @@ -1,7 +1,10 @@ -local queueKey = KEYS[1] -local queueLockKey = KEYS[2] -local limit = ARGV[1] -local afterMessageId = ARGV[2] +-- gets messages from a device's queue, up to a given limit +-- returns a list of all envelopes and their queue-local IDs + +local queueKey = KEYS[1] -- sorted set of all Envelopes for a device, scored by queue-local ID +local queueLockKey = KEYS[2] -- a key whose presence indicates that the queue is being persistent and must not be read +local limit = ARGV[1] -- [number] the maximum number of messages to return +local afterMessageId = ARGV[2] -- [number] a queue-local ID to exclusively start after, to support pagination. Use -1 to start at the beginning local locked = redis.call("GET", queueLockKey) @@ -9,17 +12,8 @@ if locked then return {} end -if afterMessageId == "null" then - -- An index range is inclusive - local min = 0 - local max = limit - 1 - - if max < 0 then - return {} - end - - return redis.call("ZRANGE", queueKey, min, max, "WITHSCORES") -else - -- note: this is deprecated in Redis 6.2, and should be migrated to zrange after the cluster is updated - return redis.call("ZRANGEBYSCORE", queueKey, "("..afterMessageId, "+inf", "WITHSCORES", "LIMIT", 0, limit) +if afterMessageId == "null" or afterMessageId == nil then + return redis.error_reply("ERR afterMessageId is required") end + +return redis.call("ZRANGE", queueKey, "("..afterMessageId, "+inf", "BYSCORE", "LIMIT", 0, limit, "WITHSCORES") diff --git a/service/src/main/resources/lua/get_queues_to_persist.lua b/service/src/main/resources/lua/get_queues_to_persist.lua index 97f9793ab..6b321710f 100644 --- a/service/src/main/resources/lua/get_queues_to_persist.lua +++ b/service/src/main/resources/lua/get_queues_to_persist.lua @@ -1,8 +1,10 @@ -local queueTotalIndexKey = KEYS[1] -local maxTime = ARGV[1] -local limit = ARGV[2] +-- returns a list of queues that meet persistence criteria -local results = redis.call("ZRANGEBYSCORE", queueTotalIndexKey, 0, maxTime, "LIMIT", 0, limit) +local queueTotalIndexKey = KEYS[1] -- sorted set of all queues in the shard, by timestamp of oldest message +local maxTime = ARGV[1] -- [number] the most recent queue timestamp that may be fetched +local limit = ARGV[2] -- [number] the maximum number of queues to fetch + +local results = redis.call("ZRANGE", queueTotalIndexKey, 0, maxTime, "BYSCORE", "LIMIT", 0, limit) if results and next(results) then redis.call("ZREM", queueTotalIndexKey, unpack(results)) diff --git a/service/src/main/resources/lua/insert_item.lua b/service/src/main/resources/lua/insert_item.lua index 76eea91bf..494dcf305 100644 --- a/service/src/main/resources/lua/insert_item.lua +++ b/service/src/main/resources/lua/insert_item.lua @@ -1,9 +1,12 @@ -local queueKey = KEYS[1] -local queueMetadataKey = KEYS[2] -local queueTotalIndexKey = KEYS[3] -local message = ARGV[1] -local currentTime = ARGV[2] -local guid = ARGV[3] +-- inserts a message into a device's queue, and updates relevant associated data +-- returns a number, the queue-local message ID (useful for testing) + +local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID +local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs +local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message +local message = ARGV[1] -- [bytes] the Envelope to insert +local currentTime = ARGV[2] -- [number] the message timestamp, to sort the queue in the queueTotalIndex +local guid = ARGV[3] -- [string] the message GUID if redis.call("HEXISTS", queueMetadataKey, guid) == 1 then return tonumber(redis.call("HGET", queueMetadataKey, guid)) @@ -14,9 +17,8 @@ local messageId = redis.call("HINCRBY", queueMetadataKey, "counter", 1) redis.call("ZADD", queueKey, "NX", messageId, message) redis.call("HSET", queueMetadataKey, guid, messageId) - -redis.call("EXPIRE", queueKey, 7776000) -- 90 days -redis.call("EXPIRE", queueMetadataKey, 7776000) -- 90 days +redis.call("EXPIRE", queueKey, 2678400) -- 31 days +redis.call("EXPIRE", queueMetadataKey, 2678400) -- 31 days redis.call("ZADD", queueTotalIndexKey, "NX", currentTime, queueKey) return messageId diff --git a/service/src/main/resources/lua/insert_shared_multirecipient_message_data.lua b/service/src/main/resources/lua/insert_shared_multirecipient_message_data.lua new file mode 100644 index 000000000..214bce4b4 --- /dev/null +++ b/service/src/main/resources/lua/insert_shared_multirecipient_message_data.lua @@ -0,0 +1,13 @@ +-- inserts shared multi-recipient message data + +local sharedMrmKey = KEYS[1] -- [string] the key containing the shared MRM data +local mrmData = ARGV[1] -- [bytes] the serialized multi-recipient message data +-- the remainder of ARGV is list of recipient keys and view data + +redis.call("HSET", sharedMrmKey, "data", mrmData); +redis.call("EXPIRE", sharedMrmKey, 604800) -- 7 days + +-- unpack() fails with "too many results" at very large table sizes, so we loop +for i = 2, #ARGV, 2 do + redis.call("HSET", sharedMrmKey, ARGV[i], ARGV[i + 1]) +end diff --git a/service/src/main/resources/lua/remove_item_by_guid.lua b/service/src/main/resources/lua/remove_item_by_guid.lua index b80d32f96..3a588d756 100644 --- a/service/src/main/resources/lua/remove_item_by_guid.lua +++ b/service/src/main/resources/lua/remove_item_by_guid.lua @@ -1,20 +1,26 @@ -local queueKey = KEYS[1] -local queueMetadataKey = KEYS[2] -local queueTotalIndexKey = KEYS[3] +-- removes a list of messages by ID from the cluster, returning the deleted messages +-- returns a list of removed envelopes +-- Note: content may be absent for MRM messages, and for these messages, the caller must update the sharedMrmKey +-- to remove the recipient's reference + +local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID +local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs +local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message +local messageGuids = ARGV -- [list[string]] message GUIDs local removedMessages = {} -for _, guid in ipairs(ARGV) do +for _, guid in ipairs(messageGuids) do local messageId = redis.call("HGET", queueMetadataKey, guid) if messageId then - local envelope = redis.call("ZRANGEBYSCORE", queueKey, messageId, messageId, "LIMIT", 0, 1) + local envelope = redis.call("ZRANGE", queueKey, messageId, messageId, "BYSCORE", "LIMIT", 0, 1) redis.call("ZREMRANGEBYSCORE", queueKey, messageId, messageId) redis.call("HDEL", queueMetadataKey, guid) if envelope and next(envelope) then - removedMessages[#removedMessages + 1] = envelope[1] + table.insert(removedMessages, envelope[1]) end end end diff --git a/service/src/main/resources/lua/remove_queue.lua b/service/src/main/resources/lua/remove_queue.lua index ace767eb5..68a7f6561 100644 --- a/service/src/main/resources/lua/remove_queue.lua +++ b/service/src/main/resources/lua/remove_queue.lua @@ -1,7 +1,29 @@ -local queueKey = KEYS[1] -local queueMetadataKey = KEYS[2] -local queueTotalIndexKey = KEYS[3] +-- incrementally removes a given device's queue and associated data +-- returns: a page of messages and scores. +-- The messages must be checked for mrmKeys to update. After updating MRM keys, this script must be called again +-- with processedMessageGuids. If the returned table is empty, then +-- the queue has been fully deleted. -redis.call("DEL", queueKey) -redis.call("DEL", queueMetadataKey) -redis.call("ZREM", queueTotalIndexKey, queueKey) +local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID +local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs +local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message +local limit = ARGV[1] -- the maximum number of messages to return +local processedMessageGuids = { unpack(ARGV, 2) } + +for _, guid in ipairs(processedMessageGuids) do + local messageId = redis.call("HGET", queueMetadataKey, guid) + if messageId then + redis.call("ZREMRANGEBYSCORE", queueKey, messageId, messageId) + redis.call("HDEL", queueMetadataKey, guid) + end +end + +local messages = redis.call("ZRANGE", queueKey, 0, limit-1) + +if #messages == 0 then + redis.call("DEL", queueKey) + redis.call("DEL", queueMetadataKey) + redis.call("ZREM", queueTotalIndexKey, queueKey) +end + +return messages diff --git a/service/src/main/resources/lua/remove_recipient_view_from_mrm_data.lua b/service/src/main/resources/lua/remove_recipient_view_from_mrm_data.lua new file mode 100644 index 000000000..bf91f82a1 --- /dev/null +++ b/service/src/main/resources/lua/remove_recipient_view_from_mrm_data.lua @@ -0,0 +1,17 @@ +-- Removes the given recipient view from the shared MRM data. If the only field remaining after the removal is the +-- `data` field, then the key will be deleted + +local sharedMrmKeys = KEYS -- KEYS: list of all keys in a single slot to update +local recipientViewToRemove = ARGV[1] -- the recipient view to remove from the hash + +local keysDeleted = 0 + +for _, sharedMrmKey in ipairs(sharedMrmKeys) do + redis.call("HDEL", sharedMrmKey, recipientViewToRemove) + if redis.call("HLEN", sharedMrmKey) == 1 then + redis.call("DEL", sharedMrmKey) + keysDeleted = keysDeleted + 1 + end +end + +return keysDeleted diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index e4f0ce911..fc9b3f92d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -87,6 +87,7 @@ import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration; import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices; import org.whispersystems.textsecuregcm.entities.AccountStaleDevices; import org.whispersystems.textsecuregcm.entities.IncomingMessage; @@ -121,6 +122,7 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; +import org.whispersystems.textsecuregcm.storage.RemovedMessage; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; @@ -251,6 +253,7 @@ void setup() { final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); when(dynamicConfiguration.getInboundMessageByteLimitConfiguration()).thenReturn(inboundMessageByteLimitConfiguration); + when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true)); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); @@ -311,7 +314,7 @@ void testSingleDeviceCurrent() throws Exception { ArgumentCaptor captor = ArgumentCaptor.forClass(Envelope.class); verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); - assertTrue(captor.getValue().hasSourceUuid()); + assertTrue(captor.getValue().hasSourceServiceId()); assertTrue(captor.getValue().hasSourceDevice()); assertTrue(captor.getValue().getUrgent()); } @@ -353,7 +356,7 @@ void testSingleDeviceCurrentNotUrgent() throws Exception { ArgumentCaptor captor = ArgumentCaptor.forClass(Envelope.class); verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); - assertTrue(captor.getValue().hasSourceUuid()); + assertTrue(captor.getValue().hasSourceServiceId()); assertTrue(captor.getValue().hasSourceDevice()); assertFalse(captor.getValue().getUrgent()); } @@ -375,7 +378,7 @@ void testSingleDeviceCurrentByPni() throws Exception { ArgumentCaptor captor = ArgumentCaptor.forClass(Envelope.class); verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); - assertTrue(captor.getValue().hasSourceUuid()); + assertTrue(captor.getValue().hasSourceServiceId()); assertTrue(captor.getValue().hasSourceDevice()); } } @@ -410,7 +413,7 @@ void testSingleDeviceCurrentUnidentified() throws Exception { ArgumentCaptor captor = ArgumentCaptor.forClass(Envelope.class); verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); - assertFalse(captor.getValue().hasSourceUuid()); + assertFalse(captor.getValue().hasSourceServiceId()); assertFalse(captor.getValue().hasSourceDevice()); } } @@ -444,7 +447,7 @@ void testSingleDeviceCurrentGroupSendEndorsement( assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse))); if (expectedResponse == 200) { verify(messageSender).sendMessage( - any(Account.class), any(Device.class), argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()), + any(Account.class), any(Device.class), argThat(env -> !env.hasSourceServiceId() && !env.hasSourceDevice()), eq(false)); } else { verifyNoMoreInteractions(messageSender); @@ -732,23 +735,27 @@ void testGetMessagesBadAuth() { @Test void testDeleteMessages() { - long timestamp = System.currentTimeMillis(); + long clientTimestamp = System.currentTimeMillis(); UUID sourceUuid = UUID.randomUUID(); UUID uuid1 = UUID.randomUUID(); + final long serverTimestamp = 0; when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid1, null)) .thenReturn( - CompletableFutureTestUtil.almostCompletedFuture(Optional.of(generateEnvelope(uuid1, Envelope.Type.CIPHERTEXT_VALUE, - timestamp, sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0)))); + CompletableFutureTestUtil.almostCompletedFuture(Optional.of( + new RemovedMessage(Optional.of(new AciServiceIdentifier(sourceUuid)), + new AciServiceIdentifier(AuthHelper.VALID_UUID), uuid1, serverTimestamp, clientTimestamp, + Envelope.Type.CIPHERTEXT)))); UUID uuid2 = UUID.randomUUID(); when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid2, null)) .thenReturn( - CompletableFutureTestUtil.almostCompletedFuture(Optional.of(generateEnvelope( - uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, - System.currentTimeMillis(), sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, null, 0)))); + CompletableFutureTestUtil.almostCompletedFuture(Optional.of( + new RemovedMessage(Optional.of(new AciServiceIdentifier(sourceUuid)), + new AciServiceIdentifier(AuthHelper.VALID_UUID), uuid2, serverTimestamp, clientTimestamp, + Envelope.Type.SERVER_DELIVERY_RECEIPT)))); UUID uuid3 = UUID.randomUUID(); when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid3, null)) @@ -766,7 +773,7 @@ void testDeleteMessages() { assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq((byte) 1), - eq(new AciServiceIdentifier(sourceUuid)), eq(timestamp)); + eq(new AciServiceIdentifier(sourceUuid)), eq(clientTimestamp)); } try (final Response response = resources.getJerseyTest() @@ -1068,9 +1075,16 @@ private static Stream testValidateEnvelopeType() { } private record Recipient(ServiceIdentifier uuid, - byte deviceId, - int registrationId, - byte[] perRecipientKeyMaterial) { + Byte[] deviceId, + Integer[] registrationId, + byte[] perRecipientKeyMaterial) { + + Recipient(ServiceIdentifier uuid, + byte deviceId, + int registrationId, + byte[] perRecipientKeyMaterial) { + this(uuid, new Byte[]{deviceId}, new Integer[]{registrationId}, perRecipientKeyMaterial); + } } private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, @@ -1081,8 +1095,13 @@ private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipi bb.put(UUIDUtil.toBytes(r.uuid().uuid())); } - bb.put(r.deviceId()); // device id (1 byte) - bb.putShort((short) r.registrationId()); // registration id (2 bytes) + assert (r.deviceId.length == r.registrationId.length); + + for (int i = 0; i < r.deviceId.length; i++) { + final int hasMore = i == r.deviceId.length - 1 ? 0 : 0x8000; + bb.put(r.deviceId()[i]); // device id (1 byte) + bb.putShort((short) (r.registrationId()[i] | hasMore)); // registration id (2 bytes) + } bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes) } @@ -1157,7 +1176,7 @@ void testManyRecipientMessage() throws Exception { .queryParam("story", true) .queryParam("urgent", false) .request() - .header(HttpHeaders.USER_AGENT, "FIXME") + .header(HttpHeaders.USER_AGENT, "test") .put(entity)) { assertThat(response.readEntity(String.class), response.getStatus(), is(equalTo(200))); @@ -1206,7 +1225,7 @@ private void testMultiRecipientMessage( .queryParam("story", isStory) .queryParam("urgent", urgent) .request() - .header(HttpHeaders.USER_AGENT, "FIXME") + .header(HttpHeaders.USER_AGENT, "test") .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessHeader) .put(entity)) { @@ -1216,7 +1235,7 @@ private void testMultiRecipientMessage( .sendMessage( any(), any(), - argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()), + argThat(env -> env.getUrgent() == urgent && !env.hasSourceServiceId() && !env.hasSourceDevice()), eq(true)); if (expectedStatus == 200) { SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); @@ -1384,7 +1403,7 @@ void testMultiRecipientMessageWithGroupSendEndorsements() throws Exception { .queryParam("story", false) .queryParam("urgent", false) .request() - .header(HttpHeaders.USER_AGENT, "FIXME") + .header(HttpHeaders.USER_AGENT, "test") .header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader( serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z"))) .put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) { @@ -1395,7 +1414,7 @@ void testMultiRecipientMessageWithGroupSendEndorsements() throws Exception { .sendMessage( any(), any(), - argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()), + argThat(env -> !env.hasSourceServiceId() && !env.hasSourceDevice()), eq(true)); SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); assertThat(smrmr.uuids404(), is(empty())); @@ -1423,7 +1442,7 @@ void testMultiRecipientMessageWithInvalidGroupSendEndorsements() throws Exceptio .queryParam("story", false) .queryParam("urgent", false) .request() - .header(HttpHeaders.USER_AGENT, "FIXME") + .header(HttpHeaders.USER_AGENT, "test") .header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader( serverSecretParams, List.of(MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z"))) .put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) { @@ -1454,7 +1473,7 @@ void testMultiRecipientMessageWithExpiredGroupSendEndorsements() throws Exceptio .queryParam("story", false) .queryParam("urgent", false) .request() - .header(HttpHeaders.USER_AGENT, "FIXME") + .header(HttpHeaders.USER_AGENT, "test") .header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader( serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z"))) .put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) { @@ -1620,7 +1639,7 @@ void sendMultiRecipientMessageMismatchedDevices() throws JsonProcessingException .queryParam("story", false) .queryParam("urgent", true) .request() - .header(HttpHeaders.USER_AGENT, "FIXME") + .header(HttpHeaders.USER_AGENT, "test") .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); // make the PUT request @@ -1663,7 +1682,7 @@ void sendMultiRecipientMessageStaleDevices() throws JsonProcessingException { .queryParam("story", false) .queryParam("urgent", true) .request() - .header(HttpHeaders.USER_AGENT, "FIXME") + .header(HttpHeaders.USER_AGENT, "test") .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); // make the PUT request @@ -1702,7 +1721,7 @@ void sendMultiRecipientMessageStoryRateLimited() { .queryParam("story", true) .queryParam("urgent", true) .request() - .header(HttpHeaders.USER_AGENT, "FIXME") + .header(HttpHeaders.USER_AGENT, "test") .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); when(rateLimiter.validateAsync(any(UUID.class))) @@ -1730,14 +1749,14 @@ private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UU final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() .setType(MessageProtos.Envelope.Type.forNumber(type)) - .setTimestamp(timestamp) + .setClientTimestamp(timestamp) .setServerTimestamp(serverTimestamp) - .setDestinationUuid(destinationUuid.toString()) + .setDestinationServiceId(destinationUuid.toString()) .setStory(story) .setServerGuid(guid.toString()); if (sourceUuid != null) { - builder.setSourceUuid(sourceUuid.toString()); + builder.setSourceServiceId(sourceUuid.toString()); builder.setSourceDevice(sourceDevice); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java index 308d0f74d..eeec5de64 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java @@ -104,7 +104,7 @@ private MessageProtos.Envelope createEnvelope(ServiceIdentifier destinationIdent final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder(); if (destinationIdentifier != null) { - builder.setDestinationUuid(destinationIdentifier.toServiceIdentifierString()); + builder.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString()); } return builder.build(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java index 570a9cc9b..19cda192f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java @@ -151,7 +151,7 @@ void testSendMessageNoChannel() { private MessageProtos.Envelope generateRandomMessage() { return MessageProtos.Envelope.newBuilder() - .setTimestamp(System.currentTimeMillis()) + .setClientTimestamp(System.currentTimeMillis()) .setServerTimestamp(System.currentTimeMillis()) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setType(MessageProtos.Envelope.Type.CIPHERTEXT) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java index 268435bb3..27b2764a0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java @@ -54,8 +54,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; -import org.jetbrains.annotations.Nullable; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 659ec423d..3a94f104b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -160,8 +160,8 @@ void changeNumberSetPrimaryDevicePrekeyAndSendMessages() throws Exception { final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); - assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); - assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); + assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); + assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni())); } @@ -208,8 +208,8 @@ void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); - assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); - assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); + assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); + assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni())); } @@ -254,8 +254,8 @@ void changeNumberSameNumberSetPrimaryDevicePrekeyAndSendMessages() throws Except final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); - assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); - assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); + assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); + assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); } @@ -296,8 +296,8 @@ void updatePniKeysSetPrimaryDevicePrekeyAndSendMessages() throws Exception { final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); - assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); - assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); + assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); + assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); } @@ -340,8 +340,8 @@ void updatePniKeysSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); - assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); - assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); + assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); + assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index 81e5e1a7e..a451ce4bf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -81,7 +81,7 @@ void setUp() throws Exception { notificationExecutorService = Executors.newSingleThreadExecutor(); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService, - messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC()); + messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC(), dynamicConfigurationManager); messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class), messageDeletionExecutorService); messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, @@ -185,12 +185,12 @@ public boolean handleMessagesPersisted() { private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final long serverTimestamp) { return MessageProtos.Envelope.newBuilder() - .setTimestamp(serverTimestamp * 2) // client timestamp may not be accurate + .setClientTimestamp(serverTimestamp * 2) // client timestamp may not be accurate .setServerTimestamp(serverTimestamp) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setServerGuid(messageGuid.toString()) - .setDestinationUuid(UUID.randomUUID().toString()) + .setDestinationServiceId(UUID.randomUUID().toString()) .build(); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index 211c7a6b9..4ae32396e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -40,6 +40,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; @@ -48,12 +49,11 @@ import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; - -import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException; +@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) class MessagePersisterTest { @RegisterExtension @@ -104,7 +104,7 @@ void setUp() throws Exception { resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor(); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService, - messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); + messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager); messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, clientPresenceManager, keysManager, dynamicConfigurationManager, PERSIST_DELAY, 1); @@ -356,7 +356,8 @@ private void insertMessages(final UUID accountUuid, final byte deviceId, final i final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder() - .setTimestamp(firstMessageTimestamp.toEpochMilli() + i) + .setDestinationServiceId(accountUuid.toString()) + .setClientTimestamp(firstMessageTimestamp.toEpochMilli() + i) .setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setType(MessageProtos.Envelope.Type.CIPHERTEXT) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java new file mode 100644 index 000000000..f0fe7d7f5 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.lettuce.core.RedisCommandExecutionException; +import io.lettuce.core.ScriptOutputType; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.UUID; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; + +class MessagesCacheGetItemsScriptTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + @Test + void testCacheGetItemsScript() throws Exception { + final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final UUID destinationUuid = UUID.randomUUID(); + final byte deviceId = 1; + final String serverGuid = UUID.randomUUID().toString(); + final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder() + .setServerTimestamp(Instant.now().getEpochSecond()) + .setServerGuid(serverGuid) + .build(); + + insertScript.execute(destinationUuid, deviceId, envelope1); + + final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final List messageAndScores = getItemsScript.execute(destinationUuid, deviceId, 1, -1) + .block(Duration.ofSeconds(1)); + + assertNotNull(messageAndScores); + assertEquals(2, messageAndScores.size()); + final MessageProtos.Envelope resultEnvelope = MessageProtos.Envelope.parseFrom( + messageAndScores.getFirst()); + + assertEquals(serverGuid, resultEnvelope.getServerGuid()); + } + + @Test + void testCacheGetItemsInvalidParameter() throws Exception { + final ClusterLuaScript getItemsScript = ClusterLuaScript.fromResource(REDIS_CLUSTER_EXTENSION.getRedisCluster(), + "lua/get_items.lua", ScriptOutputType.OBJECT); + + final byte[] fakeKey = new byte[]{1}; + + final Exception e = assertThrows(RedisCommandExecutionException.class, + () -> getItemsScript.executeBinaryReactive(List.of(fakeKey, fakeKey), + List.of("1".getBytes(StandardCharsets.UTF_8))) + .next() + .block(Duration.ofSeconds(1))); + + assertEquals("ERR afterMessageId is required", e.getMessage()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java new file mode 100644 index 000000000..febe8c32e --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.time.Instant; +import java.util.UUID; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; + +class MessagesCacheInsertScriptTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + @Test + void testCacheInsertScript() throws Exception { + final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final UUID destinationUuid = UUID.randomUUID(); + final byte deviceId = 1; + final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder() + .setServerTimestamp(Instant.now().getEpochSecond()) + .setServerGuid(UUID.randomUUID().toString()) + .build(); + + assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1)); + + final MessageProtos.Envelope envelope2 = MessageProtos.Envelope.newBuilder() + .setServerTimestamp(Instant.now().getEpochSecond()) + .setServerGuid(UUID.randomUUID().toString()) + .build(); + assertEquals(2, insertScript.execute(destinationUuid, deviceId, envelope2)); + + assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1), + "Repeated with same guid should have same message ID"); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java new file mode 100644 index 000000000..e1ceea81d --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.util.Pair; + +class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + @ParameterizedTest + @MethodSource + void testInsert(final int count, final Map> destinations) throws Exception { + + final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); + insertMrmScript.execute(sharedMrmKey, + MessagesCacheTest.generateRandomMrmMessage(destinations)); + + final int totalDevices = destinations.values().stream().mapToInt(List::size).sum(); + final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withBinaryCluster(conn -> conn.sync().hlen(sharedMrmKey)); + assertEquals(totalDevices + 1, hashFieldCount); + } + + public static List testInsert() { + final Map> singleAccount = Map.of( + new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2)); + + final List testCases = new ArrayList<>(); + testCases.add(Arguments.of(1, singleAccount)); + + for (int j = 1000; j <= 30000; j += 1000) { + + final Map> deviceLists = new HashMap<>(); + final Map> manyAccounts = IntStream.range(0, j) + .mapToObj(i -> { + final int deviceCount = 1 + i % 5; + final List devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count) + .mapToObj(v -> (byte) v) + .toList()); + + return new Pair<>(new AciServiceIdentifier(UUID.randomUUID()), devices); + }) + .collect(Collectors.toMap(Pair::first, Pair::second)); + + testCases.add(Arguments.of(j, manyAccounts)); + } + + return testCases; + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveByGuidScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveByGuidScriptTest.java new file mode 100644 index 000000000..52f3b239f --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveByGuidScriptTest.java @@ -0,0 +1,52 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.time.Instant; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; + +class MessagesCacheRemoveByGuidScriptTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + @Test + void testCacheRemoveByGuid() throws Exception { + final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final UUID destinationUuid = UUID.randomUUID(); + final byte deviceId = 1; + final UUID serverGuid = UUID.randomUUID(); + final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder() + .setServerTimestamp(Instant.now().getEpochSecond()) + .setServerGuid(serverGuid.toString()) + .build(); + + insertScript.execute(destinationUuid, deviceId, envelope1); + + final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final List removedMessages = removeByGuidScript.execute(destinationUuid, deviceId, + List.of(serverGuid)).get(1, TimeUnit.SECONDS); + + assertEquals(1, removedMessages.size()); + + final MessageProtos.Envelope resultMessage = MessageProtos.Envelope.parseFrom( + removedMessages.getFirst()); + + assertEquals(serverGuid, UUID.fromString(resultMessage.getServerGuid())); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveQueueScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveQueueScriptTest.java new file mode 100644 index 000000000..b25b3456a --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveQueueScriptTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; + +class MessagesCacheRemoveQueueScriptTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + + @Test + void testCacheRemoveQueueScript() throws Exception { + final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final UUID destinationUuid = UUID.randomUUID(); + final byte deviceId = 1; + final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder() + .setServerTimestamp(Instant.now().getEpochSecond()) + .setServerGuid(UUID.randomUUID().toString()) + .build(); + + insertScript.execute(destinationUuid, deviceId, envelope1); + + final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final List messagesToCheckForMrmKeys = removeScript.execute(destinationUuid, deviceId, + Collections.emptyList()) + .block(Duration.ofSeconds(1)); + + assertEquals(1, messagesToCheckForMrmKeys.size()); + + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveRecipientViewFromMrmDataScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveRecipientViewFromMrmDataScriptTest.java new file mode 100644 index 000000000..8f623abdd --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveRecipientViewFromMrmDataScriptTest.java @@ -0,0 +1,124 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.lettuce.core.cluster.SlotHash; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.util.Pair; +import reactor.core.publisher.Flux; +import reactor.util.function.Tuples; + +class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + @ParameterizedTest + @MethodSource + void testUpdateSingleKey(final Map> destinations) throws Exception { + + final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); + insertMrmScript.execute(sharedMrmKey, + MessagesCacheTest.generateRandomMrmMessage(destinations)); + + final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(destinations.entrySet()) + .flatMap(e -> Flux.fromStream(e.getValue().stream().map(deviceId -> Tuples.of(e.getKey(), deviceId)))) + .flatMap(aciServiceIdentifierByteTuple -> removeRecipientViewFromMrmDataScript.execute(List.of(sharedMrmKey), + aciServiceIdentifierByteTuple.getT1(), aciServiceIdentifierByteTuple.getT2())) + .reduce(Long::sum) + .block(Duration.ofSeconds(35))); + + assertEquals(1, keysRemoved); + + final long keyExists = REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withBinaryCluster(conn -> conn.sync().exists(sharedMrmKey)); + assertEquals(0, keyExists); + } + + public static List>> testUpdateSingleKey() { + final Map> singleAccount = Map.of( + new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2)); + + final List>> testCases = new ArrayList<>(); + testCases.add(singleAccount); + + // Generate a more, from smallish to very large + for (int j = 1000; j <= 81000; j *= 3) { + + final Map> deviceLists = new HashMap<>(); + final Map> manyAccounts = IntStream.range(0, j) + .mapToObj(i -> { + final int deviceCount = 1 + i % 5; + final List devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count) + .mapToObj(v -> (byte) v) + .toList()); + + return new Pair<>(new AciServiceIdentifier(UUID.randomUUID()), devices); + }) + .collect(Collectors.toMap(Pair::first, Pair::second)); + + testCases.add(manyAccounts); + } + + return testCases; + } + + @ParameterizedTest + @ValueSource(ints = {1, 10, 100, 1000, 10000}) + void testUpdateManyKeys(int keyCount) throws Exception { + + final List sharedMrmKeys = new ArrayList<>(keyCount); + final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + final byte deviceId = 1; + + for (int i = 0; i < keyCount; i++) { + + final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); + insertMrmScript.execute(sharedMrmKey, + MessagesCacheTest.generateRandomMrmMessage(aciServiceIdentifier, deviceId)); + + sharedMrmKeys.add(sharedMrmKey); + } + + final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(sharedMrmKeys) + .collectMultimap(SlotHash::getSlot) + .flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values())) + .flatMap(keys -> removeRecipientViewFromMrmDataScript.execute(keys, aciServiceIdentifier, deviceId)) + .reduce(Long::sum) + .block(Duration.ofSeconds(5))); + + assertEquals(sharedMrmKeys.size(), keysRemoved); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 9f04dac46..cf9dbb10f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.storage; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; @@ -25,6 +26,8 @@ import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands; import io.lettuce.core.protocol.AsyncCommand; import io.lettuce.core.protocol.RedisCommand; +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.time.Clock; import java.time.Duration; @@ -32,9 +35,12 @@ import java.time.ZoneId; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Deque; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Random; import java.util.UUID; @@ -42,11 +48,13 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.stream.Collectors; +import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -57,7 +65,12 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.reactivestreams.Publisher; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; +import org.signal.libsignal.protocol.ServiceId; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; @@ -83,6 +96,8 @@ class WithRealCluster { private Scheduler messageDeliveryScheduler; private MessagesCache messagesCache; + private DynamicConfigurationManager dynamicConfigurationManager; + private static final UUID DESTINATION_UUID = UUID.randomUUID(); private static final byte DESTINATION_DEVICE_ID = 7; @@ -95,11 +110,16 @@ void setUp() throws Exception { connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz"); }); + final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); + when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true)); + dynamicConfigurationManager = mock(DynamicConfigurationManager.class); + when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); + sharedExecutorService = Executors.newSingleThreadExecutor(); resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor(); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService, - messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); + messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager); messagesCache.start(); } @@ -148,10 +168,10 @@ void testRemoveByUUID(final boolean sealedSender) throws Exception { final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); - final Optional maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID, + final Optional maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS); - assertEquals(Optional.of(message), maybeRemovedMessage); + assertEquals(Optional.of(RemovedMessage.fromEnvelope(message)), maybeRemovedMessage); } @ParameterizedTest @@ -181,11 +201,11 @@ void testRemoveBatchByUUID(final boolean sealedSender) throws Exception { message); } - final List removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, + final List removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid())) .collect(Collectors.toList())).get(5, TimeUnit.SECONDS); - assertEquals(messagesToRemove, removedMessages); + assertEquals(messagesToRemove.stream().map(RemovedMessage::fromEnvelope).toList(), removedMessages); assertEquals(messagesToPreserve, messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); } @@ -283,7 +303,8 @@ void testGetMessagesPublisher(final boolean expectStale) throws Exception { } final MessagesCache messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), - sharedExecutorService, messageDeliveryScheduler, sharedExecutorService, cacheClock); + sharedExecutorService, messageDeliveryScheduler, sharedExecutorService, cacheClock, + dynamicConfigurationManager); final List actualMessages = Flux.from( messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID)) @@ -320,7 +341,7 @@ void testGetMessagesPublisher(final boolean expectStale) throws Exception { @ParameterizedTest @ValueSource(booleans = {true, false}) void testClearQueueForDevice(final boolean sealedSender) { - final int messageCount = 100; + final int messageCount = 1000; for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { for (int i = 0; i < messageCount; i++) { @@ -340,7 +361,7 @@ void testClearQueueForDevice(final boolean sealedSender) { @ParameterizedTest @ValueSource(booleans = {true, false}) void testClearQueueForAccount(final boolean sealedSender) { - final int messageCount = 100; + final int messageCount = 1000; for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { for (int i = 0; i < messageCount; i++) { @@ -542,6 +563,57 @@ void testAvailabilityListenerResponses() { }); } + @Test + void testMultiRecipientMessage() throws Exception { + final UUID destinationUuid = UUID.randomUUID(); + final byte deviceId = 1; + + final UUID mrmGuid = UUID.randomUUID(); + final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage( + new AciServiceIdentifier(destinationUuid), deviceId); + final byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrmGuid, mrm); + + final UUID guid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(guid, true) + .toBuilder() + // clear some things added by the helper + .clearServerGuid() + // mrm views phase 1: messages have content + .setContent( + ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(new ServiceId.Aci(destinationUuid))))) + .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) + .build(); + messagesCache.insert(guid, destinationUuid, deviceId, message); + + assertEquals(1L, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid)))); + + final List messages = get(destinationUuid, deviceId, 1); + assertEquals(1, messages.size()); + assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid())); + assertFalse(messages.getFirst().hasSharedMrmKey()); + + final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients() + .get(new ServiceId.Aci(destinationUuid)); + assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray()); + + final Optional removedMessage = messagesCache.remove(destinationUuid, deviceId, guid) + .join(); + + assertTrue(removedMessage.isPresent()); + assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString())); + assertTrue(get(destinationUuid, deviceId, 1).isEmpty()); + + // updating the shared MRM data is purely async, so we just wait for it + assertTimeoutPreemptively(Duration.ofSeconds(1), () -> { + boolean exists; + do { + exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid))); + } while (exists); + }); + } + private List get(final UUID destinationUuid, final byte destinationDeviceId, final int messageCount) { return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId)) @@ -573,7 +645,7 @@ void setup() throws Exception { messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); messagesCache = new MessagesCache(mockCluster, mock(ExecutorService.class), messageDeliveryScheduler, - Executors.newSingleThreadExecutor(), Clock.systemUTC()); + Executors.newSingleThreadExecutor(), Clock.systemUTC(), mock(DynamicConfigurationManager.class)); } @AfterEach @@ -755,18 +827,85 @@ private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, fin private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender, final long timestamp) { final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() - .setTimestamp(timestamp) + .setClientTimestamp(timestamp) .setServerTimestamp(timestamp) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setServerGuid(messageGuid.toString()) - .setDestinationUuid(UUID.randomUUID().toString()); + .setDestinationServiceId(UUID.randomUUID().toString()); if (!sealedSender) { envelopeBuilder.setSourceDevice(random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1) - .setSourceUuid(UUID.randomUUID().toString()); + .setSourceServiceId(UUID.randomUUID().toString()); } return envelopeBuilder.build(); } + + static SealedSenderMultiRecipientMessage generateRandomMrmMessage( + Map> destinations) { + + + try { + final ByteBuffer prefix = ByteBuffer.allocate(7); + prefix.put((byte) 0x23); // version + writeVarint(prefix, destinations.size()); // recipient count + prefix.flip(); + + List recipients = new ArrayList<>(destinations.size()); + + for (Map.Entry> aciServiceIdentifierAndDeviceIds : destinations.entrySet()) { + + final AciServiceIdentifier destination = aciServiceIdentifierAndDeviceIds.getKey(); + final List deviceIds = aciServiceIdentifierAndDeviceIds.getValue(); + + assert deviceIds.size() < 255; + + final ByteBuffer recipient = ByteBuffer.allocate(17 + 3 * deviceIds.size() + 48); + + recipient.put(destination.toFixedWidthByteArray()); + for (int i = 0; i < deviceIds.size(); i++) { + final int hasMore = i == deviceIds.size() - 1 ? 0x0000 : 0x8000; + recipient.put(new byte[]{deviceIds.get(i)}); // device ID + recipient.putShort((short) ((100 + deviceIds.get(i)) | hasMore)); // registration ID + } + + final byte[] keyMaterial = new byte[48]; + ThreadLocalRandom.current().nextBytes(keyMaterial); + recipient.put(keyMaterial); + + recipients.add(recipient); + } + + final byte[] commonPayload = new byte[64]; + ThreadLocalRandom.current().nextBytes(commonPayload); + + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + baos.write(prefix.array(), 0, prefix.limit()); + for (ByteBuffer recipient : recipients) { + baos.write(recipient.array()); + } + baos.write(commonPayload); + + return SealedSenderMultiRecipientMessage.parse(baos.toByteArray()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + static SealedSenderMultiRecipientMessage generateRandomMrmMessage(AciServiceIdentifier destination, + byte... deviceIds) { + + final Map> destinations = new HashMap<>(); + destinations.put(destination, Arrays.asList(ArrayUtils.toObject(deviceIds))); + return generateRandomMrmMessage(destinations); + } + + private static void writeVarint(ByteBuffer bb, long n) { + while (n >= 0x80) { + bb.put((byte) (n & 0x7F | 0x80)); + n = n >> 7; + } + bb.put((byte) (n & 0x7F)); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java index 7009ef238..bb44bd815 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java @@ -6,8 +6,6 @@ package org.whispersystems.textsecuregcm.storage; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import com.google.protobuf.ByteString; import java.time.Duration; @@ -31,7 +29,6 @@ import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.MessageHelper; -import org.whispersystems.textsecuregcm.util.SystemMapper; import reactor.core.publisher.Flux; import reactor.test.StepVerifier; @@ -47,31 +44,31 @@ class MessagesDynamoDbTest { final long serverTimestamp = System.currentTimeMillis(); MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder(); builder.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER); - builder.setTimestamp(123456789L); + builder.setClientTimestamp(123456789L); builder.setContent(ByteString.copyFrom(new byte[]{(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF})); builder.setServerGuid(UUID.randomUUID().toString()); builder.setServerTimestamp(serverTimestamp); - builder.setDestinationUuid(UUID.randomUUID().toString()); + builder.setDestinationServiceId(UUID.randomUUID().toString()); MESSAGE1 = builder.build(); builder.setType(MessageProtos.Envelope.Type.CIPHERTEXT); - builder.setSourceUuid(UUID.randomUUID().toString()); + builder.setSourceServiceId(UUID.randomUUID().toString()); builder.setSourceDevice(1); builder.setContent(ByteString.copyFromUtf8("MOO")); builder.setServerGuid(UUID.randomUUID().toString()); builder.setServerTimestamp(serverTimestamp + 1); - builder.setDestinationUuid(UUID.randomUUID().toString()); + builder.setDestinationServiceId(UUID.randomUUID().toString()); MESSAGE2 = builder.build(); builder.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER); - builder.clearSourceUuid(); + builder.clearSourceDevice(); builder.clearSourceDevice(); builder.setContent(ByteString.copyFromUtf8("COW")); builder.setServerGuid(UUID.randomUUID().toString()); builder.setServerTimestamp(serverTimestamp); // Test same millisecond arrival for two different messages - builder.setDestinationUuid(UUID.randomUUID().toString()); + builder.setDestinationServiceId(UUID.randomUUID().toString()); MESSAGE3 = builder.build(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java index e655fcb3b..ef4e66157 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java @@ -35,7 +35,7 @@ class MessagesManagerTest { void insert() { final UUID sourceAci = UUID.randomUUID(); final Envelope message = Envelope.newBuilder() - .setSourceUuid(sourceAci.toString()) + .setSourceServiceId(sourceAci.toString()) .build(); final UUID destinationUuid = UUID.randomUUID(); @@ -45,7 +45,7 @@ void insert() { verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class)); final Envelope syncMessage = Envelope.newBuilder(message) - .setSourceUuid(destinationUuid.toString()) + .setSourceServiceId(destinationUuid.toString()) .build(); messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java index 30039d5ca..e7cdf6fd4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java @@ -17,11 +17,11 @@ public static MessageProtos.Envelope createMessage(UUID senderUuid, final byte s return MessageProtos.Envelope.newBuilder() .setServerGuid(UUID.randomUUID().toString()) .setType(MessageProtos.Envelope.Type.CIPHERTEXT) - .setTimestamp(timestamp) + .setClientTimestamp(timestamp) .setServerTimestamp(0) - .setSourceUuid(senderUuid.toString()) + .setSourceServiceId(senderUuid.toString()) .setSourceDevice(senderDeviceId) - .setDestinationUuid(destinationUuid.toString()) + .setDestinationServiceId(destinationUuid.toString()) .setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8))) .build(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index e9d0863d6..47621eeb0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -44,6 +44,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; @@ -55,6 +56,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.MessagesCache; @@ -85,6 +87,8 @@ class WebSocketConnectionIntegrationTest { private Scheduler messageDeliveryScheduler; private ClientReleaseManager clientReleaseManager; + private DynamicConfigurationManager dynamicConfigurationManager; + private long serialTimestamp = System.currentTimeMillis(); @BeforeEach @@ -92,8 +96,10 @@ void setUp() throws Exception { sharedExecutorService = Executors.newSingleThreadExecutor(); scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); + dynamicConfigurationManager = mock(DynamicConfigurationManager.class); + when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService, - messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); + messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager); messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.MESSAGES.tableName(), Duration.ofDays(7), sharedExecutorService); @@ -381,12 +387,12 @@ private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) { final long timestamp = serialTimestamp++; return MessageProtos.Envelope.newBuilder() - .setTimestamp(timestamp) + .setClientTimestamp(timestamp) .setServerTimestamp(timestamp) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setServerGuid(messageGuid.toString()) - .setDestinationUuid(UUID.randomUUID().toString()) + .setDestinationServiceId(UUID.randomUUID().toString()) .build(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 9c4346686..08c83f60e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -48,7 +48,6 @@ import java.util.stream.IntStream; import java.util.stream.Stream; import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -297,19 +296,19 @@ void testPendingSend() { final Envelope firstMessage = Envelope.newBuilder() .setServerGuid(UUID.randomUUID().toString()) - .setSourceUuid(UUID.randomUUID().toString()) - .setDestinationUuid(accountUuid.toString()) + .setSourceServiceId(UUID.randomUUID().toString()) + .setDestinationServiceId(accountUuid.toString()) .setUpdatedPni(UUID.randomUUID().toString()) - .setTimestamp(System.currentTimeMillis()) + .setClientTimestamp(System.currentTimeMillis()) .setSourceDevice(1) .setType(Envelope.Type.CIPHERTEXT) .build(); final Envelope secondMessage = Envelope.newBuilder() .setServerGuid(UUID.randomUUID().toString()) - .setSourceUuid(senderTwoUuid.toString()) - .setDestinationUuid(accountUuid.toString()) - .setTimestamp(System.currentTimeMillis()) + .setSourceServiceId(senderTwoUuid.toString()) + .setDestinationServiceId(accountUuid.toString()) + .setClientTimestamp(System.currentTimeMillis()) .setSourceDevice(2) .setType(Envelope.Type.CIPHERTEXT) .build(); @@ -365,7 +364,7 @@ void testPendingSend() { futures.get(0).completeExceptionally(new IOException()); verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)), - eq(secondMessage.getTimestamp())); + eq(secondMessage.getClientTimestamp())); connection.stop(); verify(client).close(anyInt(), anyString()); @@ -616,10 +615,10 @@ void testProcessStoredMessagesContainsSenderUuid() { final byte[] body = argument.get(); try { final Envelope envelope = Envelope.parseFrom(body); - if (!envelope.hasSourceUuid() || envelope.getSourceUuid().length() == 0) { + if (!envelope.hasSourceServiceId() || envelope.getSourceServiceId().length() == 0) { return false; } - return envelope.getSourceUuid().equals(senderUuid.toString()); + return envelope.getSourceServiceId().equals(senderUuid.toString()); } catch (InvalidProtocolBufferException e) { return false; } @@ -627,7 +626,7 @@ void testProcessStoredMessagesContainsSenderUuid() { verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } - private @NotNull WebSocketConnection webSocketConnection(final WebSocketClient client) { + private WebSocketConnection webSocketConnection(final WebSocketClient client) { return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client, retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class)); @@ -933,11 +932,11 @@ private Envelope createMessage(UUID senderUuid, UUID destinationUuid, long times return Envelope.newBuilder() .setServerGuid(UUID.randomUUID().toString()) .setType(Envelope.Type.CIPHERTEXT) - .setTimestamp(timestamp) + .setClientTimestamp(timestamp) .setServerTimestamp(0) - .setSourceUuid(senderUuid.toString()) + .setSourceServiceId(senderUuid.toString()) .setSourceDevice(SOURCE_DEVICE_ID) - .setDestinationUuid(destinationUuid.toString()) + .setDestinationServiceId(destinationUuid.toString()) .setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8))) .build(); }