diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java index 4e55f47cf..ac9e3c86d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java @@ -5,9 +5,10 @@ package org.whispersystems.textsecuregcm.configuration.dynamic; -public record DynamicMessagesConfiguration(boolean storeSharedMrmData, boolean mrmViewExperimentEnabled) { +public record DynamicMessagesConfiguration(boolean storeSharedMrmData, boolean fetchSharedMrmData, + boolean useSharedMrmData) { public DynamicMessagesConfiguration() { - this(false, false); + this(false, false, false); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 4593b8085..f9f68b8f5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -957,7 +957,7 @@ public class MessageController { if (sharedMrmKey != null) { messageBuilder.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)); } - // mrm views phase 1: always set content + // mrm views phase 2: always set content messageBuilder.setContent(ByteString.copyFrom(payload)); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 413ea877e..546484c9f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -137,6 +137,9 @@ public class MessagesCache { private final Counter staleEphemeralMessagesCounter = Metrics.counter( name(MessagesCache.class, "staleEphemeralMessages")); private final Counter mrmContentRetrievedCounter = Metrics.counter(name(MessagesCache.class, "mrmViewRetrieved")); + private final Counter mrmRetrievalErrorCounter = Metrics.counter(name(MessagesCache.class, "mrmRetrievalError")); + private final Counter mrmPhaseTwoMissingContentCounter = Metrics.counter( + name(MessagesCache.class, "mrmPhaseTwoMissingContent")); private final Counter sharedMrmDataKeyRemovedCounter = Metrics.counter( name(MessagesCache.class, "sharedMrmKeyRemoved")); @@ -349,19 +352,15 @@ public class MessagesCache { final Mono messageMono; if (message.hasSharedMrmKey()) { - final Mono experimentMono; if (isStaleEphemeralMessage(message, earliestAllowableEphemeralTimestamp)) { // skip fetching content for message that will be discarded - experimentMono = Mono.empty(); + messageMono = Mono.just(message.toBuilder().clearSharedMrmKey().build()); } else { - experimentMono = maybeRunMrmViewExperiment(message, destinationUuid, destinationDevice); + // mrm views phase 2: fetch shared MRM data -- internally depends on dynamic config that + // enables fetching and using it (the stored messages still always have `content` set upstream) + messageMono = getMessageWithSharedMrmData(message, destinationDevice); } - // mrm views phase 1: messageMono for sharedMrmKey is always Mono.just(), because messages always have content - // To avoid races, wait for the experiment to run, but ignore any errors - messageMono = experimentMono - .onErrorComplete() - .then(Mono.just(message.toBuilder().clearSharedMrmKey().build())); } else { messageMono = Mono.just(message); } @@ -378,14 +377,23 @@ public class MessagesCache { } /** - * Runs the fetch and compare logic for the MRM view experiment, if it is enabled. + * Returns the given message with its shared MRM data. * - * @see DynamicMessagesConfiguration#mrmViewExperimentEnabled() + * @see DynamicMessagesConfiguration#fetchSharedMrmData() + * @see DynamicMessagesConfiguration#useSharedMrmData() */ - private Mono maybeRunMrmViewExperiment(final MessageProtos.Envelope mrmMessage, final UUID destinationUuid, + private Mono getMessageWithSharedMrmData(final MessageProtos.Envelope mrmMessage, final byte destinationDevice) { - if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration() - .mrmViewExperimentEnabled()) { + + assert mrmMessage.hasSharedMrmKey(); + + // mrm views phase 2: messages have content + if (!mrmMessage.hasContent()) { + mrmPhaseTwoMissingContentCounter.increment(); + } + + if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().fetchSharedMrmData() + || !mrmMessage.hasContent()) { final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME); @@ -394,7 +402,7 @@ public class MessagesCache { // the message might be addressed to the account's PNI, so use the service ID from the envelope ServiceIdentifier.valueOf(mrmMessage.getDestinationServiceId()), destinationDevice); - final Mono mrmMessageMono = Mono.from(redisCluster.withBinaryClusterReactive( + final Mono messageFromRedisMono = Mono.from(redisCluster.withBinaryClusterReactive( conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey) .collectList() .publishOn(messageDeliveryScheduler))) @@ -416,14 +424,25 @@ public class MessagesCache { sink.error(e); } }) + .onErrorResume(throwable -> { + logger.warn("Failed to retrieve shared mrm data", throwable); + mrmRetrievalErrorCounter.increment(); + return Mono.empty(); + }) .share(); - experiment.compareMonoResult(mrmMessage.toBuilder().clearSharedMrmKey().build(), mrmMessageMono); + if (mrmMessage.hasContent()) { + experiment.compareMonoResult(mrmMessage.toBuilder().clearSharedMrmKey().build(), messageFromRedisMono); + } - return mrmMessageMono; - } else { - return Mono.empty(); + if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().useSharedMrmData() + || !mrmMessage.hasContent()) { + return messageFromRedisMono; + } } + + // if fetching or using shared data is disabled, fallback to just() with the existing message + return Mono.just(mrmMessage.toBuilder().clearSharedMrmKey().build()); } /** @@ -497,13 +516,9 @@ public class MessagesCache { .concatMap(message -> { final Mono messageMono; if (message.hasSharedMrmKey()) { - final Mono experimentMono = maybeRunMrmViewExperiment(message, accountUuid, destinationDevice); - - // mrm views phase 1: messageMono for sharedMrmKey is always Mono.just(), because messages always have content - // To avoid races, wait for the experiment to run, but ignore any errors - messageMono = experimentMono - .onErrorComplete() - .then(Mono.just(message.toBuilder().clearSharedMrmKey().build())); + // mrm views phase 2: fetch shared MRM data -- internally depends on dynamic config that + // enables fetching and using it (the stored messages still always have `content` set upstream) + messageMono = getMessageWithSharedMrmData(message, destinationDevice); } else { messageMono = Mono.just(message); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index a2c54c905..7b648ac98 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -252,7 +252,8 @@ class MessageControllerTest { final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); when(dynamicConfiguration.getInboundMessageByteLimitConfiguration()).thenReturn(inboundMessageByteLimitConfiguration); - when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true)); + when(dynamicConfiguration.getMessagesConfiguration()).thenReturn( + new DynamicMessagesConfiguration(true, true, true)); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 93b3b19de..299898323 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -62,6 +62,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.reactivestreams.Publisher; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.ServiceId; @@ -96,6 +97,7 @@ class MessagesCacheTest { private MessagesCache messagesCache; private DynamicConfigurationManager dynamicConfigurationManager; + private DynamicConfiguration dynamicConfiguration; private static final UUID DESTINATION_UUID = UUID.randomUUID(); @@ -103,8 +105,9 @@ class MessagesCacheTest { @BeforeEach void setUp() throws Exception { - final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); - when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true)); + dynamicConfiguration = mock(DynamicConfiguration.class); + when(dynamicConfiguration.getMessagesConfiguration()).thenReturn( + new DynamicMessagesConfiguration(true, true, true)); dynamicConfigurationManager = mock(DynamicConfigurationManager.class); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); @@ -399,9 +402,13 @@ class MessagesCacheTest { assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.getFirst())); } - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testMultiRecipientMessage(final boolean sharedMrmKeyPresent) throws Exception { + @CartesianTest + void testMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean sharedMrmKeyPresent, + @CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData) { + + when(dynamicConfiguration.getMessagesConfiguration()) + .thenReturn(new DynamicMessagesConfiguration(true, true, useSharedMrmData)); + final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID()); final byte deviceId = 1; @@ -419,7 +426,7 @@ class MessagesCacheTest { .toBuilder() // clear some things added by the helper .clearServerGuid() - // mrm views phase 1: messages have content + // mrm views phase 2: messages have content .setContent( ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(destinationServiceId.toLibsignal())))) .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) @@ -430,10 +437,70 @@ class MessagesCacheTest { .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey))); final List messages = get(destinationServiceId.uuid(), deviceId, 1); + if (useSharedMrmData && !sharedMrmKeyPresent) { + assertTrue(messages.isEmpty()); + } else { + + assertEquals(1, messages.size()); + assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid())); + assertFalse(messages.getFirst().hasSharedMrmKey()); + final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients() + .get(destinationServiceId.toLibsignal()); + assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray()); + } + + final Optional removedMessage = messagesCache.remove(destinationServiceId.uuid(), deviceId, guid) + .join(); + + assertTrue(removedMessage.isPresent()); + assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString())); + assertTrue(get(destinationServiceId.uuid(), deviceId, 1).isEmpty()); + + // updating the shared MRM data is purely async, so we just wait for it + assertTimeoutPreemptively(Duration.ofSeconds(1), () -> { + boolean exists; + do { + exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)); + } while (exists); + }, "Shared MRM data should be deleted asynchronously"); + } + + @CartesianTest + void testMultiRecipientMessagePhase2MissingContentSafeguard( + @CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData, + @CartesianTest.Values(booleans = {true, false}) final boolean fetchSharedMrmData) { + + when(dynamicConfiguration.getMessagesConfiguration()) + .thenReturn(new DynamicMessagesConfiguration(true, fetchSharedMrmData, useSharedMrmData)); + + final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID()); + final byte deviceId = 1; + + final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId); + + final byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm); + + final UUID guid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(guid, destinationServiceId, true) + .toBuilder() + // clear some things added by the helper + .clearServerGuid() + // mrm views phase 2: there is a safeguard against missing content, even if the dynamic configuration + // is to not fetch or use shared MRM data + .clearContent() + .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) + .build(); + messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message); + + assertEquals(1, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey))); + + final List messages = get(destinationServiceId.uuid(), deviceId, 1); + assertEquals(1, messages.size()); assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid())); assertFalse(messages.getFirst().hasSharedMrmKey()); - final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients() .get(destinationServiceId.toLibsignal()); assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray()); @@ -455,19 +522,24 @@ class MessagesCacheTest { }, "Shared MRM data should be deleted asynchronously"); } - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testGetMessagesToPersist(final boolean sharedMrmKeyPresent) { + @CartesianTest + void testGetMessagesToPersist(@CartesianTest.Values(booleans = {true, false}) final boolean sharedMrmKeyPresent, + @CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData) { + + when(dynamicConfiguration.getMessagesConfiguration()) + .thenReturn(new DynamicMessagesConfiguration(true, true, useSharedMrmData)); + final UUID destinationUuid = UUID.randomUUID(); + final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(destinationUuid); final byte deviceId = 1; final UUID messageGuid = UUID.randomUUID(); - final MessageProtos.Envelope message = generateRandomMessage(destinationUuid, true); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, + new AciServiceIdentifier(destinationUuid), true); messagesCache.insert(messageGuid, destinationUuid, deviceId, message); - final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage( - new AciServiceIdentifier(destinationUuid), deviceId); + final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId); final byte[] sharedMrmDataKey; if (sharedMrmKeyPresent) { @@ -477,31 +549,35 @@ class MessagesCacheTest { } final UUID mrmMessageGuid = UUID.randomUUID(); - final MessageProtos.Envelope mrmMessage = generateRandomMessage(mrmMessageGuid, true) + final MessageProtos.Envelope mrmMessage = generateRandomMessage(mrmMessageGuid, destinationServiceId, true) .toBuilder() // clear some things added by the helper .clearServerGuid() - // mrm views phase 1: messages have content + // mrm views phase 2: messages have content .setContent( ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(new ServiceId.Aci(destinationUuid))))) .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) .build(); messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage); - final List messages = get(destinationUuid, deviceId, 100); + final List messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100); - assertEquals(2, messages.size()); + if (useSharedMrmData && !sharedMrmKeyPresent) { + assertEquals(1, messages.size()); + } else { + assertEquals(2, messages.size()); + + assertEquals(mrmMessage.toBuilder(). + clearSharedMrmKey(). + setServerGuid(mrmMessageGuid.toString()) + .build(), + messages.getLast()); + } assertEquals(message.toBuilder() .setServerGuid(messageGuid.toString()) .build(), messages.getFirst()); - - assertEquals(mrmMessage.toBuilder(). - clearSharedMrmKey(). - setServerGuid(mrmMessageGuid.toString()) - .build(), - messages.getLast()); } private List get(final UUID destinationUuid, final byte destinationDeviceId,