From e3160bc717fa1d033009a4757caeab408e0c2a7d Mon Sep 17 00:00:00 2001 From: ravi-signal <99042880+ravi-signal@users.noreply.github.com> Date: Mon, 10 Mar 2025 16:09:05 -0500 Subject: [PATCH] Add a dedicated size estimation method to MessagesCache --- .../storage/MessagePersister.java | 5 +- .../textsecuregcm/storage/MessagesCache.java | 88 +++++----- .../storage/MessagesCacheTest.java | 150 ++++++------------ 3 files changed, 97 insertions(+), 146 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index 9c5340993..115ea5d9c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -250,10 +250,7 @@ public class MessagePersister implements Managed { final Device device = maybeDevice.get(); // Calculate how many bytes we should trim - final long cachedMessageBytes = Flux - .from(messagesCache.getMessagesToPersistReactive(aci, deviceId, CACHE_PAGE_SIZE)) - .reduce(0, (acc, envelope) -> acc + envelope.getSerializedSize()) - .block(); + final long cachedMessageBytes = messagesCache.estimatePersistedQueueSizeBytes(aci, deviceId).join(); final double extraRoomRatio = this.dynamicConfigurationManager.getConfiguration() .getMessagePersisterConfiguration() .getTrimOversizedQueueExtraRoomRatio(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 87f329d0d..b446fd3de 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -50,7 +50,6 @@ import org.whispersystems.textsecuregcm.util.RedisClusterUtil; import reactor.core.observability.micrometer.Micrometer; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; @@ -534,11 +533,13 @@ public class MessagesCache { }); } - Flux getMessagesToPersistReactive(final UUID accountUuid, final byte destinationDevice, - final int pageSize) { - final Timer.Sample sample = Timer.start(); - - + /** + * Estimate the size of the cached queue if it were to be persisted + * @param accountUuid The account identifier + * @param destinationDevice The destination device id + * @return A future that completes with the approximate size of stored messages that need to be persisted + */ + CompletableFuture estimatePersistedQueueSizeBytes(final UUID accountUuid, final byte destinationDevice) { final Function, Mono>>> getNextPage = (Optional start) -> Mono.fromCompletionStage(() -> redisCluster.withBinaryCluster(connection -> connection.async().zrangebyscoreWithScores( @@ -546,12 +547,8 @@ public class MessagesCache { Range.from( start.map(Range.Boundary::excluding).orElse(Range.Boundary.unbounded()), Range.Boundary.unbounded()), - Limit.from(pageSize)))); - - final Sinks.Many staleEphemeralMessages = Sinks.many().unicast().onBackpressureBuffer(); - final Sinks.Many staleMrmMessages = Sinks.many().unicast().onBackpressureBuffer(); - - final Flux messagesToPersist = getNextPage.apply(Optional.empty()) + Limit.from(PAGE_SIZE)))); + final Flux allSerializedMessages = getNextPage.apply(Optional.empty()) .expand(scoredValues -> { if (scoredValues.isEmpty()) { return Mono.empty(); @@ -559,7 +556,45 @@ public class MessagesCache { long lastTimestamp = (long) scoredValues.getLast().getScore(); return getNextPage.apply(Optional.of(lastTimestamp)); }) - .concatMap(scoredValues -> Flux.fromStream(scoredValues.stream().map(ScoredValue::getValue))) + .concatMap(scoredValues -> Flux.fromStream(scoredValues.stream().map(ScoredValue::getValue))); + + return parseAndFetchMrms(allSerializedMessages, destinationDevice) + .filter(Predicate.not(envelope -> envelope.getEphemeral() || isStaleMrmMessage(envelope))) + .reduce(0L, (acc, envelope) -> acc + envelope.getSerializedSize()) + .toFuture(); + } + + List getMessagesToPersist(final UUID accountUuid, final byte destinationDevice, + final int limit) { + + final Timer.Sample sample = Timer.start(); + + final List messages = redisCluster.withBinaryCluster(connection -> + connection.sync().zrange(getMessageQueueKey(accountUuid, destinationDevice), 0, limit)); + + final Flux allMessages = parseAndFetchMrms(Flux.fromIterable(messages), destinationDevice); + + final Flux messagesToPersist = allMessages + .filter(Predicate.not(envelope -> + envelope.getEphemeral() || isStaleMrmMessage(envelope))); + + final Flux ephemeralMessages = allMessages + .filter(MessageProtos.Envelope::getEphemeral); + discardStaleMessages(accountUuid, destinationDevice, ephemeralMessages, staleEphemeralMessagesCounter, "ephemeral"); + + final Flux staleMrmMessages = allMessages.filter(MessagesCache::isStaleMrmMessage) + // clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data + .map(envelope -> envelope.toBuilder().clearSharedMrmKey().build()); + discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm"); + + return messagesToPersist + .collectList() + .doOnTerminate(() -> sample.stop(getMessagesTimer)) + .block(Duration.ofSeconds(5)); + } + + private Flux parseAndFetchMrms(final Flux serializedMessages, final byte destinationDevice) { + return serializedMessages .mapNotNull(message -> { try { return MessageProtos.Envelope.parseFrom(message); @@ -577,36 +612,9 @@ public class MessagesCache { } return messageMono; - }) - .doOnNext(envelope -> { - if (envelope.getEphemeral()) { - staleEphemeralMessages.tryEmitNext(envelope).orThrow(); - } else if (isStaleMrmMessage(envelope)) { - // clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data - staleMrmMessages.tryEmitNext(envelope.toBuilder().clearSharedMrmKey().build()).orThrow(); - } - }) - .filter(Predicate.not(envelope -> envelope.getEphemeral() || isStaleMrmMessage(envelope))); - - discardStaleMessages(accountUuid, destinationDevice, staleEphemeralMessages.asFlux(), staleEphemeralMessagesCounter, "ephemeral"); - discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages.asFlux(), staleMrmMessagesCounter, "mrm"); - - return messagesToPersist - .doFinally(signal -> { - sample.stop(getMessagesTimer); - staleEphemeralMessages.tryEmitComplete(); - staleMrmMessages.tryEmitComplete(); }); } - List getMessagesToPersist(final UUID accountUuid, final byte destinationDevice, - final int limit) { - return getMessagesToPersistReactive(accountUuid, destinationDevice, limit) - .take(limit) - .collectList() - .block(Duration.ofSeconds(5)); - } - public CompletableFuture clear(final UUID destinationUuid) { return CompletableFuture.allOf( Device.ALL_POSSIBLE_DEVICE_IDS.stream() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index bc9a0264b..f03795b47 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -492,118 +492,64 @@ class MessagesCacheTest { } @Test - void testMessagesToPersistPagination() throws InterruptedException { + void testEstimatePersistedQueueSize() { final UUID destinationUuid = UUID.randomUUID(); final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid); final byte deviceId = 1; - final byte[] queueKey = MessagesCache.getMessageQueueKey(destinationUuid, deviceId); - final List messages = IntStream.range(0, 60) - .mapToObj(i -> switch (i % 3) { - // Stale MRM - case 0 -> generateRandomMessage(UUID.randomUUID(), serviceId, true) + // Should count all non-ephemeral, non-stale message bytes + long expectedQueueSize = 0L; + for (int i = 0; i < 400; i++) { + final MessageProtos.Envelope messageToInsert = switch (i % 4) { + // An MRM message + case 0 -> { + + // First generate a random MRM message + final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(serviceId, deviceId); + final SealedSenderMultiRecipientMessage.Recipient recepient = mrm.getRecipients() + .get(serviceId.toLibsignal()); + + // Calculate the size of a message that has the shared content in it + final MessageProtos.Envelope message = generateRandomMessage(UUID.randomUUID(), serviceId, true) + .toBuilder() + .setContent(ByteString.copyFrom(mrm.messageForRecipient(recepient))) + .build(); + expectedQueueSize += message.getSerializedSize(); + byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join(); + + // Insert the MRM message without the content + yield message + .toBuilder() + .clearContent() + .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) + .build(); + } + + // A stale MRM message + case 1 -> + generateRandomMessage(UUID.randomUUID(), serviceId, true) .toBuilder() // clear some things added by the helper .clearContent() .setSharedMrmKey(MessagesCache.STALE_MRM_KEY) .build(); - // ephemeral message - case 1 -> generateRandomMessage(UUID.randomUUID(), serviceId, true) - .toBuilder() - .setEphemeral(true).build(); - // standard message - case 2 -> generateRandomMessage(UUID.randomUUID(), serviceId, true); - default -> throw new IllegalStateException(); - }) - .toList(); - for (MessageProtos.Envelope envelope : messages) { - messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join(); + + // An ephemeral message + case 2 -> generateRandomMessage(UUID.randomUUID(), serviceId, true).toBuilder().setEphemeral(true).build(); + + // A standardard message + case 3 -> { + final MessageProtos.Envelope message = generateRandomMessage(UUID.randomUUID(), serviceId, true); + expectedQueueSize += message.getSerializedSize(); + yield message; + } + + default -> throw new IllegalStateException(); + }; + messagesCache.insert(UUID.fromString(messageToInsert.getServerGuid()), destinationUuid, deviceId, messageToInsert).join(); } - - final List expectedGuidsToPersist = messages.stream() - .filter(envelope -> !envelope.getEphemeral() && !envelope.hasSharedMrmKey()) - .map(envelope -> UUID.fromString(envelope.getServerGuid())) - .limit(10) - .collect(Collectors.toList()); - - // Fetch 10 messages which should discard 20 ephemeral stale messages, and leave the rest - final List actual = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 10).stream() - .map(envelope -> UUID.fromString(envelope.getServerGuid())) - .toList(); - assertIterableEquals(expectedGuidsToPersist, actual); - - // Eventually, the 20 ephemeral/stale messages should be discarded - assertTimeoutPreemptively(Duration.ofSeconds(1), () -> { - while (REDIS_CLUSTER_EXTENSION.getRedisCluster() - .withBinaryCluster(conn -> conn.sync().zcard(queueKey)) != 40) { - Thread.sleep(1); - } - }, "Ephemeral and stale messages should be deleted asynchronously"); - - // Let all pending tasks finish and make sure no more stale messages have been deleted - sharedExecutorService.shutdown(); - sharedExecutorService.awaitTermination(1, TimeUnit.SECONDS); - assertEquals(REDIS_CLUSTER_EXTENSION.getRedisCluster() - .withBinaryCluster(conn -> conn.sync().zcard(queueKey)).longValue(), 40); - } - - @Test - void testMessagesToPersistReactive() { - final UUID destinationUuid = UUID.randomUUID(); - final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid); - final byte deviceId = 1; - final byte[] messageQueueKey = MessagesCache.getMessageQueueKey(destinationUuid, deviceId); - - final List messages = IntStream.range(0, 200) - .mapToObj(i -> { - if (i % 3 == 0) { - final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(serviceId, deviceId); - byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join(); - return generateRandomMessage(UUID.randomUUID(), serviceId, true) - .toBuilder() - // clear some things added by the helper - .clearContent() - .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) - .build(); - } else if (i % 13 == 0) { - return generateRandomMessage(UUID.randomUUID(), serviceId, true).toBuilder().setEphemeral(true).build(); - } else if (i % 17 == 0) { - return generateRandomMessage(UUID.randomUUID(), serviceId, true) - .toBuilder() - // clear some things added by the helper - .clearContent() - .setSharedMrmKey(MessagesCache.STALE_MRM_KEY) - .build(); - } else { - return generateRandomMessage(UUID.randomUUID(), serviceId, true); - } - }) - .toList(); - - for (MessageProtos.Envelope envelope : messages) { - messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join(); - } - - final List expected = messages.stream() - .filter(envelope -> !envelope.getEphemeral() && - (envelope.getSharedMrmKey() == null || !envelope.getSharedMrmKey().equals(MessagesCache.STALE_MRM_KEY))) - .toList(); - final List actual = messagesCache - .getMessagesToPersistReactive(destinationUuid, deviceId, 7).collectList().block(); - - assertEquals(expected.size(), actual.size()); - for (int i = 0; i < actual.size(); i++) { - assertNotNull(actual.get(i).getContent()); - assertEquals(actual.get(i).getServerGuid(), expected.get(i).getServerGuid()); - } - - // Ephemeral messages and stale MRM messages are asynchronously deleted, but eventually they should all be removed - assertTimeoutPreemptively(Duration.ofSeconds(1), () -> { - while (REDIS_CLUSTER_EXTENSION.getRedisCluster() - .withBinaryCluster(conn -> conn.sync().zcard(messageQueueKey)) != expected.size()) { - Thread.sleep(1); - } - }, "Ephemeral and stale messages should be deleted asynchronously"); + long actualQueueSize = messagesCache.estimatePersistedQueueSizeBytes(destinationUuid, deviceId).join(); + assertEquals(expectedQueueSize, actualQueueSize); } @ParameterizedTest