Dont discard ephemeral messages beyond what the persister consumes

This commit is contained in:
ravi-signal 2025-03-07 15:27:03 -06:00 committed by GitHub
parent b7fee7b426
commit eab3c36d83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 106 additions and 29 deletions

View File

@ -50,6 +50,7 @@ import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import reactor.core.observability.micrometer.Micrometer; import reactor.core.observability.micrometer.Micrometer;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers; import reactor.core.scheduler.Schedulers;
@ -149,7 +150,8 @@ public class MessagesCache {
static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot";
private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8);
private static final ByteString STALE_MRM_KEY = ByteString.copyFromUtf8("stale"); @VisibleForTesting
static final ByteString STALE_MRM_KEY = ByteString.copyFromUtf8("stale");
@VisibleForTesting @VisibleForTesting
static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10); static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10);
@ -536,22 +538,26 @@ public class MessagesCache {
final int pageSize) { final int pageSize) {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
final Function<Long, Mono<List<ScoredValue<byte[]>>>> getNextPage = (Long start) ->
final Function<Optional<Long>, Mono<List<ScoredValue<byte[]>>>> getNextPage = (Optional<Long> start) ->
Mono.fromCompletionStage(() -> redisCluster.withBinaryCluster(connection -> Mono.fromCompletionStage(() -> redisCluster.withBinaryCluster(connection ->
connection.async().zrangebyscoreWithScores( connection.async().zrangebyscoreWithScores(
getMessageQueueKey(accountUuid, destinationDevice), getMessageQueueKey(accountUuid, destinationDevice),
Range.from( Range.from(
Range.Boundary.excluding(start), start.map(Range.Boundary::excluding).orElse(Range.Boundary.unbounded()),
Range.Boundary.unbounded()), Range.Boundary.unbounded()),
Limit.from(pageSize)))); Limit.from(pageSize))));
final Flux<MessageProtos.Envelope> allMessages = getNextPage.apply(0L) final Sinks.Many<MessageProtos.Envelope> staleEphemeralMessages = Sinks.many().unicast().onBackpressureBuffer();
final Sinks.Many<MessageProtos.Envelope> staleMrmMessages = Sinks.many().unicast().onBackpressureBuffer();
final Flux<MessageProtos.Envelope> messagesToPersist = getNextPage.apply(Optional.empty())
.expand(scoredValues -> { .expand(scoredValues -> {
if (scoredValues.isEmpty()) { if (scoredValues.isEmpty()) {
return Mono.empty(); return Mono.empty();
} }
long lastTimestamp = (long) scoredValues.getLast().getScore(); long lastTimestamp = (long) scoredValues.getLast().getScore();
return getNextPage.apply(lastTimestamp); return getNextPage.apply(Optional.of(lastTimestamp));
}) })
.concatMap(scoredValues -> Flux.fromStream(scoredValues.stream().map(ScoredValue::getValue))) .concatMap(scoredValues -> Flux.fromStream(scoredValues.stream().map(ScoredValue::getValue)))
.mapNotNull(message -> { .mapNotNull(message -> {
@ -572,30 +578,25 @@ public class MessagesCache {
return messageMono; return messageMono;
}) })
.publish() .doOnNext(envelope -> {
// We expect exactly three subscribers to this base flux: if (envelope.getEphemeral()) {
// 1. the caller of the method staleEphemeralMessages.tryEmitNext(envelope).orThrow();
// 2. an internal processes to discard stale ephemeral messages } else if (isStaleMrmMessage(envelope)) {
// 3. an internal process to discard stale MRM messages // clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data
// The discard subscribers will subscribe immediately, but we dont want to do any work if the staleMrmMessages.tryEmitNext(envelope.toBuilder().clearSharedMrmKey().build()).orThrow();
// caller never subscribes }
.autoConnect(3); })
.filter(Predicate.not(envelope -> envelope.getEphemeral() || isStaleMrmMessage(envelope)));
final Flux<MessageProtos.Envelope> messagesToPersist = allMessages discardStaleMessages(accountUuid, destinationDevice, staleEphemeralMessages.asFlux(), staleEphemeralMessagesCounter, "ephemeral");
.filter(Predicate.not(envelope -> discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages.asFlux(), staleMrmMessagesCounter, "mrm");
envelope.getEphemeral() || isStaleMrmMessage(envelope)));
final Flux<MessageProtos.Envelope> ephemeralMessages = allMessages
.filter(MessageProtos.Envelope::getEphemeral);
discardStaleMessages(accountUuid, destinationDevice, ephemeralMessages, staleEphemeralMessagesCounter, "ephemeral");
final Flux<MessageProtos.Envelope> staleMrmMessages = allMessages.filter(MessagesCache::isStaleMrmMessage)
// clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data
.map(envelope -> envelope.toBuilder().clearSharedMrmKey().build());
discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm");
return messagesToPersist return messagesToPersist
.doOnTerminate(() -> sample.stop(getMessagesTimer)); .doFinally(signal -> {
sample.stop(getMessagesTimer);
staleEphemeralMessages.tryEmitComplete();
staleMrmMessages.tryEmitComplete();
});
} }
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final byte destinationDevice, List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final byte destinationDevice,

View File

@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
@ -490,13 +491,70 @@ class MessagesCacheTest {
}, "Shared MRM data should be deleted asynchronously"); }, "Shared MRM data should be deleted asynchronously");
} }
@Test
void testMessagesToPersistPagination() throws InterruptedException {
final UUID destinationUuid = UUID.randomUUID();
final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid);
final byte deviceId = 1;
final byte[] queueKey = MessagesCache.getMessageQueueKey(destinationUuid, deviceId);
final List<MessageProtos.Envelope> messages = IntStream.range(0, 60)
.mapToObj(i -> switch (i % 3) {
// Stale MRM
case 0 -> generateRandomMessage(UUID.randomUUID(), serviceId, true)
.toBuilder()
// clear some things added by the helper
.clearContent()
.setSharedMrmKey(MessagesCache.STALE_MRM_KEY)
.build();
// ephemeral message
case 1 -> generateRandomMessage(UUID.randomUUID(), serviceId, true)
.toBuilder()
.setEphemeral(true).build();
// standard message
case 2 -> generateRandomMessage(UUID.randomUUID(), serviceId, true);
default -> throw new IllegalStateException();
})
.toList();
for (MessageProtos.Envelope envelope : messages) {
messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join();
}
final List<UUID> expectedGuidsToPersist = messages.stream()
.filter(envelope -> !envelope.getEphemeral() && !envelope.hasSharedMrmKey())
.map(envelope -> UUID.fromString(envelope.getServerGuid()))
.limit(10)
.collect(Collectors.toList());
// Fetch 10 messages which should discard 20 ephemeral stale messages, and leave the rest
final List<UUID> actual = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 10).stream()
.map(envelope -> UUID.fromString(envelope.getServerGuid()))
.toList();
assertIterableEquals(expectedGuidsToPersist, actual);
// Eventually, the 20 ephemeral/stale messages should be discarded
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> {
while (REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().zcard(queueKey)) != 40) {
Thread.sleep(1);
}
}, "Ephemeral and stale messages should be deleted asynchronously");
// Let all pending tasks finish and make sure no more stale messages have been deleted
sharedExecutorService.shutdown();
sharedExecutorService.awaitTermination(1, TimeUnit.SECONDS);
assertEquals(REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().zcard(queueKey)).longValue(), 40);
}
@Test @Test
void testMessagesToPersistReactive() { void testMessagesToPersistReactive() {
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid); final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid);
final byte deviceId = 1; final byte deviceId = 1;
final byte[] messageQueueKey = MessagesCache.getMessageQueueKey(destinationUuid, deviceId);
final List<MessageProtos.Envelope> expected = IntStream.range(0, 100) final List<MessageProtos.Envelope> messages = IntStream.range(0, 200)
.mapToObj(i -> { .mapToObj(i -> {
if (i % 3 == 0) { if (i % 3 == 0) {
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(serviceId, deviceId); final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(serviceId, deviceId);
@ -509,17 +567,27 @@ class MessagesCacheTest {
.build(); .build();
} else if (i % 13 == 0) { } else if (i % 13 == 0) {
return generateRandomMessage(UUID.randomUUID(), serviceId, true).toBuilder().setEphemeral(true).build(); return generateRandomMessage(UUID.randomUUID(), serviceId, true).toBuilder().setEphemeral(true).build();
} else if (i % 17 == 0) {
return generateRandomMessage(UUID.randomUUID(), serviceId, true)
.toBuilder()
// clear some things added by the helper
.clearContent()
.setSharedMrmKey(MessagesCache.STALE_MRM_KEY)
.build();
} else { } else {
return generateRandomMessage(UUID.randomUUID(), serviceId, true); return generateRandomMessage(UUID.randomUUID(), serviceId, true);
} }
}) })
.filter(envelope -> !envelope.getEphemeral())
.toList(); .toList();
for (MessageProtos.Envelope envelope : expected) { for (MessageProtos.Envelope envelope : messages) {
messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join(); messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join();
} }
final List<MessageProtos.Envelope> expected = messages.stream()
.filter(envelope -> !envelope.getEphemeral() &&
(envelope.getSharedMrmKey() == null || !envelope.getSharedMrmKey().equals(MessagesCache.STALE_MRM_KEY)))
.toList();
final List<MessageProtos.Envelope> actual = messagesCache final List<MessageProtos.Envelope> actual = messagesCache
.getMessagesToPersistReactive(destinationUuid, deviceId, 7).collectList().block(); .getMessagesToPersistReactive(destinationUuid, deviceId, 7).collectList().block();
@ -528,6 +596,14 @@ class MessagesCacheTest {
assertNotNull(actual.get(i).getContent()); assertNotNull(actual.get(i).getContent());
assertEquals(actual.get(i).getServerGuid(), expected.get(i).getServerGuid()); assertEquals(actual.get(i).getServerGuid(), expected.get(i).getServerGuid());
} }
// Ephemeral messages and stale MRM messages are asynchronously deleted, but eventually they should all be removed
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> {
while (REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().zcard(messageQueueKey)) != expected.size()) {
Thread.sleep(1);
}
}, "Ephemeral and stale messages should be deleted asynchronously");
} }
@ParameterizedTest @ParameterizedTest