From 3f9863c441e67029c69dfa8c1451911de5032010 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Wed, 29 Jan 2025 11:50:20 -0600 Subject: [PATCH] Discard mrm messages that can never be sent --- .../textsecuregcm/storage/MessagesCache.java | 87 +++++++++++++++---- .../storage/MrmDataMissingException.java | 21 +++++ .../storage/MessagesCacheTest.java | 25 ++++-- 3 files changed, 109 insertions(+), 24 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/MrmDataMissingException.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index fb78e0cd6..3633986db 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -14,6 +14,7 @@ import io.lettuce.core.ZAddArgs; import io.lettuce.core.cluster.SlotHash; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Timer; import java.io.IOException; import java.nio.ByteBuffer; @@ -131,9 +132,11 @@ public class MessagesCache { private final Timer clearQueueTimer = Metrics.timer(name(MessagesCache.class, "clear")); private final Counter staleEphemeralMessagesCounter = Metrics.counter( name(MessagesCache.class, "staleEphemeralMessages")); + private final Counter staleMrmMessagesCounter = Metrics.counter(name(MessagesCache.class, "staleMrmMessages")); private final Counter mrmContentRetrievedCounter = Metrics.counter(name(MessagesCache.class, "mrmViewRetrieved")); private final String MRM_RETRIEVAL_ERROR_COUNTER_NAME = name(MessagesCache.class, "mrmRetrievalError"); private final String EPHEMERAL_TAG_NAME = "ephemeral"; + private final String MISSING_MRM_DATA_TAG_NAME = "missingMrmData"; private final Counter skippedStaleEphemeralMrmCounter = Metrics.counter( name(MessagesCache.class, "skippedStaleEphemeralMrm")); private final Counter sharedMrmDataKeyRemovedCounter = Metrics.counter( @@ -142,6 +145,8 @@ public class MessagesCache { static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); + private static final ByteString STALE_MRM_KEY = ByteString.copyFromUtf8("stale"); + @VisibleForTesting static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10); @@ -287,20 +292,26 @@ public class MessagesCache { final Flux allMessages = getAllMessages(destinationUuid, destinationDevice, earliestAllowableEphemeralTimestamp, PAGE_SIZE) .publish() - // We expect exactly two subscribers to this base flux: + // We expect exactly three subscribers to this base flux: // 1. the websocket that delivers messages to clients - // 2. an internal process to discard stale ephemeral messages - // The discard subscriber will subscribe immediately, but we don’t want to do any work if the + // 2. an internal processes to discard stale ephemeral messages + // 3. an internal process to discard stale MRM messages + // The discard subscribers will subscribe immediately, but we don’t want to do any work if the // websocket never subscribes. - .autoConnect(2); + .autoConnect(3); final Flux messagesToPublish = allMessages - .filter(Predicate.not(envelope -> isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp))); + .filter(Predicate.not(envelope -> + isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp) || isStaleMrmMessage(envelope))); final Flux staleEphemeralMessages = allMessages .filter(envelope -> isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp)); + discardStaleMessages(destinationUuid, destinationDevice, staleEphemeralMessages, staleEphemeralMessagesCounter, "ephemeral"); - discardStaleEphemeralMessages(destinationUuid, destinationDevice, staleEphemeralMessages); + final Flux staleMrmMessages = allMessages.filter(MessagesCache::isStaleMrmMessage) + // clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data + .map(envelope -> envelope.toBuilder().clearSharedMrmKey().build()); + discardStaleMessages(destinationUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm"); return messagesToPublish.name(GET_FLUX_NAME) .tap(Micrometer.metrics(Metrics.globalRegistry)); @@ -317,16 +328,25 @@ public class MessagesCache { return message.getEphemeral() && message.getClientTimestamp() < earliestAllowableTimestamp; } - private void discardStaleEphemeralMessages(final UUID destinationUuid, final byte destinationDevice, - Flux staleEphemeralMessages) { - staleEphemeralMessages + /** + * Checks whether the given message is a stale MRM message + * + * @see #getMessageWithSharedMrmData(MessageProtos.Envelope, byte) + */ + private static boolean isStaleMrmMessage(final MessageProtos.Envelope message) { + return message.hasSharedMrmKey() && STALE_MRM_KEY.equals(message.getSharedMrmKey()); + } + + private void discardStaleMessages(final UUID destinationUuid, final byte destinationDevice, + Flux staleMessages, final Counter counter, final String context) { + staleMessages .map(e -> UUID.fromString(e.getServerGuid())) .buffer(PAGE_SIZE) .subscribeOn(messageDeletionScheduler) - .subscribe(staleEphemeralMessageGuids -> - remove(destinationUuid, destinationDevice, staleEphemeralMessageGuids) - .thenAccept(removedMessages -> staleEphemeralMessagesCounter.increment(removedMessages.size())), - e -> logger.warn("Could not remove stale ephemeral messages from cache", e)); + .subscribe(messageGuids -> + remove(destinationUuid, destinationDevice, messageGuids) + .thenAccept(removedMessages -> counter.increment(removedMessages.size())), + e -> logger.warn("Could not remove stale {} messages from cache", context, e)); } @VisibleForTesting @@ -382,7 +402,13 @@ public class MessagesCache { } /** - * Returns the given message with its shared MRM data. + * Returns the given message with its shared MRM data. There are three possible cases: + *
    + *
  1. The reconstructed message for delivery with {@code content} set and {@code sharedMrmKey} cleared
  2. + *
  3. The input with {@code sharedMrmKey} set to a static value, indicating that the shared MRM data is no longer available, and the message should be + * discarded from the queue
  4. + *
  5. An empty {@code Mono}, if an unexpected error occurred
  6. + *
*/ private Mono getMessageWithSharedMrmData(final MessageProtos.Envelope mrmMessage, final byte destinationDevice) { @@ -402,6 +428,18 @@ public class MessagesCache { try { assert mrmDataAndView.size() == 2; + if (mrmDataAndView.getFirst().isEmpty()) { + // shared data is missing + //noinspection ReactiveStreamsThrowInOperator + throw new MrmDataMissingException(MrmDataMissingException.Type.SHARED); + } + + if (mrmDataAndView.getLast().isEmpty()) { + // recipient's view is missing + //noinspection ReactiveStreamsThrowInOperator + throw new MrmDataMissingException(MrmDataMissingException.Type.RECIPIENT_VIEW); + } + final byte[] content = SealedSenderMultiRecipientMessage.messageForRecipient( mrmDataAndView.getFirst().getValue(), mrmDataAndView.getLast().getValue()); @@ -418,11 +456,24 @@ public class MessagesCache { }) .onErrorResume(throwable -> { logger.warn("Failed to retrieve shared mrm data", throwable); - Metrics.counter(MRM_RETRIEVAL_ERROR_COUNTER_NAME, - EPHEMERAL_TAG_NAME, String.valueOf(mrmMessage.getEphemeral())) - .increment(); - return Mono.empty(); + final List tags = new ArrayList<>(); + tags.add(Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(mrmMessage.getEphemeral()))); + + final Mono result; + if (throwable instanceof MrmDataMissingException mdme) { + tags.add(Tag.of(MISSING_MRM_DATA_TAG_NAME, mdme.getType().name())); + // MRM data may be missing if either of the two non-transactional writes (delete from queue, update shared + // MRM data) fails after it has been delivered. We return it so that it may be discarded from the queue. + result = Mono.just(mrmMessage.toBuilder().setSharedMrmKey(STALE_MRM_KEY).build()); + } else { + // For unexpected errors, return empty. The message will remain in the queue and be retried in the future. + result = Mono.empty(); + } + + Metrics.counter(MRM_RETRIEVAL_ERROR_COUNTER_NAME, tags).increment(); + + return result; }) .share(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MrmDataMissingException.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MrmDataMissingException.java new file mode 100644 index 000000000..8557fc05a --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MrmDataMissingException.java @@ -0,0 +1,21 @@ +package org.whispersystems.textsecuregcm.storage; + +import org.whispersystems.textsecuregcm.util.NoStackTraceRuntimeException; + +class MrmDataMissingException extends NoStackTraceRuntimeException { + + enum Type { + SHARED, + RECIPIENT_VIEW + } + + private final Type type; + + MrmDataMissingException(final Type type) { + this.type = type; + } + + Type getType() { + return type; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index a3332e738..06f7f36ec 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -446,23 +446,36 @@ class MessagesCacheTest { .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey))); final List messages = get(destinationServiceId.uuid(), deviceId, 1); + if (!sharedMrmKeyPresent) { assertTrue(messages.isEmpty()); - } else { + // the discard is purely async, so we just wait for it + assertTimeoutPreemptively(Duration.ofSeconds(1), () -> { + boolean exists; + do { + exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withBinaryCluster(conn -> + conn.sync().hlen(MessagesCache.getMessageQueueKey(destinationServiceId.uuid(), deviceId))); + } while (exists); + }, "Stale MRM message should be deleted asynchronously"); + + } else { assertEquals(1, messages.size()); + assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid())); assertFalse(messages.getFirst().hasSharedMrmKey()); final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients() .get(destinationServiceId.toLibsignal()); assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray()); + + final Optional removedMessage = messagesCache.remove(destinationServiceId.uuid(), deviceId, guid) + .join(); + + assertTrue(removedMessage.isPresent()); + assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString())); } - final Optional removedMessage = messagesCache.remove(destinationServiceId.uuid(), deviceId, guid) - .join(); - - assertTrue(removedMessage.isPresent()); - assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString())); assertTrue(get(destinationServiceId.uuid(), deviceId, 1).isEmpty()); // updating the shared MRM data is purely async, so we just wait for it