Use MRM shared data views

This commit is contained in:
Chris Eager 2024-11-12 12:23:54 -06:00 committed by Chris Eager
parent 085f013bf9
commit ea75c39b58
10 changed files with 24 additions and 166 deletions

View File

@ -603,7 +603,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor); DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler, MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler,
messageDeletionAsyncExecutor, clock, dynamicConfigurationManager); messageDeletionAsyncExecutor, clock);
ClientReleaseManager clientReleaseManager = new ClientReleaseManager(clientReleases, ClientReleaseManager clientReleaseManager = new ClientReleaseManager(clientReleases,
recurringJobExecutor, recurringJobExecutor,
config.getClientReleaseConfiguration().refreshInterval(), config.getClientReleaseConfiguration().refreshInterval(),

View File

@ -68,10 +68,6 @@ public class DynamicConfiguration {
@Valid @Valid
DynamicMetricsConfiguration metricsConfiguration = new DynamicMetricsConfiguration(false); DynamicMetricsConfiguration metricsConfiguration = new DynamicMetricsConfiguration(false);
@JsonProperty
@Valid
DynamicMessagesConfiguration messagesConfiguration = new DynamicMessagesConfiguration();
@JsonProperty @JsonProperty
@Valid @Valid
List<String> svrStatusCodesToIgnoreForAccountDeletion = Collections.emptyList(); List<String> svrStatusCodesToIgnoreForAccountDeletion = Collections.emptyList();
@ -130,10 +126,6 @@ public class DynamicConfiguration {
return metricsConfiguration; return metricsConfiguration;
} }
public DynamicMessagesConfiguration getMessagesConfiguration() {
return messagesConfiguration;
}
public List<String> getSvrStatusCodesToIgnoreForAccountDeletion() { public List<String> getSvrStatusCodesToIgnoreForAccountDeletion() {
return svrStatusCodesToIgnoreForAccountDeletion; return svrStatusCodesToIgnoreForAccountDeletion;
} }

View File

@ -1,13 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration.dynamic;
public record DynamicMessagesConfiguration(boolean useSharedMrmData) {
public DynamicMessagesConfiguration() {
this(false);
}
}

View File

@ -952,9 +952,6 @@ public class MessageController {
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)); .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey));
// mrm views phase 3: always set content
messageBuilder.setContent(ByteString.copyFrom(payload));
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
} }

View File

@ -36,10 +36,7 @@ import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.ServiceId;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
@ -116,8 +113,6 @@ public class MessagesCache {
// messageDeletionExecutorService wrapped into a reactor Scheduler // messageDeletionExecutorService wrapped into a reactor Scheduler
private final Scheduler messageDeletionScheduler; private final Scheduler messageDeletionScheduler;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final MessagesCacheInsertScript insertScript; private final MessagesCacheInsertScript insertScript;
private final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript; private final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript;
private final MessagesCacheRemoveByGuidScript removeByGuidScript; private final MessagesCacheRemoveByGuidScript removeByGuidScript;
@ -137,10 +132,8 @@ public class MessagesCache {
private final Counter staleEphemeralMessagesCounter = Metrics.counter( private final Counter staleEphemeralMessagesCounter = Metrics.counter(
name(MessagesCache.class, "staleEphemeralMessages")); name(MessagesCache.class, "staleEphemeralMessages"));
private final Counter mrmContentRetrievedCounter = Metrics.counter(name(MessagesCache.class, "mrmViewRetrieved")); private final Counter mrmContentRetrievedCounter = Metrics.counter(name(MessagesCache.class, "mrmViewRetrieved"));
private final String MRM_RETRIEVAL_ERROR_COUNTER_NAME = "mrmRetrievalError"; private final String MRM_RETRIEVAL_ERROR_COUNTER_NAME = name(MessagesCache.class, "mrmRetrievalError");
private final String EPHEMERAL_TAG_NAME = "ephemeral"; private final String EPHEMERAL_TAG_NAME = "ephemeral";
private final Counter mrmPhaseTwoMissingContentCounter = Metrics.counter(
name(MessagesCache.class, "mrmPhaseTwoMissingContent"));
private final Counter skippedStaleEphemeralMrmCounter = Metrics.counter( private final Counter skippedStaleEphemeralMrmCounter = Metrics.counter(
name(MessagesCache.class, "skippedStaleEphemeralMrm")); name(MessagesCache.class, "skippedStaleEphemeralMrm"));
private final Counter sharedMrmDataKeyRemovedCounter = Metrics.counter( private final Counter sharedMrmDataKeyRemovedCounter = Metrics.counter(
@ -149,8 +142,6 @@ public class MessagesCache {
static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot";
private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8);
private static final String MRM_VIEWS_EXPERIMENT_NAME = "mrmViews";
@VisibleForTesting @VisibleForTesting
static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10); static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10);
@ -164,8 +155,7 @@ public class MessagesCache {
public MessagesCache(final FaultTolerantRedisClusterClient redisCluster, public MessagesCache(final FaultTolerantRedisClusterClient redisCluster,
final Scheduler messageDeliveryScheduler, final Scheduler messageDeliveryScheduler,
final ExecutorService messageDeletionExecutorService, final ExecutorService messageDeletionExecutorService,
final Clock clock, final Clock clock)
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager)
throws IOException { throws IOException {
this( this(
@ -173,7 +163,6 @@ public class MessagesCache {
messageDeliveryScheduler, messageDeliveryScheduler,
messageDeletionExecutorService, messageDeletionExecutorService,
clock, clock,
dynamicConfigurationManager,
new MessagesCacheInsertScript(redisCluster), new MessagesCacheInsertScript(redisCluster),
new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(redisCluster), new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(redisCluster),
new MessagesCacheGetItemsScript(redisCluster), new MessagesCacheGetItemsScript(redisCluster),
@ -189,7 +178,6 @@ public class MessagesCache {
MessagesCache(final FaultTolerantRedisClusterClient redisCluster, MessagesCache(final FaultTolerantRedisClusterClient redisCluster,
final Scheduler messageDeliveryScheduler, final Scheduler messageDeliveryScheduler,
final ExecutorService messageDeletionExecutorService, final Clock clock, final ExecutorService messageDeletionExecutorService, final Clock clock,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final MessagesCacheInsertScript insertScript, final MessagesCacheInsertScript insertScript,
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript, final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript,
final MessagesCacheGetItemsScript getItemsScript, final MessagesCacheRemoveByGuidScript removeByGuidScript, final MessagesCacheGetItemsScript getItemsScript, final MessagesCacheRemoveByGuidScript removeByGuidScript,
@ -205,8 +193,6 @@ public class MessagesCache {
this.messageDeletionExecutorService = messageDeletionExecutorService; this.messageDeletionExecutorService = messageDeletionExecutorService;
this.messageDeletionScheduler = Schedulers.fromExecutorService(messageDeletionExecutorService, "messageDeletion"); this.messageDeletionScheduler = Schedulers.fromExecutorService(messageDeletionExecutorService, "messageDeletion");
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.insertScript = insertScript; this.insertScript = insertScript;
this.insertMrmScript = insertMrmScript; this.insertMrmScript = insertMrmScript;
this.removeByGuidScript = removeByGuidScript; this.removeByGuidScript = removeByGuidScript;
@ -371,8 +357,6 @@ public class MessagesCache {
messageMono = Mono.just(message.toBuilder().clearSharedMrmKey().build()); messageMono = Mono.just(message.toBuilder().clearSharedMrmKey().build());
skippedStaleEphemeralMrmCounter.increment(); skippedStaleEphemeralMrmCounter.increment();
} else { } else {
// mrm views phase 3: fetch shared MRM data -- internally depends on dynamic config that
// enables using it (the stored messages still always have `content` set upstream)
messageMono = getMessageWithSharedMrmData(message, destinationDevice); messageMono = getMessageWithSharedMrmData(message, destinationDevice);
} }
@ -393,27 +377,18 @@ public class MessagesCache {
/** /**
* Returns the given message with its shared MRM data. * Returns the given message with its shared MRM data.
*
* @see DynamicMessagesConfiguration#useSharedMrmData()
*/ */
private Mono<MessageProtos.Envelope> getMessageWithSharedMrmData(final MessageProtos.Envelope mrmMessage, private Mono<MessageProtos.Envelope> getMessageWithSharedMrmData(final MessageProtos.Envelope mrmMessage,
final byte destinationDevice) { final byte destinationDevice) {
assert mrmMessage.hasSharedMrmKey(); assert mrmMessage.hasSharedMrmKey();
// mrm views phase 2: messages have content
if (!mrmMessage.hasContent()) {
mrmPhaseTwoMissingContentCounter.increment();
}
final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME);
final byte[] key = mrmMessage.getSharedMrmKey().toByteArray(); final byte[] key = mrmMessage.getSharedMrmKey().toByteArray();
final byte[] sharedMrmViewKey = MessagesCache.getSharedMrmViewKey( final byte[] sharedMrmViewKey = MessagesCache.getSharedMrmViewKey(
// the message might be addressed to the account's PNI, so use the service ID from the envelope // the message might be addressed to the account's PNI, so use the service ID from the envelope
ServiceIdentifier.valueOf(mrmMessage.getDestinationServiceId()), destinationDevice); ServiceIdentifier.valueOf(mrmMessage.getDestinationServiceId()), destinationDevice);
final Mono<MessageProtos.Envelope> messageFromRedisMono = Mono.from(redisCluster.withBinaryClusterReactive( return Mono.from(redisCluster.withBinaryClusterReactive(
conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey) conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey)
.collectList() .collectList()
.publishOn(messageDeliveryScheduler))) .publishOn(messageDeliveryScheduler)))
@ -444,18 +419,6 @@ public class MessagesCache {
return Mono.empty(); return Mono.empty();
}) })
.share(); .share();
if (mrmMessage.hasContent()) {
experiment.compareMonoResult(mrmMessage.toBuilder().clearSharedMrmKey().build(), messageFromRedisMono);
}
if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().useSharedMrmData()
|| !mrmMessage.hasContent()) {
return messageFromRedisMono;
}
// if fetching or using shared data is disabled, fallback to just() with the existing message
return Mono.just(mrmMessage.toBuilder().clearSharedMrmKey().build());
} }
/** /**
@ -529,8 +492,6 @@ public class MessagesCache {
.concatMap(message -> { .concatMap(message -> {
final Mono<MessageProtos.Envelope> messageMono; final Mono<MessageProtos.Envelope> messageMono;
if (message.hasSharedMrmKey()) { if (message.hasSharedMrmKey()) {
// mrm views phase 2: fetch shared MRM data -- internally depends on dynamic config that
// enables fetching and using it (the stored messages still always have `content` set upstream)
messageMono = getMessageWithSharedMrmData(message, destinationDevice); messageMono = getMessageWithSharedMrmData(message, destinationDevice);
} else { } else {
messageMono = Mono.just(message); messageMono = Mono.just(message);

View File

@ -212,7 +212,7 @@ record CommandDependencies(
storageServiceExecutor, storageServiceRetryExecutor, configuration.getSecureStorageServiceConfiguration()); storageServiceExecutor, storageServiceRetryExecutor, configuration.getSecureStorageServiceConfiguration());
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor); DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor);
MessagesCache messagesCache = new MessagesCache(messagesCluster, MessagesCache messagesCache = new MessagesCache(messagesCluster,
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager); messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC());
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient,
configuration.getDynamoDbTables().getReportMessage().getTableName(), configuration.getDynamoDbTables().getReportMessage().getTableName(),

View File

@ -81,7 +81,7 @@ class MessagePersisterIntegrationTest {
final AccountsManager accountsManager = mock(AccountsManager.class); final AccountsManager accountsManager = mock(AccountsManager.class);
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC(), dynamicConfigurationManager); messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC());
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class), messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
messageDeletionExecutorService); messageDeletionExecutorService);

View File

@ -99,7 +99,7 @@ class MessagePersisterTest {
resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor(); resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager); messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY, 1); dynamicConfigurationManager, PERSIST_DELAY, 1);

View File

@ -64,12 +64,8 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
@ -98,26 +94,17 @@ class MessagesCacheTest {
private Scheduler messageDeliveryScheduler; private Scheduler messageDeliveryScheduler;
private MessagesCache messagesCache; private MessagesCache messagesCache;
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private DynamicConfiguration dynamicConfiguration;
private static final UUID DESTINATION_UUID = UUID.randomUUID(); private static final UUID DESTINATION_UUID = UUID.randomUUID();
private static final byte DESTINATION_DEVICE_ID = 7; private static final byte DESTINATION_DEVICE_ID = 7;
@BeforeEach @BeforeEach
void setUp() throws Exception { void setUp() throws Exception {
dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(
new DynamicMessagesConfiguration(true));
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
sharedExecutorService = Executors.newSingleThreadExecutor(); sharedExecutorService = Executors.newSingleThreadExecutor();
resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor(); resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager); messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
} }
@AfterEach @AfterEach
@ -321,7 +308,7 @@ class MessagesCacheTest {
} }
final MessagesCache messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), final MessagesCache messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, sharedExecutorService, cacheClock, dynamicConfigurationManager); messageDeliveryScheduler, sharedExecutorService, cacheClock);
final List<MessageProtos.Envelope> actualMessages = Flux.from( final List<MessageProtos.Envelope> actualMessages = Flux.from(
messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID)) messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID))
@ -429,12 +416,9 @@ class MessagesCacheTest {
assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.getFirst())); assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.getFirst()));
} }
@CartesianTest @ParameterizedTest
void testMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean sharedMrmKeyPresent, @ValueSource(booleans = {true, false})
@CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData) { void testMultiRecipientMessage(final boolean sharedMrmKeyPresent) {
when(dynamicConfiguration.getMessagesConfiguration())
.thenReturn(new DynamicMessagesConfiguration(useSharedMrmData));
final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID()); final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1; final byte deviceId = 1;
@ -453,10 +437,8 @@ class MessagesCacheTest {
.toBuilder() .toBuilder()
// clear some things added by the helper // clear some things added by the helper
.clearServerGuid() .clearServerGuid()
// mrm views phase 2: messages have content
.setContent(
ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(destinationServiceId.toLibsignal()))))
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.clearContent()
.build(); .build();
messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message); messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message);
@ -464,7 +446,7 @@ class MessagesCacheTest {
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey))); .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));
final List<MessageProtos.Envelope> messages = get(destinationServiceId.uuid(), deviceId, 1); final List<MessageProtos.Envelope> messages = get(destinationServiceId.uuid(), deviceId, 1);
if (useSharedMrmData && !sharedMrmKeyPresent) { if (!sharedMrmKeyPresent) {
assertTrue(messages.isEmpty()); assertTrue(messages.isEmpty());
} else { } else {
@ -495,65 +477,7 @@ class MessagesCacheTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = {true, false}) @ValueSource(booleans = {true, false})
void testMultiRecipientMessagePhase2MissingContentSafeguard(final boolean useSharedMrmData) { void testGetMessagesToPersist(final boolean sharedMrmKeyPresent) {
when(dynamicConfiguration.getMessagesConfiguration())
.thenReturn(new DynamicMessagesConfiguration(useSharedMrmData));
final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1;
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId);
final byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm);
final UUID guid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(guid, destinationServiceId, true)
.toBuilder()
// clear some things added by the helper
.clearServerGuid()
// mrm views phase 2: there is a safeguard against missing content, even if the dynamic configuration
// is to not fetch or use shared MRM data
.clearContent()
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build();
messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message);
assertEquals(1, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));
final List<MessageProtos.Envelope> messages = get(destinationServiceId.uuid(), deviceId, 1);
assertEquals(1, messages.size());
assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid()));
assertFalse(messages.getFirst().hasSharedMrmKey());
final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients()
.get(destinationServiceId.toLibsignal());
assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray());
final Optional<RemovedMessage> removedMessage = messagesCache.remove(destinationServiceId.uuid(), deviceId, guid)
.join();
assertTrue(removedMessage.isPresent());
assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString()));
assertTrue(get(destinationServiceId.uuid(), deviceId, 1).isEmpty());
// updating the shared MRM data is purely async, so we just wait for it
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> {
boolean exists;
do {
exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey));
} while (exists);
}, "Shared MRM data should be deleted asynchronously");
}
@CartesianTest
void testGetMessagesToPersist(@CartesianTest.Values(booleans = {true, false}) final boolean sharedMrmKeyPresent,
@CartesianTest.Values(booleans = {true, false}) final boolean useSharedMrmData) {
when(dynamicConfiguration.getMessagesConfiguration())
.thenReturn(new DynamicMessagesConfiguration(useSharedMrmData));
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(destinationUuid); final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(destinationUuid);
@ -578,24 +502,22 @@ class MessagesCacheTest {
final MessageProtos.Envelope mrmMessage = generateRandomMessage(mrmMessageGuid, destinationServiceId, true) final MessageProtos.Envelope mrmMessage = generateRandomMessage(mrmMessageGuid, destinationServiceId, true)
.toBuilder() .toBuilder()
// clear some things added by the helper // clear some things added by the helper
.clearServerGuid() .clearContent()
// mrm views phase 2: messages have content
.setContent(
ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(new ServiceId.Aci(destinationUuid)))))
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build(); .build();
messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage); messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage);
final List<MessageProtos.Envelope> messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100); final List<MessageProtos.Envelope> messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100);
if (useSharedMrmData && !sharedMrmKeyPresent) { if (!sharedMrmKeyPresent) {
assertEquals(1, messages.size()); assertEquals(1, messages.size());
} else { } else {
assertEquals(2, messages.size()); assertEquals(2, messages.size());
assertEquals(mrmMessage.toBuilder(). assertEquals(mrmMessage.toBuilder()
clearSharedMrmKey(). .clearSharedMrmKey()
setServerGuid(mrmMessageGuid.toString()) .setContent(ByteString.copyFrom(
mrm.messageForRecipient(mrm.getRecipients().get(destinationServiceId.toLibsignal()))))
.build(), .build(),
messages.getLast()); messages.getLast());
} }
@ -636,7 +558,7 @@ class MessagesCacheTest {
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(mockCluster, messageDeliveryScheduler, messagesCache = new MessagesCache(mockCluster, messageDeliveryScheduler,
Executors.newSingleThreadExecutor(), Clock.systemUTC(), mock(DynamicConfigurationManager.class)); Executors.newSingleThreadExecutor(), Clock.systemUTC());
} }
@AfterEach @AfterEach
@ -822,8 +744,7 @@ class MessagesCacheTest {
} }
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid,
final ServiceIdentifier destinationServiceId, final boolean sealedSender, final ServiceIdentifier destinationServiceId, final boolean sealedSender, final long timestamp) {
final long timestamp) {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setClientTimestamp(timestamp) .setClientTimestamp(timestamp)
.setServerTimestamp(timestamp) .setServerTimestamp(timestamp)

View File

@ -99,7 +99,7 @@ class WebSocketConnectionIntegrationTest {
dynamicConfigurationManager = mock(DynamicConfigurationManager.class); dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager); messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(), messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.MESSAGES.tableName(), Duration.ofDays(7), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.MESSAGES.tableName(), Duration.ofDays(7),
sharedExecutorService); sharedExecutorService);