Always store and fetch shared MRM data

This commit is contained in:
Chris Eager 2024-11-09 11:23:14 -06:00 committed by Chris Eager
parent d53a6e4c42
commit ee5df0e11c
5 changed files with 78 additions and 82 deletions

View File

@ -5,10 +5,9 @@
package org.whispersystems.textsecuregcm.configuration.dynamic;
public record DynamicMessagesConfiguration(boolean storeSharedMrmData, boolean fetchSharedMrmData,
boolean useSharedMrmData) {
public record DynamicMessagesConfiguration(boolean useSharedMrmData) {
public DynamicMessagesConfiguration() {
this(false, false, false);
this(false);
}
}

View File

@ -655,10 +655,7 @@ public class MessageController {
}
try {
@Nullable final byte[] sharedMrmKey =
dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().storeSharedMrmData()
? messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage)
: null;
final byte[] sharedMrmKey = messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage);
CompletableFuture.allOf(
recipients.values().stream()
@ -941,7 +938,7 @@ public class MessageController {
boolean story,
boolean urgent,
byte[] payload,
@Nullable byte[] sharedMrmKey) {
byte[] sharedMrmKey) {
final Envelope.Builder messageBuilder = Envelope.newBuilder();
final long serverTimestamp = System.currentTimeMillis();
@ -952,12 +949,10 @@ public class MessageController {
.setServerTimestamp(serverTimestamp)
.setStory(story)
.setUrgent(urgent)
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString());
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey));
if (sharedMrmKey != null) {
messageBuilder.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey));
}
// mrm views phase 2: always set content
// mrm views phase 3: always set content
messageBuilder.setContent(ByteString.copyFrom(payload));
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);

View File

@ -370,8 +370,8 @@ public class MessagesCache {
messageMono = Mono.just(message.toBuilder().clearSharedMrmKey().build());
skippedStaleEphemeralMrmCounter.increment();
} else {
// mrm views phase 2: fetch shared MRM data -- internally depends on dynamic config that
// enables fetching and using it (the stored messages still always have `content` set upstream)
// mrm views phase 3: fetch shared MRM data -- internally depends on dynamic config that
// enables using it (the stored messages still always have `content` set upstream)
messageMono = getMessageWithSharedMrmData(message, destinationDevice);
}
@ -393,7 +393,6 @@ public class MessagesCache {
/**
* Returns the given message with its shared MRM data.
*
* @see DynamicMessagesConfiguration#fetchSharedMrmData()
* @see DynamicMessagesConfiguration#useSharedMrmData()
*/
private Mono<MessageProtos.Envelope> getMessageWithSharedMrmData(final MessageProtos.Envelope mrmMessage,
@ -406,53 +405,49 @@ public class MessagesCache {
mrmPhaseTwoMissingContentCounter.increment();
}
if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().fetchSharedMrmData()
final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME);
final byte[] key = mrmMessage.getSharedMrmKey().toByteArray();
final byte[] sharedMrmViewKey = MessagesCache.getSharedMrmViewKey(
// the message might be addressed to the account's PNI, so use the service ID from the envelope
ServiceIdentifier.valueOf(mrmMessage.getDestinationServiceId()), destinationDevice);
final Mono<MessageProtos.Envelope> messageFromRedisMono = Mono.from(redisCluster.withBinaryClusterReactive(
conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey)
.collectList()
.publishOn(messageDeliveryScheduler)))
.<MessageProtos.Envelope>handle((mrmDataAndView, sink) -> {
try {
assert mrmDataAndView.size() == 2;
final byte[] content = SealedSenderMultiRecipientMessage.messageForRecipient(
mrmDataAndView.getFirst().getValue(),
mrmDataAndView.getLast().getValue());
sink.next(mrmMessage.toBuilder()
.clearSharedMrmKey()
.setContent(ByteString.copyFrom(content))
.build());
mrmContentRetrievedCounter.increment();
} catch (Exception e) {
sink.error(e);
}
})
.onErrorResume(throwable -> {
logger.warn("Failed to retrieve shared mrm data", throwable);
mrmRetrievalErrorCounter.increment();
return Mono.empty();
})
.share();
if (mrmMessage.hasContent()) {
experiment.compareMonoResult(mrmMessage.toBuilder().clearSharedMrmKey().build(), messageFromRedisMono);
}
if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().useSharedMrmData()
|| !mrmMessage.hasContent()) {
final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME);
final byte[] key = mrmMessage.getSharedMrmKey().toByteArray();
final byte[] sharedMrmViewKey = MessagesCache.getSharedMrmViewKey(
// the message might be addressed to the account's PNI, so use the service ID from the envelope
ServiceIdentifier.valueOf(mrmMessage.getDestinationServiceId()), destinationDevice);
final Mono<MessageProtos.Envelope> messageFromRedisMono = Mono.from(redisCluster.withBinaryClusterReactive(
conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey)
.collectList()
.publishOn(messageDeliveryScheduler)))
.<MessageProtos.Envelope>handle((mrmDataAndView, sink) -> {
try {
assert mrmDataAndView.size() == 2;
final byte[] content = SealedSenderMultiRecipientMessage.messageForRecipient(
mrmDataAndView.getFirst().getValue(),
mrmDataAndView.getLast().getValue());
sink.next(mrmMessage.toBuilder()
.clearSharedMrmKey()
.setContent(ByteString.copyFrom(content))
.build());
mrmContentRetrievedCounter.increment();
} catch (Exception e) {
sink.error(e);
}
})
.onErrorResume(throwable -> {
logger.warn("Failed to retrieve shared mrm data", throwable);
mrmRetrievalErrorCounter.increment();
return Mono.empty();
})
.share();
if (mrmMessage.hasContent()) {
experiment.compareMonoResult(mrmMessage.toBuilder().clearSharedMrmKey().build(), messageFromRedisMono);
}
if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().useSharedMrmData()
|| !mrmMessage.hasContent()) {
return messageFromRedisMono;
}
return messageFromRedisMono;
}
// if fetching or using shared data is disabled, fallback to just() with the existing message

View File

@ -82,12 +82,12 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
@ -252,8 +252,6 @@ class MessageControllerTest {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getInboundMessageByteLimitConfiguration()).thenReturn(inboundMessageByteLimitConfiguration);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(
new DynamicMessagesConfiguration(true, true, true));
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
@ -1141,6 +1139,10 @@ class MessageControllerTest {
@Test
void testManyRecipientMessage() throws Exception {
when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class)))
.thenReturn(new byte[]{1});
final int nRecipients = 999;
final int devicesPerRecipient = 5;
final List<Recipient> recipients = new ArrayList<>();
@ -1152,8 +1154,8 @@ class MessageControllerTest {
d -> generateTestDevice(
(byte) d, 100 + d, 10 * d, true))
.collect(Collectors.toList());
final UUID aci = new UUID(0L, (long) i);
final UUID pni = new UUID(1L, (long) i);
final UUID aci = new UUID(0L, i);
final UUID pni = new UUID(1L, i);
final String e164 = String.format("+1408555%04d", i);
final Account account = AccountsHelper.generateTestAccount(e164, aci, pni, devices, UNIDENTIFIED_ACCESS_BYTES);
@ -1186,13 +1188,12 @@ class MessageControllerTest {
// see testMultiRecipientMessageNoPni and testMultiRecipientMessagePni below for actual invocations
private void testMultiRecipientMessage(
Map<ServiceIdentifier, Map<Byte, Integer>> destinations,
boolean authorize,
boolean isStory,
boolean urgent,
boolean explicitIdentifier,
int expectedStatus,
int expectedMessagesSent) throws Exception {
Map<ServiceIdentifier, Map<Byte, Integer>> destinations, boolean authorize, boolean isStory, boolean urgent,
boolean explicitIdentifier, int expectedStatus, int expectedMessagesSent) throws Exception {
when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class)))
.thenReturn(new byte[]{1});
final List<Recipient> recipients = new ArrayList<>();
destinations.forEach(
(serviceIdentifier, deviceToRegistrationId) ->
@ -1383,6 +1384,10 @@ class MessageControllerTest {
@Test
void testMultiRecipientMessageWithGroupSendEndorsements() throws Exception {
when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class)))
.thenReturn(new byte[]{1});
final List<Recipient> recipients = List.of(
new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
@ -1550,6 +1555,9 @@ class MessageControllerTest {
@MethodSource
void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known, boolean useExplicitIdentifier) {
when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class)))
.thenReturn(new byte[]{1});
final Recipient r1;
if (known) {
r1 = new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);

View File

@ -109,7 +109,7 @@ class MessagesCacheTest {
void setUp() throws Exception {
dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(
new DynamicMessagesConfiguration(true, true, true));
new DynamicMessagesConfiguration(true));
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
@ -434,7 +434,7 @@ class MessagesCacheTest {
@CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData) {
when(dynamicConfiguration.getMessagesConfiguration())
.thenReturn(new DynamicMessagesConfiguration(true, true, useSharedMrmData));
.thenReturn(new DynamicMessagesConfiguration(useSharedMrmData));
final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1;
@ -493,13 +493,12 @@ class MessagesCacheTest {
}, "Shared MRM data should be deleted asynchronously");
}
@CartesianTest
void testMultiRecipientMessagePhase2MissingContentSafeguard(
@CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData,
@CartesianTest.Values(booleans = {true, false}) final boolean fetchSharedMrmData) {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiRecipientMessagePhase2MissingContentSafeguard(final boolean useSharedMrmData) {
when(dynamicConfiguration.getMessagesConfiguration())
.thenReturn(new DynamicMessagesConfiguration(true, fetchSharedMrmData, useSharedMrmData));
.thenReturn(new DynamicMessagesConfiguration(useSharedMrmData));
final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1;
@ -554,7 +553,7 @@ class MessagesCacheTest {
@CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData) {
when(dynamicConfiguration.getMessagesConfiguration())
.thenReturn(new DynamicMessagesConfiguration(true, true, useSharedMrmData));
.thenReturn(new DynamicMessagesConfiguration(useSharedMrmData));
final UUID destinationUuid = UUID.randomUUID();
final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(destinationUuid);