diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index e27ea9243..87f329d0d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -50,6 +50,7 @@ import org.whispersystems.textsecuregcm.util.RedisClusterUtil; import reactor.core.observability.micrometer.Micrometer; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; @@ -149,7 +150,8 @@ public class MessagesCache { static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); - private static final ByteString STALE_MRM_KEY = ByteString.copyFromUtf8("stale"); + @VisibleForTesting + static final ByteString STALE_MRM_KEY = ByteString.copyFromUtf8("stale"); @VisibleForTesting static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10); @@ -536,22 +538,26 @@ public class MessagesCache { final int pageSize) { final Timer.Sample sample = Timer.start(); - final Function>>> getNextPage = (Long start) -> + + final Function, Mono>>> getNextPage = (Optional start) -> Mono.fromCompletionStage(() -> redisCluster.withBinaryCluster(connection -> connection.async().zrangebyscoreWithScores( getMessageQueueKey(accountUuid, destinationDevice), Range.from( - Range.Boundary.excluding(start), + start.map(Range.Boundary::excluding).orElse(Range.Boundary.unbounded()), Range.Boundary.unbounded()), Limit.from(pageSize)))); - final Flux allMessages = getNextPage.apply(0L) + final Sinks.Many staleEphemeralMessages = Sinks.many().unicast().onBackpressureBuffer(); + final Sinks.Many staleMrmMessages = Sinks.many().unicast().onBackpressureBuffer(); + + final Flux messagesToPersist = getNextPage.apply(Optional.empty()) .expand(scoredValues -> { if (scoredValues.isEmpty()) { return Mono.empty(); } long lastTimestamp = (long) scoredValues.getLast().getScore(); - return getNextPage.apply(lastTimestamp); + return getNextPage.apply(Optional.of(lastTimestamp)); }) .concatMap(scoredValues -> Flux.fromStream(scoredValues.stream().map(ScoredValue::getValue))) .mapNotNull(message -> { @@ -572,30 +578,25 @@ public class MessagesCache { return messageMono; }) - .publish() - // We expect exactly three subscribers to this base flux: - // 1. the caller of the method - // 2. an internal processes to discard stale ephemeral messages - // 3. an internal process to discard stale MRM messages - // The discard subscribers will subscribe immediately, but we don’t want to do any work if the - // caller never subscribes - .autoConnect(3); + .doOnNext(envelope -> { + if (envelope.getEphemeral()) { + staleEphemeralMessages.tryEmitNext(envelope).orThrow(); + } else if (isStaleMrmMessage(envelope)) { + // clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data + staleMrmMessages.tryEmitNext(envelope.toBuilder().clearSharedMrmKey().build()).orThrow(); + } + }) + .filter(Predicate.not(envelope -> envelope.getEphemeral() || isStaleMrmMessage(envelope))); - final Flux messagesToPersist = allMessages - .filter(Predicate.not(envelope -> - envelope.getEphemeral() || isStaleMrmMessage(envelope))); - - final Flux ephemeralMessages = allMessages - .filter(MessageProtos.Envelope::getEphemeral); - discardStaleMessages(accountUuid, destinationDevice, ephemeralMessages, staleEphemeralMessagesCounter, "ephemeral"); - - final Flux staleMrmMessages = allMessages.filter(MessagesCache::isStaleMrmMessage) - // clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data - .map(envelope -> envelope.toBuilder().clearSharedMrmKey().build()); - discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm"); + discardStaleMessages(accountUuid, destinationDevice, staleEphemeralMessages.asFlux(), staleEphemeralMessagesCounter, "ephemeral"); + discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages.asFlux(), staleMrmMessagesCounter, "mrm"); return messagesToPersist - .doOnTerminate(() -> sample.stop(getMessagesTimer)); + .doFinally(signal -> { + sample.stop(getMessagesTimer); + staleEphemeralMessages.tryEmitComplete(); + staleMrmMessages.tryEmitComplete(); + }); } List getMessagesToPersist(final UUID accountUuid, final byte destinationDevice, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 18e3971d9..bc9a0264b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; @@ -490,13 +491,70 @@ class MessagesCacheTest { }, "Shared MRM data should be deleted asynchronously"); } + @Test + void testMessagesToPersistPagination() throws InterruptedException { + final UUID destinationUuid = UUID.randomUUID(); + final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid); + final byte deviceId = 1; + final byte[] queueKey = MessagesCache.getMessageQueueKey(destinationUuid, deviceId); + + final List messages = IntStream.range(0, 60) + .mapToObj(i -> switch (i % 3) { + // Stale MRM + case 0 -> generateRandomMessage(UUID.randomUUID(), serviceId, true) + .toBuilder() + // clear some things added by the helper + .clearContent() + .setSharedMrmKey(MessagesCache.STALE_MRM_KEY) + .build(); + // ephemeral message + case 1 -> generateRandomMessage(UUID.randomUUID(), serviceId, true) + .toBuilder() + .setEphemeral(true).build(); + // standard message + case 2 -> generateRandomMessage(UUID.randomUUID(), serviceId, true); + default -> throw new IllegalStateException(); + }) + .toList(); + for (MessageProtos.Envelope envelope : messages) { + messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join(); + } + + final List expectedGuidsToPersist = messages.stream() + .filter(envelope -> !envelope.getEphemeral() && !envelope.hasSharedMrmKey()) + .map(envelope -> UUID.fromString(envelope.getServerGuid())) + .limit(10) + .collect(Collectors.toList()); + + // Fetch 10 messages which should discard 20 ephemeral stale messages, and leave the rest + final List actual = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 10).stream() + .map(envelope -> UUID.fromString(envelope.getServerGuid())) + .toList(); + assertIterableEquals(expectedGuidsToPersist, actual); + + // Eventually, the 20 ephemeral/stale messages should be discarded + assertTimeoutPreemptively(Duration.ofSeconds(1), () -> { + while (REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withBinaryCluster(conn -> conn.sync().zcard(queueKey)) != 40) { + Thread.sleep(1); + } + }, "Ephemeral and stale messages should be deleted asynchronously"); + + // Let all pending tasks finish and make sure no more stale messages have been deleted + sharedExecutorService.shutdown(); + sharedExecutorService.awaitTermination(1, TimeUnit.SECONDS); + assertEquals(REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withBinaryCluster(conn -> conn.sync().zcard(queueKey)).longValue(), 40); + } + @Test void testMessagesToPersistReactive() { final UUID destinationUuid = UUID.randomUUID(); final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid); final byte deviceId = 1; + final byte[] messageQueueKey = MessagesCache.getMessageQueueKey(destinationUuid, deviceId); - final List expected = IntStream.range(0, 100) + final List messages = IntStream.range(0, 200) .mapToObj(i -> { if (i % 3 == 0) { final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(serviceId, deviceId); @@ -509,17 +567,27 @@ class MessagesCacheTest { .build(); } else if (i % 13 == 0) { return generateRandomMessage(UUID.randomUUID(), serviceId, true).toBuilder().setEphemeral(true).build(); + } else if (i % 17 == 0) { + return generateRandomMessage(UUID.randomUUID(), serviceId, true) + .toBuilder() + // clear some things added by the helper + .clearContent() + .setSharedMrmKey(MessagesCache.STALE_MRM_KEY) + .build(); } else { return generateRandomMessage(UUID.randomUUID(), serviceId, true); } }) - .filter(envelope -> !envelope.getEphemeral()) .toList(); - for (MessageProtos.Envelope envelope : expected) { + for (MessageProtos.Envelope envelope : messages) { messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join(); } + final List expected = messages.stream() + .filter(envelope -> !envelope.getEphemeral() && + (envelope.getSharedMrmKey() == null || !envelope.getSharedMrmKey().equals(MessagesCache.STALE_MRM_KEY))) + .toList(); final List actual = messagesCache .getMessagesToPersistReactive(destinationUuid, deviceId, 7).collectList().block(); @@ -528,6 +596,14 @@ class MessagesCacheTest { assertNotNull(actual.get(i).getContent()); assertEquals(actual.get(i).getServerGuid(), expected.get(i).getServerGuid()); } + + // Ephemeral messages and stale MRM messages are asynchronously deleted, but eventually they should all be removed + assertTimeoutPreemptively(Duration.ofSeconds(1), () -> { + while (REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withBinaryCluster(conn -> conn.sync().zcard(messageQueueKey)) != expected.size()) { + Thread.sleep(1); + } + }, "Ephemeral and stale messages should be deleted asynchronously"); } @ParameterizedTest