From ee5df0e11cb5f4ada4fa030550f40555c1bc6a97 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Sat, 9 Nov 2024 11:23:14 -0600 Subject: [PATCH] Always store and fetch shared MRM data --- .../dynamic/DynamicMessagesConfiguration.java | 5 +- .../controllers/MessageController.java | 15 +-- .../textsecuregcm/storage/MessagesCache.java | 93 +++++++++---------- .../controllers/MessageControllerTest.java | 32 ++++--- .../storage/MessagesCacheTest.java | 15 ++- 5 files changed, 78 insertions(+), 82 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java index ac9e3c86d..ece407844 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java @@ -5,10 +5,9 @@ package org.whispersystems.textsecuregcm.configuration.dynamic; -public record DynamicMessagesConfiguration(boolean storeSharedMrmData, boolean fetchSharedMrmData, - boolean useSharedMrmData) { +public record DynamicMessagesConfiguration(boolean useSharedMrmData) { public DynamicMessagesConfiguration() { - this(false, false, false); + this(false); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index f236800bf..859c7354e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -655,10 +655,7 @@ public class MessageController { } try { - @Nullable final byte[] sharedMrmKey = - dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().storeSharedMrmData() - ? messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage) - : null; + final byte[] sharedMrmKey = messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage); CompletableFuture.allOf( recipients.values().stream() @@ -941,7 +938,7 @@ public class MessageController { boolean story, boolean urgent, byte[] payload, - @Nullable byte[] sharedMrmKey) { + byte[] sharedMrmKey) { final Envelope.Builder messageBuilder = Envelope.newBuilder(); final long serverTimestamp = System.currentTimeMillis(); @@ -952,12 +949,10 @@ public class MessageController { .setServerTimestamp(serverTimestamp) .setStory(story) .setUrgent(urgent) - .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()); + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) + .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)); - if (sharedMrmKey != null) { - messageBuilder.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)); - } - // mrm views phase 2: always set content + // mrm views phase 3: always set content messageBuilder.setContent(ByteString.copyFrom(payload)); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index f830cf3f8..83621187d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -370,8 +370,8 @@ public class MessagesCache { messageMono = Mono.just(message.toBuilder().clearSharedMrmKey().build()); skippedStaleEphemeralMrmCounter.increment(); } else { - // mrm views phase 2: fetch shared MRM data -- internally depends on dynamic config that - // enables fetching and using it (the stored messages still always have `content` set upstream) + // mrm views phase 3: fetch shared MRM data -- internally depends on dynamic config that + // enables using it (the stored messages still always have `content` set upstream) messageMono = getMessageWithSharedMrmData(message, destinationDevice); } @@ -393,7 +393,6 @@ public class MessagesCache { /** * Returns the given message with its shared MRM data. * - * @see DynamicMessagesConfiguration#fetchSharedMrmData() * @see DynamicMessagesConfiguration#useSharedMrmData() */ private Mono getMessageWithSharedMrmData(final MessageProtos.Envelope mrmMessage, @@ -406,53 +405,49 @@ public class MessagesCache { mrmPhaseTwoMissingContentCounter.increment(); } - if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().fetchSharedMrmData() + final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME); + + final byte[] key = mrmMessage.getSharedMrmKey().toByteArray(); + final byte[] sharedMrmViewKey = MessagesCache.getSharedMrmViewKey( + // the message might be addressed to the account's PNI, so use the service ID from the envelope + ServiceIdentifier.valueOf(mrmMessage.getDestinationServiceId()), destinationDevice); + + final Mono messageFromRedisMono = Mono.from(redisCluster.withBinaryClusterReactive( + conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey) + .collectList() + .publishOn(messageDeliveryScheduler))) + .handle((mrmDataAndView, sink) -> { + try { + assert mrmDataAndView.size() == 2; + + final byte[] content = SealedSenderMultiRecipientMessage.messageForRecipient( + mrmDataAndView.getFirst().getValue(), + mrmDataAndView.getLast().getValue()); + + sink.next(mrmMessage.toBuilder() + .clearSharedMrmKey() + .setContent(ByteString.copyFrom(content)) + .build()); + + mrmContentRetrievedCounter.increment(); + } catch (Exception e) { + sink.error(e); + } + }) + .onErrorResume(throwable -> { + logger.warn("Failed to retrieve shared mrm data", throwable); + mrmRetrievalErrorCounter.increment(); + return Mono.empty(); + }) + .share(); + + if (mrmMessage.hasContent()) { + experiment.compareMonoResult(mrmMessage.toBuilder().clearSharedMrmKey().build(), messageFromRedisMono); + } + + if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().useSharedMrmData() || !mrmMessage.hasContent()) { - - final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME); - - final byte[] key = mrmMessage.getSharedMrmKey().toByteArray(); - final byte[] sharedMrmViewKey = MessagesCache.getSharedMrmViewKey( - // the message might be addressed to the account's PNI, so use the service ID from the envelope - ServiceIdentifier.valueOf(mrmMessage.getDestinationServiceId()), destinationDevice); - - final Mono messageFromRedisMono = Mono.from(redisCluster.withBinaryClusterReactive( - conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey) - .collectList() - .publishOn(messageDeliveryScheduler))) - .handle((mrmDataAndView, sink) -> { - try { - assert mrmDataAndView.size() == 2; - - final byte[] content = SealedSenderMultiRecipientMessage.messageForRecipient( - mrmDataAndView.getFirst().getValue(), - mrmDataAndView.getLast().getValue()); - - sink.next(mrmMessage.toBuilder() - .clearSharedMrmKey() - .setContent(ByteString.copyFrom(content)) - .build()); - - mrmContentRetrievedCounter.increment(); - } catch (Exception e) { - sink.error(e); - } - }) - .onErrorResume(throwable -> { - logger.warn("Failed to retrieve shared mrm data", throwable); - mrmRetrievalErrorCounter.increment(); - return Mono.empty(); - }) - .share(); - - if (mrmMessage.hasContent()) { - experiment.compareMonoResult(mrmMessage.toBuilder().clearSharedMrmKey().build(), messageFromRedisMono); - } - - if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().useSharedMrmData() - || !mrmMessage.hasContent()) { - return messageFromRedisMono; - } + return messageFromRedisMono; } // if fetching or using shared data is disabled, fallback to just() with the existing message diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index e75986856..c6bc89109 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -82,12 +82,12 @@ import org.junit.jupiter.params.provider.ValueSource; import org.junitpioneer.jupiter.cartesian.ArgumentSets; import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.mockito.ArgumentCaptor; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.zkgroup.ServerSecretParams; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration; -import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration; import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices; import org.whispersystems.textsecuregcm.entities.AccountStaleDevices; import org.whispersystems.textsecuregcm.entities.IncomingMessage; @@ -252,8 +252,6 @@ class MessageControllerTest { final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); when(dynamicConfiguration.getInboundMessageByteLimitConfiguration()).thenReturn(inboundMessageByteLimitConfiguration); - when(dynamicConfiguration.getMessagesConfiguration()).thenReturn( - new DynamicMessagesConfiguration(true, true, true)); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); @@ -1141,6 +1139,10 @@ class MessageControllerTest { @Test void testManyRecipientMessage() throws Exception { + + when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class))) + .thenReturn(new byte[]{1}); + final int nRecipients = 999; final int devicesPerRecipient = 5; final List recipients = new ArrayList<>(); @@ -1152,8 +1154,8 @@ class MessageControllerTest { d -> generateTestDevice( (byte) d, 100 + d, 10 * d, true)) .collect(Collectors.toList()); - final UUID aci = new UUID(0L, (long) i); - final UUID pni = new UUID(1L, (long) i); + final UUID aci = new UUID(0L, i); + final UUID pni = new UUID(1L, i); final String e164 = String.format("+1408555%04d", i); final Account account = AccountsHelper.generateTestAccount(e164, aci, pni, devices, UNIDENTIFIED_ACCESS_BYTES); @@ -1186,13 +1188,12 @@ class MessageControllerTest { // see testMultiRecipientMessageNoPni and testMultiRecipientMessagePni below for actual invocations private void testMultiRecipientMessage( - Map> destinations, - boolean authorize, - boolean isStory, - boolean urgent, - boolean explicitIdentifier, - int expectedStatus, - int expectedMessagesSent) throws Exception { + Map> destinations, boolean authorize, boolean isStory, boolean urgent, + boolean explicitIdentifier, int expectedStatus, int expectedMessagesSent) throws Exception { + + when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class))) + .thenReturn(new byte[]{1}); + final List recipients = new ArrayList<>(); destinations.forEach( (serviceIdentifier, deviceToRegistrationId) -> @@ -1383,6 +1384,10 @@ class MessageControllerTest { @Test void testMultiRecipientMessageWithGroupSendEndorsements() throws Exception { + + when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class))) + .thenReturn(new byte[]{1}); + final List recipients = List.of( new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]), new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), @@ -1550,6 +1555,9 @@ class MessageControllerTest { @MethodSource void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known, boolean useExplicitIdentifier) { + when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class))) + .thenReturn(new byte[]{1}); + final Recipient r1; if (known) { r1 = new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index ef785bf35..fbee9514c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -109,7 +109,7 @@ class MessagesCacheTest { void setUp() throws Exception { dynamicConfiguration = mock(DynamicConfiguration.class); when(dynamicConfiguration.getMessagesConfiguration()).thenReturn( - new DynamicMessagesConfiguration(true, true, true)); + new DynamicMessagesConfiguration(true)); dynamicConfigurationManager = mock(DynamicConfigurationManager.class); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); @@ -434,7 +434,7 @@ class MessagesCacheTest { @CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData) { when(dynamicConfiguration.getMessagesConfiguration()) - .thenReturn(new DynamicMessagesConfiguration(true, true, useSharedMrmData)); + .thenReturn(new DynamicMessagesConfiguration(useSharedMrmData)); final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID()); final byte deviceId = 1; @@ -493,13 +493,12 @@ class MessagesCacheTest { }, "Shared MRM data should be deleted asynchronously"); } - @CartesianTest - void testMultiRecipientMessagePhase2MissingContentSafeguard( - @CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData, - @CartesianTest.Values(booleans = {true, false}) final boolean fetchSharedMrmData) { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testMultiRecipientMessagePhase2MissingContentSafeguard(final boolean useSharedMrmData) { when(dynamicConfiguration.getMessagesConfiguration()) - .thenReturn(new DynamicMessagesConfiguration(true, fetchSharedMrmData, useSharedMrmData)); + .thenReturn(new DynamicMessagesConfiguration(useSharedMrmData)); final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID()); final byte deviceId = 1; @@ -554,7 +553,7 @@ class MessagesCacheTest { @CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData) { when(dynamicConfiguration.getMessagesConfiguration()) - .thenReturn(new DynamicMessagesConfiguration(true, true, useSharedMrmData)); + .thenReturn(new DynamicMessagesConfiguration(useSharedMrmData)); final UUID destinationUuid = UUID.randomUUID(); final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(destinationUuid);