Skip shared multi-recipient message payloads for small messages

This commit is contained in:
Chris Eager 2025-03-21 14:26:27 -05:00 committed by Chris Eager
parent 9ef6024291
commit db2cd20dcb
2 changed files with 85 additions and 21 deletions

View File

@ -21,6 +21,7 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
import java.util.function.BiFunction;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -41,7 +42,9 @@ import reactor.core.publisher.Mono;
public class MessagesManager { public class MessagesManager {
private static final int RESULT_SET_CHUNK_SIZE = 100; private static final int RESULT_SET_CHUNK_SIZE = 100;
final String GET_MESSAGES_FOR_DEVICE_FLUX_NAME = name(MessagesManager.class, "getMessagesForDevice"); private final static String GET_MESSAGES_FOR_DEVICE_FLUX_NAME = name(MessagesManager.class, "getMessagesForDevice");
// shared payloads have some overhead, which sometimes exceeds the size if we just wrote the content directly
private static final int MULTI_RECIPIENT_MESSAGE_MINIMUM_SIZE_FOR_SHARED_PAYLOAD = 150;
private static final Logger logger = LoggerFactory.getLogger(MessagesManager.class); private static final Logger logger = LoggerFactory.getLogger(MessagesManager.class);
@ -139,17 +142,50 @@ public class MessagesManager {
final long serverTimestamp = clock.millis(); final long serverTimestamp = clock.millis();
return insertSharedMultiRecipientMessagePayload(multiRecipientMessage) final Envelope.Builder prototypeMessageBuilder = Envelope.newBuilder()
.thenCompose(sharedMrmKey -> { .setType(Envelope.Type.UNIDENTIFIED_SENDER)
final Envelope prototypeMessage = Envelope.newBuilder() .setClientTimestamp(clientTimestamp == 0 ? serverTimestamp : clientTimestamp)
.setType(Envelope.Type.UNIDENTIFIED_SENDER) .setServerTimestamp(serverTimestamp)
.setClientTimestamp(clientTimestamp == 0 ? serverTimestamp : clientTimestamp) .setStory(isStory)
.setServerTimestamp(serverTimestamp) .setEphemeral(isEphemeral)
.setStory(isStory) .setUrgent(isUrgent);
.setEphemeral(isEphemeral)
.setUrgent(isUrgent) final CompletableFuture<Envelope> prototypeMessageFuture;
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) final BiFunction<ServiceIdentifier, Envelope, Envelope> recipientEnvelopeBuilder;
// A shortcut -- message sizes do not vary by recipient in the current SealedSenderMultiRecipientMessage version
final int perRecipientMessageSize = multiRecipientMessage.getRecipients().values().stream().findAny()
.map(multiRecipientMessage::messageSizeForRecipient)
.orElse(0);
multiRecipientMessage.messageSizeForRecipient(
multiRecipientMessage.getRecipients().values().iterator().next());
if (perRecipientMessageSize >= MULTI_RECIPIENT_MESSAGE_MINIMUM_SIZE_FOR_SHARED_PAYLOAD) {
// the message is large enough that the shared payload overhead is worth it, so insert into the cache
prototypeMessageFuture = insertSharedMultiRecipientMessagePayload((multiRecipientMessage))
.thenApply(sharedMrmKey -> prototypeMessageBuilder
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey))
.build());
recipientEnvelopeBuilder = (serviceIdentifier, prototype) -> prototype.toBuilder()
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.build();
} else {
prototypeMessageFuture = CompletableFuture.completedFuture(prototypeMessageBuilder.build());
recipientEnvelopeBuilder = (serviceIdentifier, prototype) ->
prototype.toBuilder()
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setContent(ByteString.copyFrom(multiRecipientMessage.messageForRecipient(
multiRecipientMessage.getRecipients().get(serviceIdentifier.toLibsignal()))))
.build(); .build();
}
return prototypeMessageFuture
.thenCompose(prototypeMessage -> {
final Map<Account, Map<Byte, Boolean>> clientPresenceByAccountAndDevice = new ConcurrentHashMap<>(); final Map<Account, Map<Byte, Boolean>> clientPresenceByAccountAndDevice = new ConcurrentHashMap<>();
@ -162,9 +198,7 @@ public class MessagesManager {
return insertAsync(resolvedRecipients.get(recipient).getIdentifier(IdentityType.ACI), return insertAsync(resolvedRecipients.get(recipient).getIdentifier(IdentityType.ACI),
IntStream.range(0, devices.length).mapToObj(i -> devices[i]) IntStream.range(0, devices.length).mapToObj(i -> devices[i])
.collect(Collectors.toMap(deviceId -> deviceId, deviceId -> prototypeMessage.toBuilder() .collect(Collectors.toMap(deviceId -> deviceId, deviceId -> recipientEnvelopeBuilder.apply(serviceIdentifier, prototypeMessage))))
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.build())))
.thenAccept(clientPresenceByDeviceId -> .thenAccept(clientPresenceByDeviceId ->
clientPresenceByAccountAndDevice.put(resolvedRecipients.get(recipient), clientPresenceByAccountAndDevice.put(resolvedRecipients.get(recipient),
clientPresenceByDeviceId)); clientPresenceByDeviceId));

View File

@ -28,6 +28,9 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
@ -83,8 +86,15 @@ class MessagesManagerTest {
verifyNoMoreInteractions(reportMessageManager); verifyNoMoreInteractions(reportMessageManager);
} }
@Test @ParameterizedTest
void insertMultiRecipientMessage() throws InvalidMessageException, InvalidVersionException { @CsvSource({
"32, false",
"99, false",
"100, true",
"200, true",
"1024, true",
})
void insertMultiRecipientMessage(final int sharedPayloadSize, final boolean expectSharedMrm) throws InvalidMessageException, InvalidVersionException {
final ServiceIdentifier singleDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); final ServiceIdentifier singleDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final ServiceIdentifier singleDeviceAccountPniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID()); final ServiceIdentifier singleDeviceAccountPniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
final ServiceIdentifier multiDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); final ServiceIdentifier multiDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
@ -105,7 +115,7 @@ class MessagesManagerTest {
new TestRecipient(multiDeviceAccountAciServiceIdentifier, (byte) (Device.PRIMARY_ID + 1), 3, new byte[48]), new TestRecipient(multiDeviceAccountAciServiceIdentifier, (byte) (Device.PRIMARY_ID + 1), 3, new byte[48]),
new TestRecipient(unresolvedAccountAciServiceIdentifier, Device.PRIMARY_ID, 4, new byte[48]), new TestRecipient(unresolvedAccountAciServiceIdentifier, Device.PRIMARY_ID, 4, new byte[48]),
new TestRecipient(singleDeviceAccountPniServiceIdentifier, Device.PRIMARY_ID, 5, new byte[48]) new TestRecipient(singleDeviceAccountPniServiceIdentifier, Device.PRIMARY_ID, 5, new byte[48])
)); ), sharedPayloadSize);
final SealedSenderMultiRecipientMessage multiRecipientMessage = final SealedSenderMultiRecipientMessage multiRecipientMessage =
SealedSenderMultiRecipientMessage.parse(multiRecipientMessageBytes); SealedSenderMultiRecipientMessage.parse(multiRecipientMessageBytes);
@ -158,26 +168,46 @@ class MessagesManagerTest {
.setStory(isStory) .setStory(isStory)
.setEphemeral(isEphemeral) .setEphemeral(isEphemeral)
.setUrgent(isUrgent) .setUrgent(isUrgent)
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey))
.build(); .build();
final Map<ServiceIdentifier, Envelope> expectedEnvelopesByServiceIdentifier = Stream.of(singleDeviceAccountAciServiceIdentifier, singleDeviceAccountPniServiceIdentifier, multiDeviceAccountAciServiceIdentifier)
.collect(Collectors.toMap(
Function.identity(),
serviceIdentifier -> {
final Envelope.Builder envelopeBuilder = prototypeExpectedMessage.toBuilder()
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString());
if (expectSharedMrm) {
return envelopeBuilder
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey))
.build();
}
return envelopeBuilder.setContent(ByteString.copyFrom(multiRecipientMessage.messageForRecipient(
multiRecipientMessage.getRecipients().get(serviceIdentifier.toLibsignal()))))
.build();
}
));
assertEquals(expectedPresenceByAccountAndDeviceId, assertEquals(expectedPresenceByAccountAndDeviceId,
messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, isStory, isEphemeral, isUrgent).join()); messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, isStory, isEphemeral, isUrgent).join());
verify(messagesCache).insert(any(), verify(messagesCache).insert(any(),
eq(singleDeviceAccountAciServiceIdentifier.uuid()), eq(singleDeviceAccountAciServiceIdentifier.uuid()),
eq(Device.PRIMARY_ID), eq(Device.PRIMARY_ID),
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build())); eq(expectedEnvelopesByServiceIdentifier.get(singleDeviceAccountAciServiceIdentifier)));
verify(messagesCache).insert(any(), verify(messagesCache).insert(any(),
eq(singleDeviceAccountAciServiceIdentifier.uuid()), eq(singleDeviceAccountAciServiceIdentifier.uuid()),
eq(Device.PRIMARY_ID), eq(Device.PRIMARY_ID),
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountPniServiceIdentifier.toServiceIdentifierString()).build())); eq(expectedEnvelopesByServiceIdentifier.get(singleDeviceAccountPniServiceIdentifier)));
verify(messagesCache).insert(any(), verify(messagesCache).insert(any(),
eq(multiDeviceAccountAciServiceIdentifier.uuid()), eq(multiDeviceAccountAciServiceIdentifier.uuid()),
eq((byte) (Device.PRIMARY_ID + 1)), eq((byte) (Device.PRIMARY_ID + 1)),
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(multiDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build())); eq(expectedEnvelopesByServiceIdentifier.get(multiDeviceAccountAciServiceIdentifier)));
verify(messagesCache, never()).insert(any(), verify(messagesCache, never()).insert(any(),
eq(unresolvedAccountAciServiceIdentifier.uuid()), eq(unresolvedAccountAciServiceIdentifier.uuid()),