From 09b50383d77b8b8ec56ab1b74b022db8019cfec6 Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Fri, 28 Feb 2025 11:11:42 -0600 Subject: [PATCH] Automatically trim primary queue when cache cannot be persisted --- .../DynamicMessagePersisterConfiguration.java | 20 ++++ .../storage/MessagePersister.java | 100 ++++++++++++++--- .../textsecuregcm/storage/MessagesCache.java | 47 ++++++-- .../storage/MessagesDynamoDb.java | 2 - .../storage/MessagePersisterTest.java | 104 ++++++++++++++++-- .../storage/MessagesCacheTest.java | 42 +++++++ 6 files changed, 280 insertions(+), 35 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagePersisterConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagePersisterConfiguration.java index d74cac20d..cfe079bbc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagePersisterConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagePersisterConfiguration.java @@ -6,13 +6,33 @@ package org.whispersystems.textsecuregcm.configuration.dynamic; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.annotations.VisibleForTesting; public class DynamicMessagePersisterConfiguration { @JsonProperty private boolean persistenceEnabled = true; + /** + * If we have to trim a client's persisted queue to make room to persist from Redis to DynamoDB, how much extra room should we make + */ + @JsonProperty + private double trimOversizedQueueExtraRoomRatio = 1.5; + + public DynamicMessagePersisterConfiguration() {} + + @VisibleForTesting + public DynamicMessagePersisterConfiguration(final boolean persistenceEnabled, final double trimOversizedQueueExtraRoomRatio) { + this.persistenceEnabled = persistenceEnabled; + this.trimOversizedQueueExtraRoomRatio = trimOversizedQueueExtraRoomRatio; + } + public boolean isPersistenceEnabled() { return persistenceEnabled; } + + public double getTrimOversizedQueueExtraRoomRatio() { + return trimOversizedQueueExtraRoomRatio; + } + } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index 56f4d0621..9c5340993 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -18,11 +18,17 @@ import java.time.Instant; import java.util.List; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.util.Util; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException; public class MessagePersister implements Managed { @@ -30,17 +36,21 @@ public class MessagePersister implements Managed { private final MessagesCache messagesCache; private final MessagesManager messagesManager; private final AccountsManager accountsManager; + private final DynamicConfigurationManager dynamicConfigurationManager; private final Duration persistDelay; private final Thread[] workerThreads; private volatile boolean running; + private static final String OVERSIZED_QUEUE_COUNTER_NAME = name(MessagePersister.class, "persistQueueOversized"); + private final Timer getQueuesTimer = Metrics.timer(name(MessagePersister.class, "getQueues")); private final Timer persistQueueTimer = Metrics.timer(name(MessagePersister.class, "persistQueue")); private final Counter persistQueueExceptionMeter = Metrics.counter( name(MessagePersister.class, "persistQueueException")); - private final Counter oversizedQueueCounter = Metrics.counter(name(MessagePersister.class, "persistQueueOversized")); + private static final Counter trimmedMessageCounter = Metrics.counter(name(MessagePersister.class, "trimmedMessage")); + private static final Counter trimmedMessageBytesCounter = Metrics.counter(name(MessagePersister.class, "trimmedMessageBytes")); private final DistributionSummary queueCountDistributionSummery = DistributionSummary.builder( name(MessagePersister.class, "queueCount")) .publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999) @@ -54,6 +64,7 @@ public class MessagePersister implements Managed { static final int QUEUE_BATCH_LIMIT = 100; static final int MESSAGE_BATCH_LIMIT = 100; + static final int CACHE_PAGE_SIZE = 100; private static final long EXCEPTION_PAUSE_MILLIS = Duration.ofSeconds(3).toMillis(); @@ -66,12 +77,12 @@ public class MessagePersister implements Managed { final AccountsManager accountsManager, final DynamicConfigurationManager dynamicConfigurationManager, final Duration persistDelay, - final int dedicatedProcessWorkerThreadCount - ) { + final int dedicatedProcessWorkerThreadCount) { this.messagesCache = messagesCache; this.messagesManager = messagesManager; this.accountsManager = accountsManager; + this.dynamicConfigurationManager = dynamicConfigurationManager; this.persistDelay = persistDelay; this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount]; @@ -159,7 +170,10 @@ public class MessagePersister implements Managed { messagesCache.addQueueToPersist(accountUuid, deviceId); - Util.sleep(EXCEPTION_PAUSE_MILLIS); + if (!(e instanceof MessagePersistenceException)) { + // Pause after unexpected exceptions + Util.sleep(EXCEPTION_PAUSE_MILLIS); + } } } @@ -204,8 +218,19 @@ public class MessagePersister implements Managed { queueSizeDistributionSummery.record(messageCount); } catch (ItemCollectionSizeLimitExceededException e) { - oversizedQueueCounter.increment(); - maybeUnlink(account, deviceId); // may throw, in which case we'll retry later by the usual mechanism + final boolean isPrimary = deviceId == Device.PRIMARY_ID; + Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment(); + // may throw, in which case we'll retry later by the usual mechanism + if (isPrimary) { + logger.warn("Failed to persist queue {}::{} due to overfull queue; will trim oldest messages", + account.getUuid(), deviceId); + trimQueue(account, deviceId); + throw new MessagePersistenceException("Could not persist due to an overfull queue. Trimmed primary queue, a subsequent retry may succeed"); + } else { + logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", account.getUuid(), + deviceId); + accountsManager.removeDevice(account, deviceId).join(); + } } finally { messagesCache.unlockQueueForPersistence(accountUuid, deviceId); sample.stop(persistQueueTimer); @@ -213,13 +238,62 @@ public class MessagePersister implements Managed { } - @VisibleForTesting - void maybeUnlink(final Account account, byte destinationDeviceId) throws MessagePersistenceException { - if (destinationDeviceId == Device.PRIMARY_ID) { - throw new MessagePersistenceException("primary device has a full queue"); - } + private void trimQueue(final Account account, byte deviceId) { + final UUID aci = account.getIdentifier(IdentityType.ACI); - logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", account.getUuid(), destinationDeviceId); - accountsManager.removeDevice(account, destinationDeviceId).join(); + final Optional maybeDevice = account.getDevice(deviceId); + if (maybeDevice.isEmpty()) { + logger.warn("Not deleting messages for overfull queue {}::{}, deviceId {} does not exist", + aci, deviceId, deviceId); + return; + } + final Device device = maybeDevice.get(); + + // Calculate how many bytes we should trim + final long cachedMessageBytes = Flux + .from(messagesCache.getMessagesToPersistReactive(aci, deviceId, CACHE_PAGE_SIZE)) + .reduce(0, (acc, envelope) -> acc + envelope.getSerializedSize()) + .block(); + final double extraRoomRatio = this.dynamicConfigurationManager.getConfiguration() + .getMessagePersisterConfiguration() + .getTrimOversizedQueueExtraRoomRatio(); + final long targetDeleteBytes = Math.round(cachedMessageBytes * extraRoomRatio); + + final AtomicLong oldestMessage = new AtomicLong(0L); + final AtomicLong newestMessage = new AtomicLong(0L); + final AtomicLong bytesDeleted = new AtomicLong(0L); + + // Iterate from the oldest message until we've removed targetDeleteBytes + final Pair outcomes = Flux.from(messagesManager.getMessagesForDeviceReactive(aci, device, false)) + .concatMap(envelope -> { + if (bytesDeleted.getAndAdd(envelope.getSerializedSize()) >= targetDeleteBytes) { + return Mono.just(Optional.empty()); + } + oldestMessage.compareAndSet(0L, envelope.getServerTimestamp()); + newestMessage.set(envelope.getServerTimestamp()); + return Mono.just(Optional.of(envelope)); + }) + .takeWhile(Optional::isPresent) + .flatMap(maybeEnvelope -> { + final MessageProtos.Envelope envelope = maybeEnvelope.get(); + trimmedMessageCounter.increment(); + trimmedMessageBytesCounter.increment(envelope.getSerializedSize()); + return Mono + .fromCompletionStage(() -> messagesManager + .delete(aci, device, UUID.fromString(envelope.getServerGuid()), envelope.getServerTimestamp())) + .retryWhen(Retry.backoff(5, Duration.ofSeconds(1))) + .map(Optional::isPresent); + }) + .reduce(Pair.of(0L, 0L), (acc, deleted) -> deleted + ? Pair.of(acc.getLeft() + 1, acc.getRight()) + : Pair.of(acc.getLeft(), acc.getRight() + 1)) + .block(); + + logger.warn( + "Finished trimming {}:{}. Oldest message = {}, newest message = {}. Attempted to delete {} persisted bytes to make room for {} cached message bytes. Delete outcomes: {} present, {} missing.", + aci, deviceId, + Instant.ofEpochMilli(oldestMessage.get()), Instant.ofEpochMilli(newestMessage.get()), + targetDeleteBytes, cachedMessageBytes, + outcomes.getLeft(), outcomes.getRight()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 13b117b9a..e27ea9243 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -10,6 +10,9 @@ import static com.codahale.metrics.MetricRegistry.name; import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; +import io.lettuce.core.Limit; +import io.lettuce.core.Range; +import io.lettuce.core.ScoredValue; import io.lettuce.core.ZAddArgs; import io.lettuce.core.cluster.SlotHash; import io.micrometer.core.instrument.Counter; @@ -31,6 +34,7 @@ import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.function.Function; import java.util.function.Predicate; import org.reactivestreams.Publisher; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; @@ -528,15 +532,28 @@ public class MessagesCache { }); } - List getMessagesToPersist(final UUID accountUuid, final byte destinationDevice, - final int limit) { - + Flux getMessagesToPersistReactive(final UUID accountUuid, final byte destinationDevice, + final int pageSize) { final Timer.Sample sample = Timer.start(); - final List messages = redisCluster.withBinaryCluster(connection -> - connection.sync().zrange(getMessageQueueKey(accountUuid, destinationDevice), 0, limit)); + final Function>>> getNextPage = (Long start) -> + Mono.fromCompletionStage(() -> redisCluster.withBinaryCluster(connection -> + connection.async().zrangebyscoreWithScores( + getMessageQueueKey(accountUuid, destinationDevice), + Range.from( + Range.Boundary.excluding(start), + Range.Boundary.unbounded()), + Limit.from(pageSize)))); - final Flux allMessages = Flux.fromIterable(messages) + final Flux allMessages = getNextPage.apply(0L) + .expand(scoredValues -> { + if (scoredValues.isEmpty()) { + return Mono.empty(); + } + long lastTimestamp = (long) scoredValues.getLast().getScore(); + return getNextPage.apply(lastTimestamp); + }) + .concatMap(scoredValues -> Flux.fromStream(scoredValues.stream().map(ScoredValue::getValue))) .mapNotNull(message -> { try { return MessageProtos.Envelope.parseFrom(message); @@ -554,7 +571,15 @@ public class MessagesCache { } return messageMono; - }); + }) + .publish() + // We expect exactly three subscribers to this base flux: + // 1. the caller of the method + // 2. an internal processes to discard stale ephemeral messages + // 3. an internal process to discard stale MRM messages + // The discard subscribers will subscribe immediately, but we don’t want to do any work if the + // caller never subscribes + .autoConnect(3); final Flux messagesToPersist = allMessages .filter(Predicate.not(envelope -> @@ -570,8 +595,14 @@ public class MessagesCache { discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm"); return messagesToPersist + .doOnTerminate(() -> sample.stop(getMessagesTimer)); + } + + List getMessagesToPersist(final UUID accountUuid, final byte destinationDevice, + final int limit) { + return getMessagesToPersistReactive(accountUuid, destinationDevice, limit) + .take(limit) .collectList() - .doOnTerminate(() -> sample.stop(getMessagesTimer)) .block(Duration.ofSeconds(5)); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java index 3f0992c28..5a9bf8937 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java @@ -11,7 +11,6 @@ import static io.micrometer.core.instrument.Metrics.timer; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.protobuf.InvalidProtocolBufferException; - import io.micrometer.core.instrument.Timer; import java.nio.ByteBuffer; import java.time.Duration; @@ -24,7 +23,6 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.function.Predicate; - import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index 58a7e6417..afff5a009 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -10,12 +10,16 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNotNull; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.util.MockUtils.exactly; @@ -26,6 +30,7 @@ import java.nio.charset.StandardCharsets; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -34,8 +39,10 @@ import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -45,9 +52,12 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagePersisterConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; +import reactor.core.publisher.Flux; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException; @@ -75,6 +85,8 @@ class MessagePersisterTest { private static final Duration PERSIST_DELAY = Duration.ofMinutes(5); + private static final double EXTRA_ROOM_RATIO = 2.0; + @BeforeEach void setUp() throws Exception { @@ -84,16 +96,21 @@ class MessagePersisterTest { messagesDynamoDb = mock(MessagesDynamoDb.class); accountsManager = mock(AccountsManager.class); - destinationAccount = mock(Account.class);; + destinationAccount = mock(Account.class); when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(destinationAccount)); when(accountsManager.removeDevice(any(), anyByte())) .thenAnswer(invocation -> CompletableFuture.completedFuture(invocation.getArgument(0))); when(destinationAccount.getUuid()).thenReturn(DESTINATION_ACCOUNT_UUID); + when(destinationAccount.getIdentifier(IdentityType.ACI)).thenReturn(DESTINATION_ACCOUNT_UUID); when(destinationAccount.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER); when(destinationAccount.getDevice(DESTINATION_DEVICE_ID)).thenReturn(Optional.of(DESTINATION_DEVICE)); - when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); + + final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); + when(dynamicConfiguration.getMessagePersisterConfiguration()) + .thenReturn(new DynamicMessagePersisterConfiguration(true, EXTRA_ROOM_RATIO)); + when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); sharedExecutorService = Executors.newSingleThreadExecutor(); resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor(); @@ -285,6 +302,66 @@ class MessagePersisterTest { verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID); } + @Test + void testTrimOnFullPrimaryQueue() { + final byte[] queueName = MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, Device.PRIMARY_ID); + final Instant now = Instant.now(); + + final List cachedMessages = Stream.generate(() -> generateMessage( + DESTINATION_ACCOUNT_UUID, UUID.randomUUID(), now.getEpochSecond(), ThreadLocalRandom.current().nextInt(100))) + .limit(10) + .toList(); + final long cacheSize = cachedMessages.stream().mapToLong(MessageProtos.Envelope::getSerializedSize).sum(); + for (MessageProtos.Envelope envelope : cachedMessages) { + messagesCache.insert(UUID.fromString(envelope.getServerGuid()), DESTINATION_ACCOUNT_UUID, Device.PRIMARY_ID, envelope); + } + + final long expectedClearedBytes = (long) (cacheSize * EXTRA_ROOM_RATIO); + + final int persistedMessageCount = 100; + final List persistedMessages = new ArrayList<>(persistedMessageCount); + final List expectedClearedGuids = new ArrayList<>(); + long total = 0L; + for (int i = 0; i < 100; i++) { + final UUID guid = UUID.randomUUID(); + final MessageProtos.Envelope envelope = generateMessage(DESTINATION_ACCOUNT_UUID, guid, now.getEpochSecond(), 13); + persistedMessages.add(envelope); + if (total < expectedClearedBytes) { + total += envelope.getSerializedSize(); + expectedClearedGuids.add(guid); + } + } + + setNextSlotToPersist(SlotHash.getSlot(queueName)); + + final Device primary = mock(Device.class); + when(primary.getId()).thenReturn((byte) 1); + when(primary.isPrimary()).thenReturn(true); + when(primary.getFetchesMessages()).thenReturn(true); + when(destinationAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primary)); + + when(messagesManager.persistMessages(any(UUID.class), any(), anyList())) + .thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); + when(messagesManager.getMessagesForDeviceReactive(DESTINATION_ACCOUNT_UUID, primary, false)) + .thenReturn(Flux.concat( + Flux.fromIterable(persistedMessages), + Flux.fromIterable(cachedMessages))); + when(messagesManager.delete(any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + assertTimeoutPreemptively(Duration.ofSeconds(10), () -> + messagePersister.persistNextQueues(Clock.systemUTC().instant())); + + verify(messagesManager, times(expectedClearedGuids.size())) + .delete(eq(DESTINATION_ACCOUNT_UUID), eq(primary), argThat(expectedClearedGuids::contains), isNotNull()); + verify(messagesManager, never()).delete(any(), any(), argThat(guid -> !expectedClearedGuids.contains(guid)), any()); + + final List queuesToPersist = messagesCache.getQueuesToPersist(SlotHash.getSlot(queueName), + Clock.systemUTC().instant(), 1); + assertEquals(queuesToPersist.size(), 1); + assertEquals(queuesToPersist.getFirst(), new String(queueName, StandardCharsets.UTF_8)); + } + @Test void testFailedUnlinkOnFullQueueThrowsForRetry() { final String queueName = new String( @@ -348,20 +425,23 @@ class MessagePersisterTest { final Instant firstMessageTimestamp) { for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); - - final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder() - .setDestinationServiceId(accountUuid.toString()) - .setClientTimestamp(firstMessageTimestamp.toEpochMilli() + i) - .setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i) - .setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(256))) - .setType(MessageProtos.Envelope.Type.CIPHERTEXT) - .setServerGuid(messageGuid.toString()) - .build(); - + final MessageProtos.Envelope envelope = generateMessage( + accountUuid, messageGuid, firstMessageTimestamp.toEpochMilli() + i, 256); messagesCache.insert(messageGuid, accountUuid, deviceId, envelope).join(); } } + private MessageProtos.Envelope generateMessage(UUID accountUuid, UUID messageGuid, long messageTimestamp, int contentSize) { + return MessageProtos.Envelope.newBuilder() + .setDestinationServiceId(accountUuid.toString()) + .setClientTimestamp(messageTimestamp) + .setServerTimestamp(messageTimestamp) + .setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(contentSize))) + .setType(MessageProtos.Envelope.Type.CIPHERTEXT) + .setServerGuid(messageGuid.toString()) + .build(); + } + private void setNextSlotToPersist(final int nextSlot) { REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster( connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1))); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 06f7f36ec..18e3971d9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -54,6 +55,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; @@ -488,6 +490,46 @@ class MessagesCacheTest { }, "Shared MRM data should be deleted asynchronously"); } + @Test + void testMessagesToPersistReactive() { + final UUID destinationUuid = UUID.randomUUID(); + final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid); + final byte deviceId = 1; + + final List expected = IntStream.range(0, 100) + .mapToObj(i -> { + if (i % 3 == 0) { + final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(serviceId, deviceId); + byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join(); + return generateRandomMessage(UUID.randomUUID(), serviceId, true) + .toBuilder() + // clear some things added by the helper + .clearContent() + .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) + .build(); + } else if (i % 13 == 0) { + return generateRandomMessage(UUID.randomUUID(), serviceId, true).toBuilder().setEphemeral(true).build(); + } else { + return generateRandomMessage(UUID.randomUUID(), serviceId, true); + } + }) + .filter(envelope -> !envelope.getEphemeral()) + .toList(); + + for (MessageProtos.Envelope envelope : expected) { + messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join(); + } + + final List actual = messagesCache + .getMessagesToPersistReactive(destinationUuid, deviceId, 7).collectList().block(); + + assertEquals(expected.size(), actual.size()); + for (int i = 0; i < actual.size(); i++) { + assertNotNull(actual.get(i).getContent()); + assertEquals(actual.get(i).getServerGuid(), expected.get(i).getServerGuid()); + } + } + @ParameterizedTest @ValueSource(booleans = {true, false}) void testGetMessagesToPersist(final boolean sharedMrmKeyPresent) {