From db2cd20dcb400360895f924bf99f13731cbe776f Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Fri, 21 Mar 2025 14:26:27 -0500 Subject: [PATCH] Skip shared multi-recipient message payloads for small messages --- .../storage/MessagesManager.java | 62 ++++++++++++++----- .../storage/MessagesManagerTest.java | 44 ++++++++++--- 2 files changed, 85 insertions(+), 21 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 92a3426d7..09a3dad11 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -21,6 +21,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.IntStream; import javax.annotation.Nullable; @@ -41,7 +42,9 @@ import reactor.core.publisher.Mono; public class MessagesManager { private static final int RESULT_SET_CHUNK_SIZE = 100; - final String GET_MESSAGES_FOR_DEVICE_FLUX_NAME = name(MessagesManager.class, "getMessagesForDevice"); + private final static String GET_MESSAGES_FOR_DEVICE_FLUX_NAME = name(MessagesManager.class, "getMessagesForDevice"); + // shared payloads have some overhead, which sometimes exceeds the size if we just wrote the content directly + private static final int MULTI_RECIPIENT_MESSAGE_MINIMUM_SIZE_FOR_SHARED_PAYLOAD = 150; private static final Logger logger = LoggerFactory.getLogger(MessagesManager.class); @@ -139,17 +142,50 @@ public class MessagesManager { final long serverTimestamp = clock.millis(); - return insertSharedMultiRecipientMessagePayload(multiRecipientMessage) - .thenCompose(sharedMrmKey -> { - final Envelope prototypeMessage = Envelope.newBuilder() - .setType(Envelope.Type.UNIDENTIFIED_SENDER) - .setClientTimestamp(clientTimestamp == 0 ? serverTimestamp : clientTimestamp) - .setServerTimestamp(serverTimestamp) - .setStory(isStory) - .setEphemeral(isEphemeral) - .setUrgent(isUrgent) - .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) + final Envelope.Builder prototypeMessageBuilder = Envelope.newBuilder() + .setType(Envelope.Type.UNIDENTIFIED_SENDER) + .setClientTimestamp(clientTimestamp == 0 ? serverTimestamp : clientTimestamp) + .setServerTimestamp(serverTimestamp) + .setStory(isStory) + .setEphemeral(isEphemeral) + .setUrgent(isUrgent); + + final CompletableFuture prototypeMessageFuture; + final BiFunction recipientEnvelopeBuilder; + + // A shortcut -- message sizes do not vary by recipient in the current SealedSenderMultiRecipientMessage version + final int perRecipientMessageSize = multiRecipientMessage.getRecipients().values().stream().findAny() + .map(multiRecipientMessage::messageSizeForRecipient) + .orElse(0); + + multiRecipientMessage.messageSizeForRecipient( + multiRecipientMessage.getRecipients().values().iterator().next()); + if (perRecipientMessageSize >= MULTI_RECIPIENT_MESSAGE_MINIMUM_SIZE_FOR_SHARED_PAYLOAD) { + + // the message is large enough that the shared payload overhead is worth it, so insert into the cache + prototypeMessageFuture = insertSharedMultiRecipientMessagePayload((multiRecipientMessage)) + .thenApply(sharedMrmKey -> prototypeMessageBuilder + .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) + .build()); + + recipientEnvelopeBuilder = (serviceIdentifier, prototype) -> prototype.toBuilder() + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) + .build(); + + } else { + + prototypeMessageFuture = CompletableFuture.completedFuture(prototypeMessageBuilder.build()); + + recipientEnvelopeBuilder = (serviceIdentifier, prototype) -> + prototype.toBuilder() + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) + .setContent(ByteString.copyFrom(multiRecipientMessage.messageForRecipient( + multiRecipientMessage.getRecipients().get(serviceIdentifier.toLibsignal())))) .build(); + } + + return prototypeMessageFuture + .thenCompose(prototypeMessage -> { final Map> clientPresenceByAccountAndDevice = new ConcurrentHashMap<>(); @@ -162,9 +198,7 @@ public class MessagesManager { return insertAsync(resolvedRecipients.get(recipient).getIdentifier(IdentityType.ACI), IntStream.range(0, devices.length).mapToObj(i -> devices[i]) - .collect(Collectors.toMap(deviceId -> deviceId, deviceId -> prototypeMessage.toBuilder() - .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) - .build()))) + .collect(Collectors.toMap(deviceId -> deviceId, deviceId -> recipientEnvelopeBuilder.apply(serviceIdentifier, prototypeMessage)))) .thenAccept(clientPresenceByDeviceId -> clientPresenceByAccountAndDevice.put(resolvedRecipients.get(recipient), clientPresenceByDeviceId)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java index c7691c0c8..eebc6ec4c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java @@ -28,6 +28,9 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -83,8 +86,15 @@ class MessagesManagerTest { verifyNoMoreInteractions(reportMessageManager); } - @Test - void insertMultiRecipientMessage() throws InvalidMessageException, InvalidVersionException { + @ParameterizedTest + @CsvSource({ + "32, false", + "99, false", + "100, true", + "200, true", + "1024, true", + }) + void insertMultiRecipientMessage(final int sharedPayloadSize, final boolean expectSharedMrm) throws InvalidMessageException, InvalidVersionException { final ServiceIdentifier singleDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); final ServiceIdentifier singleDeviceAccountPniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID()); final ServiceIdentifier multiDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); @@ -105,7 +115,7 @@ class MessagesManagerTest { new TestRecipient(multiDeviceAccountAciServiceIdentifier, (byte) (Device.PRIMARY_ID + 1), 3, new byte[48]), new TestRecipient(unresolvedAccountAciServiceIdentifier, Device.PRIMARY_ID, 4, new byte[48]), new TestRecipient(singleDeviceAccountPniServiceIdentifier, Device.PRIMARY_ID, 5, new byte[48]) - )); + ), sharedPayloadSize); final SealedSenderMultiRecipientMessage multiRecipientMessage = SealedSenderMultiRecipientMessage.parse(multiRecipientMessageBytes); @@ -158,26 +168,46 @@ class MessagesManagerTest { .setStory(isStory) .setEphemeral(isEphemeral) .setUrgent(isUrgent) - .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) .build(); + final Map expectedEnvelopesByServiceIdentifier = Stream.of(singleDeviceAccountAciServiceIdentifier, singleDeviceAccountPniServiceIdentifier, multiDeviceAccountAciServiceIdentifier) + .collect(Collectors.toMap( + Function.identity(), + serviceIdentifier -> { + + final Envelope.Builder envelopeBuilder = prototypeExpectedMessage.toBuilder() + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()); + + if (expectSharedMrm) { + return envelopeBuilder + .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) + .build(); + } + + return envelopeBuilder.setContent(ByteString.copyFrom(multiRecipientMessage.messageForRecipient( + multiRecipientMessage.getRecipients().get(serviceIdentifier.toLibsignal())))) + .build(); + + } + )); + assertEquals(expectedPresenceByAccountAndDeviceId, messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, isStory, isEphemeral, isUrgent).join()); verify(messagesCache).insert(any(), eq(singleDeviceAccountAciServiceIdentifier.uuid()), eq(Device.PRIMARY_ID), - eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build())); + eq(expectedEnvelopesByServiceIdentifier.get(singleDeviceAccountAciServiceIdentifier))); verify(messagesCache).insert(any(), eq(singleDeviceAccountAciServiceIdentifier.uuid()), eq(Device.PRIMARY_ID), - eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountPniServiceIdentifier.toServiceIdentifierString()).build())); + eq(expectedEnvelopesByServiceIdentifier.get(singleDeviceAccountPniServiceIdentifier))); verify(messagesCache).insert(any(), eq(multiDeviceAccountAciServiceIdentifier.uuid()), eq((byte) (Device.PRIMARY_ID + 1)), - eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(multiDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build())); + eq(expectedEnvelopesByServiceIdentifier.get(multiDeviceAccountAciServiceIdentifier))); verify(messagesCache, never()).insert(any(), eq(unresolvedAccountAciServiceIdentifier.uuid()),