Add a dedicated size estimation method to MessagesCache

This commit is contained in:
ravi-signal 2025-03-10 16:09:05 -05:00 committed by GitHub
parent 6798958650
commit e3160bc717
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 97 additions and 146 deletions

View File

@ -250,10 +250,7 @@ public class MessagePersister implements Managed {
final Device device = maybeDevice.get(); final Device device = maybeDevice.get();
// Calculate how many bytes we should trim // Calculate how many bytes we should trim
final long cachedMessageBytes = Flux final long cachedMessageBytes = messagesCache.estimatePersistedQueueSizeBytes(aci, deviceId).join();
.from(messagesCache.getMessagesToPersistReactive(aci, deviceId, CACHE_PAGE_SIZE))
.reduce(0, (acc, envelope) -> acc + envelope.getSerializedSize())
.block();
final double extraRoomRatio = this.dynamicConfigurationManager.getConfiguration() final double extraRoomRatio = this.dynamicConfigurationManager.getConfiguration()
.getMessagePersisterConfiguration() .getMessagePersisterConfiguration()
.getTrimOversizedQueueExtraRoomRatio(); .getTrimOversizedQueueExtraRoomRatio();

View File

@ -50,7 +50,6 @@ import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import reactor.core.observability.micrometer.Micrometer; import reactor.core.observability.micrometer.Micrometer;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers; import reactor.core.scheduler.Schedulers;
@ -534,11 +533,13 @@ public class MessagesCache {
}); });
} }
Flux<MessageProtos.Envelope> getMessagesToPersistReactive(final UUID accountUuid, final byte destinationDevice, /**
final int pageSize) { * Estimate the size of the cached queue if it were to be persisted
final Timer.Sample sample = Timer.start(); * @param accountUuid The account identifier
* @param destinationDevice The destination device id
* @return A future that completes with the approximate size of stored messages that need to be persisted
*/
CompletableFuture<Long> estimatePersistedQueueSizeBytes(final UUID accountUuid, final byte destinationDevice) {
final Function<Optional<Long>, Mono<List<ScoredValue<byte[]>>>> getNextPage = (Optional<Long> start) -> final Function<Optional<Long>, Mono<List<ScoredValue<byte[]>>>> getNextPage = (Optional<Long> start) ->
Mono.fromCompletionStage(() -> redisCluster.withBinaryCluster(connection -> Mono.fromCompletionStage(() -> redisCluster.withBinaryCluster(connection ->
connection.async().zrangebyscoreWithScores( connection.async().zrangebyscoreWithScores(
@ -546,12 +547,8 @@ public class MessagesCache {
Range.from( Range.from(
start.map(Range.Boundary::excluding).orElse(Range.Boundary.unbounded()), start.map(Range.Boundary::excluding).orElse(Range.Boundary.unbounded()),
Range.Boundary.unbounded()), Range.Boundary.unbounded()),
Limit.from(pageSize)))); Limit.from(PAGE_SIZE))));
final Flux<byte[]> allSerializedMessages = getNextPage.apply(Optional.empty())
final Sinks.Many<MessageProtos.Envelope> staleEphemeralMessages = Sinks.many().unicast().onBackpressureBuffer();
final Sinks.Many<MessageProtos.Envelope> staleMrmMessages = Sinks.many().unicast().onBackpressureBuffer();
final Flux<MessageProtos.Envelope> messagesToPersist = getNextPage.apply(Optional.empty())
.expand(scoredValues -> { .expand(scoredValues -> {
if (scoredValues.isEmpty()) { if (scoredValues.isEmpty()) {
return Mono.empty(); return Mono.empty();
@ -559,7 +556,45 @@ public class MessagesCache {
long lastTimestamp = (long) scoredValues.getLast().getScore(); long lastTimestamp = (long) scoredValues.getLast().getScore();
return getNextPage.apply(Optional.of(lastTimestamp)); return getNextPage.apply(Optional.of(lastTimestamp));
}) })
.concatMap(scoredValues -> Flux.fromStream(scoredValues.stream().map(ScoredValue::getValue))) .concatMap(scoredValues -> Flux.fromStream(scoredValues.stream().map(ScoredValue::getValue)));
return parseAndFetchMrms(allSerializedMessages, destinationDevice)
.filter(Predicate.not(envelope -> envelope.getEphemeral() || isStaleMrmMessage(envelope)))
.reduce(0L, (acc, envelope) -> acc + envelope.getSerializedSize())
.toFuture();
}
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final byte destinationDevice,
final int limit) {
final Timer.Sample sample = Timer.start();
final List<byte[]> messages = redisCluster.withBinaryCluster(connection ->
connection.sync().zrange(getMessageQueueKey(accountUuid, destinationDevice), 0, limit));
final Flux<MessageProtos.Envelope> allMessages = parseAndFetchMrms(Flux.fromIterable(messages), destinationDevice);
final Flux<MessageProtos.Envelope> messagesToPersist = allMessages
.filter(Predicate.not(envelope ->
envelope.getEphemeral() || isStaleMrmMessage(envelope)));
final Flux<MessageProtos.Envelope> ephemeralMessages = allMessages
.filter(MessageProtos.Envelope::getEphemeral);
discardStaleMessages(accountUuid, destinationDevice, ephemeralMessages, staleEphemeralMessagesCounter, "ephemeral");
final Flux<MessageProtos.Envelope> staleMrmMessages = allMessages.filter(MessagesCache::isStaleMrmMessage)
// clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data
.map(envelope -> envelope.toBuilder().clearSharedMrmKey().build());
discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm");
return messagesToPersist
.collectList()
.doOnTerminate(() -> sample.stop(getMessagesTimer))
.block(Duration.ofSeconds(5));
}
private Flux<MessageProtos.Envelope> parseAndFetchMrms(final Flux<byte[]> serializedMessages, final byte destinationDevice) {
return serializedMessages
.mapNotNull(message -> { .mapNotNull(message -> {
try { try {
return MessageProtos.Envelope.parseFrom(message); return MessageProtos.Envelope.parseFrom(message);
@ -577,36 +612,9 @@ public class MessagesCache {
} }
return messageMono; return messageMono;
})
.doOnNext(envelope -> {
if (envelope.getEphemeral()) {
staleEphemeralMessages.tryEmitNext(envelope).orThrow();
} else if (isStaleMrmMessage(envelope)) {
// clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data
staleMrmMessages.tryEmitNext(envelope.toBuilder().clearSharedMrmKey().build()).orThrow();
}
})
.filter(Predicate.not(envelope -> envelope.getEphemeral() || isStaleMrmMessage(envelope)));
discardStaleMessages(accountUuid, destinationDevice, staleEphemeralMessages.asFlux(), staleEphemeralMessagesCounter, "ephemeral");
discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages.asFlux(), staleMrmMessagesCounter, "mrm");
return messagesToPersist
.doFinally(signal -> {
sample.stop(getMessagesTimer);
staleEphemeralMessages.tryEmitComplete();
staleMrmMessages.tryEmitComplete();
}); });
} }
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final byte destinationDevice,
final int limit) {
return getMessagesToPersistReactive(accountUuid, destinationDevice, limit)
.take(limit)
.collectList()
.block(Duration.ofSeconds(5));
}
public CompletableFuture<Void> clear(final UUID destinationUuid) { public CompletableFuture<Void> clear(final UUID destinationUuid) {
return CompletableFuture.allOf( return CompletableFuture.allOf(
Device.ALL_POSSIBLE_DEVICE_IDS.stream() Device.ALL_POSSIBLE_DEVICE_IDS.stream()

View File

@ -492,118 +492,64 @@ class MessagesCacheTest {
} }
@Test @Test
void testMessagesToPersistPagination() throws InterruptedException { void testEstimatePersistedQueueSize() {
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid); final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid);
final byte deviceId = 1; final byte deviceId = 1;
final byte[] queueKey = MessagesCache.getMessageQueueKey(destinationUuid, deviceId);
final List<MessageProtos.Envelope> messages = IntStream.range(0, 60) // Should count all non-ephemeral, non-stale message bytes
.mapToObj(i -> switch (i % 3) { long expectedQueueSize = 0L;
// Stale MRM for (int i = 0; i < 400; i++) {
case 0 -> generateRandomMessage(UUID.randomUUID(), serviceId, true) final MessageProtos.Envelope messageToInsert = switch (i % 4) {
// An MRM message
case 0 -> {
// First generate a random MRM message
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(serviceId, deviceId);
final SealedSenderMultiRecipientMessage.Recipient recepient = mrm.getRecipients()
.get(serviceId.toLibsignal());
// Calculate the size of a message that has the shared content in it
final MessageProtos.Envelope message = generateRandomMessage(UUID.randomUUID(), serviceId, true)
.toBuilder()
.setContent(ByteString.copyFrom(mrm.messageForRecipient(recepient)))
.build();
expectedQueueSize += message.getSerializedSize();
byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
// Insert the MRM message without the content
yield message
.toBuilder()
.clearContent()
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build();
}
// A stale MRM message
case 1 ->
generateRandomMessage(UUID.randomUUID(), serviceId, true)
.toBuilder() .toBuilder()
// clear some things added by the helper // clear some things added by the helper
.clearContent() .clearContent()
.setSharedMrmKey(MessagesCache.STALE_MRM_KEY) .setSharedMrmKey(MessagesCache.STALE_MRM_KEY)
.build(); .build();
// ephemeral message
case 1 -> generateRandomMessage(UUID.randomUUID(), serviceId, true) // An ephemeral message
.toBuilder() case 2 -> generateRandomMessage(UUID.randomUUID(), serviceId, true).toBuilder().setEphemeral(true).build();
.setEphemeral(true).build();
// standard message // A standardard message
case 2 -> generateRandomMessage(UUID.randomUUID(), serviceId, true); case 3 -> {
default -> throw new IllegalStateException(); final MessageProtos.Envelope message = generateRandomMessage(UUID.randomUUID(), serviceId, true);
}) expectedQueueSize += message.getSerializedSize();
.toList(); yield message;
for (MessageProtos.Envelope envelope : messages) { }
messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join();
default -> throw new IllegalStateException();
};
messagesCache.insert(UUID.fromString(messageToInsert.getServerGuid()), destinationUuid, deviceId, messageToInsert).join();
} }
long actualQueueSize = messagesCache.estimatePersistedQueueSizeBytes(destinationUuid, deviceId).join();
final List<UUID> expectedGuidsToPersist = messages.stream() assertEquals(expectedQueueSize, actualQueueSize);
.filter(envelope -> !envelope.getEphemeral() && !envelope.hasSharedMrmKey())
.map(envelope -> UUID.fromString(envelope.getServerGuid()))
.limit(10)
.collect(Collectors.toList());
// Fetch 10 messages which should discard 20 ephemeral stale messages, and leave the rest
final List<UUID> actual = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 10).stream()
.map(envelope -> UUID.fromString(envelope.getServerGuid()))
.toList();
assertIterableEquals(expectedGuidsToPersist, actual);
// Eventually, the 20 ephemeral/stale messages should be discarded
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> {
while (REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().zcard(queueKey)) != 40) {
Thread.sleep(1);
}
}, "Ephemeral and stale messages should be deleted asynchronously");
// Let all pending tasks finish and make sure no more stale messages have been deleted
sharedExecutorService.shutdown();
sharedExecutorService.awaitTermination(1, TimeUnit.SECONDS);
assertEquals(REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().zcard(queueKey)).longValue(), 40);
}
@Test
void testMessagesToPersistReactive() {
final UUID destinationUuid = UUID.randomUUID();
final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid);
final byte deviceId = 1;
final byte[] messageQueueKey = MessagesCache.getMessageQueueKey(destinationUuid, deviceId);
final List<MessageProtos.Envelope> messages = IntStream.range(0, 200)
.mapToObj(i -> {
if (i % 3 == 0) {
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(serviceId, deviceId);
byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
return generateRandomMessage(UUID.randomUUID(), serviceId, true)
.toBuilder()
// clear some things added by the helper
.clearContent()
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build();
} else if (i % 13 == 0) {
return generateRandomMessage(UUID.randomUUID(), serviceId, true).toBuilder().setEphemeral(true).build();
} else if (i % 17 == 0) {
return generateRandomMessage(UUID.randomUUID(), serviceId, true)
.toBuilder()
// clear some things added by the helper
.clearContent()
.setSharedMrmKey(MessagesCache.STALE_MRM_KEY)
.build();
} else {
return generateRandomMessage(UUID.randomUUID(), serviceId, true);
}
})
.toList();
for (MessageProtos.Envelope envelope : messages) {
messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join();
}
final List<MessageProtos.Envelope> expected = messages.stream()
.filter(envelope -> !envelope.getEphemeral() &&
(envelope.getSharedMrmKey() == null || !envelope.getSharedMrmKey().equals(MessagesCache.STALE_MRM_KEY)))
.toList();
final List<MessageProtos.Envelope> actual = messagesCache
.getMessagesToPersistReactive(destinationUuid, deviceId, 7).collectList().block();
assertEquals(expected.size(), actual.size());
for (int i = 0; i < actual.size(); i++) {
assertNotNull(actual.get(i).getContent());
assertEquals(actual.get(i).getServerGuid(), expected.get(i).getServerGuid());
}
// Ephemeral messages and stale MRM messages are asynchronously deleted, but eventually they should all be removed
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> {
while (REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().zcard(messageQueueKey)) != expected.size()) {
Thread.sleep(1);
}
}, "Ephemeral and stale messages should be deleted asynchronously");
} }
@ParameterizedTest @ParameterizedTest