diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 09a3dad11..92a3426d7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -21,7 +21,6 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.IntStream; import javax.annotation.Nullable; @@ -42,9 +41,7 @@ import reactor.core.publisher.Mono; public class MessagesManager { private static final int RESULT_SET_CHUNK_SIZE = 100; - private final static String GET_MESSAGES_FOR_DEVICE_FLUX_NAME = name(MessagesManager.class, "getMessagesForDevice"); - // shared payloads have some overhead, which sometimes exceeds the size if we just wrote the content directly - private static final int MULTI_RECIPIENT_MESSAGE_MINIMUM_SIZE_FOR_SHARED_PAYLOAD = 150; + final String GET_MESSAGES_FOR_DEVICE_FLUX_NAME = name(MessagesManager.class, "getMessagesForDevice"); private static final Logger logger = LoggerFactory.getLogger(MessagesManager.class); @@ -142,50 +139,17 @@ public class MessagesManager { final long serverTimestamp = clock.millis(); - final Envelope.Builder prototypeMessageBuilder = Envelope.newBuilder() - .setType(Envelope.Type.UNIDENTIFIED_SENDER) - .setClientTimestamp(clientTimestamp == 0 ? serverTimestamp : clientTimestamp) - .setServerTimestamp(serverTimestamp) - .setStory(isStory) - .setEphemeral(isEphemeral) - .setUrgent(isUrgent); - - final CompletableFuture prototypeMessageFuture; - final BiFunction recipientEnvelopeBuilder; - - // A shortcut -- message sizes do not vary by recipient in the current SealedSenderMultiRecipientMessage version - final int perRecipientMessageSize = multiRecipientMessage.getRecipients().values().stream().findAny() - .map(multiRecipientMessage::messageSizeForRecipient) - .orElse(0); - - multiRecipientMessage.messageSizeForRecipient( - multiRecipientMessage.getRecipients().values().iterator().next()); - if (perRecipientMessageSize >= MULTI_RECIPIENT_MESSAGE_MINIMUM_SIZE_FOR_SHARED_PAYLOAD) { - - // the message is large enough that the shared payload overhead is worth it, so insert into the cache - prototypeMessageFuture = insertSharedMultiRecipientMessagePayload((multiRecipientMessage)) - .thenApply(sharedMrmKey -> prototypeMessageBuilder - .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) - .build()); - - recipientEnvelopeBuilder = (serviceIdentifier, prototype) -> prototype.toBuilder() - .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) - .build(); - - } else { - - prototypeMessageFuture = CompletableFuture.completedFuture(prototypeMessageBuilder.build()); - - recipientEnvelopeBuilder = (serviceIdentifier, prototype) -> - prototype.toBuilder() - .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) - .setContent(ByteString.copyFrom(multiRecipientMessage.messageForRecipient( - multiRecipientMessage.getRecipients().get(serviceIdentifier.toLibsignal())))) + return insertSharedMultiRecipientMessagePayload(multiRecipientMessage) + .thenCompose(sharedMrmKey -> { + final Envelope prototypeMessage = Envelope.newBuilder() + .setType(Envelope.Type.UNIDENTIFIED_SENDER) + .setClientTimestamp(clientTimestamp == 0 ? serverTimestamp : clientTimestamp) + .setServerTimestamp(serverTimestamp) + .setStory(isStory) + .setEphemeral(isEphemeral) + .setUrgent(isUrgent) + .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) .build(); - } - - return prototypeMessageFuture - .thenCompose(prototypeMessage -> { final Map> clientPresenceByAccountAndDevice = new ConcurrentHashMap<>(); @@ -198,7 +162,9 @@ public class MessagesManager { return insertAsync(resolvedRecipients.get(recipient).getIdentifier(IdentityType.ACI), IntStream.range(0, devices.length).mapToObj(i -> devices[i]) - .collect(Collectors.toMap(deviceId -> deviceId, deviceId -> recipientEnvelopeBuilder.apply(serviceIdentifier, prototypeMessage)))) + .collect(Collectors.toMap(deviceId -> deviceId, deviceId -> prototypeMessage.toBuilder() + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) + .build()))) .thenAccept(clientPresenceByDeviceId -> clientPresenceByAccountAndDevice.put(resolvedRecipients.get(recipient), clientPresenceByDeviceId)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java index eebc6ec4c..c7691c0c8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java @@ -28,9 +28,6 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; import java.util.concurrent.ThreadLocalRandom; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -86,15 +83,8 @@ class MessagesManagerTest { verifyNoMoreInteractions(reportMessageManager); } - @ParameterizedTest - @CsvSource({ - "32, false", - "99, false", - "100, true", - "200, true", - "1024, true", - }) - void insertMultiRecipientMessage(final int sharedPayloadSize, final boolean expectSharedMrm) throws InvalidMessageException, InvalidVersionException { + @Test + void insertMultiRecipientMessage() throws InvalidMessageException, InvalidVersionException { final ServiceIdentifier singleDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); final ServiceIdentifier singleDeviceAccountPniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID()); final ServiceIdentifier multiDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); @@ -115,7 +105,7 @@ class MessagesManagerTest { new TestRecipient(multiDeviceAccountAciServiceIdentifier, (byte) (Device.PRIMARY_ID + 1), 3, new byte[48]), new TestRecipient(unresolvedAccountAciServiceIdentifier, Device.PRIMARY_ID, 4, new byte[48]), new TestRecipient(singleDeviceAccountPniServiceIdentifier, Device.PRIMARY_ID, 5, new byte[48]) - ), sharedPayloadSize); + )); final SealedSenderMultiRecipientMessage multiRecipientMessage = SealedSenderMultiRecipientMessage.parse(multiRecipientMessageBytes); @@ -168,46 +158,26 @@ class MessagesManagerTest { .setStory(isStory) .setEphemeral(isEphemeral) .setUrgent(isUrgent) + .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) .build(); - final Map expectedEnvelopesByServiceIdentifier = Stream.of(singleDeviceAccountAciServiceIdentifier, singleDeviceAccountPniServiceIdentifier, multiDeviceAccountAciServiceIdentifier) - .collect(Collectors.toMap( - Function.identity(), - serviceIdentifier -> { - - final Envelope.Builder envelopeBuilder = prototypeExpectedMessage.toBuilder() - .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()); - - if (expectSharedMrm) { - return envelopeBuilder - .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) - .build(); - } - - return envelopeBuilder.setContent(ByteString.copyFrom(multiRecipientMessage.messageForRecipient( - multiRecipientMessage.getRecipients().get(serviceIdentifier.toLibsignal())))) - .build(); - - } - )); - assertEquals(expectedPresenceByAccountAndDeviceId, messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, isStory, isEphemeral, isUrgent).join()); verify(messagesCache).insert(any(), eq(singleDeviceAccountAciServiceIdentifier.uuid()), eq(Device.PRIMARY_ID), - eq(expectedEnvelopesByServiceIdentifier.get(singleDeviceAccountAciServiceIdentifier))); + eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build())); verify(messagesCache).insert(any(), eq(singleDeviceAccountAciServiceIdentifier.uuid()), eq(Device.PRIMARY_ID), - eq(expectedEnvelopesByServiceIdentifier.get(singleDeviceAccountPniServiceIdentifier))); + eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountPniServiceIdentifier.toServiceIdentifierString()).build())); verify(messagesCache).insert(any(), eq(multiDeviceAccountAciServiceIdentifier.uuid()), eq((byte) (Device.PRIMARY_ID + 1)), - eq(expectedEnvelopesByServiceIdentifier.get(multiDeviceAccountAciServiceIdentifier))); + eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(multiDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build())); verify(messagesCache, never()).insert(any(), eq(unresolvedAccountAciServiceIdentifier.uuid()),