From c03d63acb8024872dd8f7e8d534d239e1a204795 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 7 Apr 2025 10:11:43 -0400 Subject: [PATCH] Centralize message size validation in actual message-sending methods --- .../controllers/MessageController.java | 34 +--- .../textsecuregcm/push/MessageSender.java | 55 +++++- .../textsecuregcm/push/ReceiptSender.java | 2 +- .../storage/ChangeNumberManager.java | 11 +- .../controllers/MessageControllerTest.java | 169 +++++++++++++----- .../textsecuregcm/push/MessageSenderTest.java | 12 +- .../storage/ChangeNumberManagerTest.java | 17 +- 7 files changed, 199 insertions(+), 101 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 231a570cf..befef1968 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -322,21 +322,10 @@ public class MessageController { final Optional reportSpamToken = spamCheckResult.token(); - int totalContentLength = 0; - - for (final IncomingMessage message : messages.messages()) { - final int contentLength = message.content().length; - - try { - MessageSender.validateContentLength(contentLength, false, isSyncMessage, isStory, userAgent); - } catch (final MessageTooLargeException e) { - throw new WebApplicationException(Status.REQUEST_ENTITY_TOO_LARGE); - } - - totalContentLength += contentLength; - } - try { + final int totalContentLength = + messages.messages().stream().mapToInt(message -> message.content().length).sum(); + rateLimiters.getInboundMessageBytes().validate(destinationIdentifier.uuid(), totalContentLength); } catch (final RateLimitExceededException e) { messageByteLimitEstimator.add(destinationIdentifier.uuid().toString()); @@ -409,7 +398,7 @@ public class MessageController { authType = AUTH_TYPE_ACCESS_KEY; } - messageSender.sendMessages(destination, destinationIdentifier, messagesByDeviceId, registrationIdsByDeviceId); + messageSender.sendMessages(destination, destinationIdentifier, messagesByDeviceId, registrationIdsByDeviceId, userAgent); Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE), @@ -433,6 +422,8 @@ public class MessageController { e.getMismatchedDevices().extraDeviceIds())) .build()); } + } catch (final MessageTooLargeException e) { + throw new WebApplicationException(Status.REQUEST_ENTITY_TOO_LARGE); } } finally { sample.stop(Timer.builder(SEND_MESSAGE_LATENCY_TIMER_NAME) @@ -505,15 +496,6 @@ public class MessageController { throw new BadRequestException("Recipient list is empty"); } - // Verify that the message isn't too large before performing more expensive validations - multiRecipientMessage.getRecipients().values().forEach(recipient -> { - try { - MessageSender.validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipient), true, false, isStory, userAgent); - } catch (final MessageTooLargeException e) { - throw new WebApplicationException(Status.REQUEST_ENTITY_TOO_LARGE); - } - }); - // Check that the request is well-formed and doesn't contain repeated entries for the same device for the same // recipient { @@ -616,7 +598,7 @@ public class MessageController { try { if (!resolvedRecipients.isEmpty()) { - messageSender.sendMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, timestamp, isStory, online, isUrgent).get(); + messageSender.sendMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, timestamp, isStory, online, isUrgent, userAgent).get(); } final List unresolvedRecipientServiceIds; @@ -695,6 +677,8 @@ public class MessageController { } throw new RuntimeException(e); + } catch (final MessageTooLargeException e) { + throw new WebApplicationException(Status.REQUEST_ENTITY_TOO_LARGE); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java index 402caa412..8649995e0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java @@ -35,6 +35,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.Util; +import javax.annotation.Nullable; /** * A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages, @@ -86,11 +87,18 @@ public class MessageSender { * @param destinationIdentifier the service identifier to which the messages are addressed * @param messagesByDeviceId a map of device IDs to message payloads * @param registrationIdsByDeviceId a map of device IDs to device registration IDs + * @param userAgent the User-Agent string for the sender; may be {@code null} if not known + * + * @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required + * devices, contained messages for devices not linked to the destination account, or devices with outdated + * registration IDs + * @throws MessageTooLargeException if the given message payload is too large */ public void sendMessages(final Account destination, final ServiceIdentifier destinationIdentifier, final Map messagesByDeviceId, - final Map registrationIdsByDeviceId) throws MismatchedDevicesException { + final Map registrationIdsByDeviceId, + @Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException { if (messagesByDeviceId.isEmpty()) { return; @@ -105,6 +113,10 @@ public class MessageSender { final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) && destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId())); + final boolean isStory = firstMessage.getStory(); + + validateIndividualMessageContentLength(messagesByDeviceId.values(), isSyncMessage, isStory, userAgent); + final Optional maybeMismatchedDevices = getMismatchedDevices(destination, destinationIdentifier, registrationIdsByDeviceId, @@ -152,15 +164,24 @@ public class MessageSender { * @param isEphemeral {@code true} if the message should only be delivered to devices with active connections or * {@code false otherwise} * @param isUrgent {@code true} if the message is urgent or {@code false otherwise} + * @param userAgent the User-Agent string for the sender; may be {@code null} if not known * * @return a future that completes when all messages have been inserted into delivery queues + * + * @throws MultiRecipientMismatchedDevicesException if the given multi-recipient message had did not have all required + * recipient devices for a recipient account, contained recipients for devices not linked to a destination account, or + * recipient devices with outdated registration IDs + * @throws MessageTooLargeException if the given message payload is too large */ public CompletableFuture sendMultiRecipientMessage(final SealedSenderMultiRecipientMessage multiRecipientMessage, final Map resolvedRecipients, final long clientTimestamp, final boolean isStory, final boolean isEphemeral, - final boolean isUrgent) throws MultiRecipientMismatchedDevicesException { + final boolean isUrgent, + @Nullable final String userAgent) throws MultiRecipientMismatchedDevicesException, MessageTooLargeException { + + validateMultiRecipientMessageContentLength(multiRecipientMessage, isStory, userAgent); final Map mismatchedDevicesByServiceIdentifier = new HashMap<>(); @@ -224,7 +245,8 @@ public class MessageSender { } } - public static void validateContentLength(final int contentLength, + @VisibleForTesting + static void validateContentLength(final int contentLength, final boolean isMultiRecipientMessage, final boolean isSyncMessage, final boolean isStory, @@ -302,4 +324,31 @@ public class MessageSender { ? Optional.of(new MismatchedDevices(missingDeviceIds, extraDeviceIds, staleDeviceIds)) : Optional.empty(); } + + private static void validateIndividualMessageContentLength(final Iterable messages, + final boolean isSyncMessage, + final boolean isStory, + @Nullable final String userAgent) throws MessageTooLargeException { + + for (final Envelope message : messages) { + MessageSender.validateContentLength(message.getContent().size(), + false, + isSyncMessage, + isStory, + userAgent); + } + } + + private static void validateMultiRecipientMessageContentLength(final SealedSenderMultiRecipientMessage multiRecipientMessage, + final boolean isStory, + @Nullable final String userAgent) throws MessageTooLargeException { + + for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) { + MessageSender.validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipient), + true, + false, + isStory, + userAgent); + } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java index 5d1b270a8..3ef17ca8c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java @@ -68,7 +68,7 @@ public class ReceiptSender { messageSender.sendMessages(destinationAccount, destinationIdentifier, messagesByDeviceId, - registrationIdsByDeviceId); + registrationIdsByDeviceId, null); } catch (final Exception e) { logger.warn("Could not send delivery receipt", e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index d0d0cffce..a77227079 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -19,7 +19,6 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; -import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageTooLargeException; @@ -96,14 +95,6 @@ public class ChangeNumberManager { final List deviceMessages, final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException { - for (final IncomingMessage message : deviceMessages) { - MessageSender.validateContentLength(message.content().length, - false, - true, - false, - senderUserAgent); - } - try { final long serverTimestamp = System.currentTimeMillis(); final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid()); @@ -125,7 +116,7 @@ public class ChangeNumberManager { final Map registrationIdsByDeviceId = account.getDevices().stream() .collect(Collectors.toMap(Device::getId, Device::getRegistrationId)); - messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId); + messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId, senderUserAgent); } catch (final RuntimeException e) { logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e); throw e; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index ce2524cb2..610c31eb8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -40,6 +40,7 @@ import jakarta.ws.rs.client.Invocation; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; +import java.io.IOException; import java.time.Duration; import java.time.Instant; import java.time.LocalDate; @@ -56,6 +57,7 @@ import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; +import javax.annotation.Nullable; import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.hamcrest.Matcher; @@ -73,7 +75,6 @@ import org.mockito.ArgumentCaptor; import org.signal.libsignal.zkgroup.ServerSecretParams; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; -import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; @@ -95,6 +96,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.push.MessageTooLargeException; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -121,7 +123,6 @@ import org.whispersystems.websocket.WebsocketHeaders; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; -import javax.annotation.Nullable; @ExtendWith(DropwizardExtensionsSupport.class) class MessageControllerTest { @@ -197,10 +198,10 @@ class MessageControllerTest { .build(); @BeforeEach - void setup() throws MultiRecipientMismatchedDevicesException { + void setup() throws MultiRecipientMismatchedDevicesException, MessageTooLargeException { reset(pushNotificationScheduler); - when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean())) + when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(null)); final List singleDeviceList = List.of( @@ -289,7 +290,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), captor.capture(), any()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -334,7 +335,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), captor.capture(), any()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -359,7 +360,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), captor.capture(), any()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -397,7 +398,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), captor.capture(), any()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -436,7 +437,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse))); if (expectedResponse == 200) { @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), captor.capture(), any()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -533,7 +534,7 @@ class MessageControllerTest { @Test void testMultiDeviceMissing() throws Exception { doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet()))) - .when(messageSender).sendMessages(any(), any(), any(), any()); + .when(messageSender).sendMessages(any(), any(), any(), any(), any()); try (final Response response = resources.getJerseyTest() @@ -555,7 +556,7 @@ class MessageControllerTest { @Test void testMultiDeviceExtra() throws Exception { doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet()))) - .when(messageSender).sendMessages(any(), any(), any(), any()); + .when(messageSender).sendMessages(any(), any(), any(), any(), any()); try (final Response response = resources.getJerseyTest() @@ -606,7 +607,7 @@ class MessageControllerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any()); + verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), any()); assertEquals(3, envelopeCaptor.getValue().size()); @@ -630,7 +631,7 @@ class MessageControllerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any()); + verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), any()); assertEquals(3, envelopeCaptor.getValue().size()); @@ -654,6 +655,7 @@ class MessageControllerTest { verify(messageSender).sendMessages(any(Account.class), any(), argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3), + any(), any()); } } @@ -661,7 +663,7 @@ class MessageControllerTest { @Test void testRegistrationIdMismatch() throws Exception { doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2)))) - .when(messageSender).sendMessages(any(), any(), any(), any()); + .when(messageSender).sendMessages(any(), any(), any(), any(), any()); try (final Response response = resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) @@ -1085,24 +1087,19 @@ class MessageControllerTest { } @Test - void testValidateContentLength() throws MismatchedDevicesException { - final int contentLength = Math.toIntExact(MessageSender.MAX_MESSAGE_SIZE + 1); - final byte[] contentBytes = new byte[contentLength]; - Arrays.fill(contentBytes, (byte) 1); + void testValidateContentLength() throws MismatchedDevicesException, MessageTooLargeException, IOException { + doThrow(new MessageTooLargeException()).when(messageSender).sendMessages(any(), any(), any(), any(), any()); try (final Response response = resources.getJerseyTest() .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .request() - .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) - .put(Entity.entity(new IncomingMessageList( - List.of(new IncomingMessage(1, (byte) 1, 1, contentBytes)), false, true, - System.currentTimeMillis()), + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"), + IncomingMessageList.class), MediaType.APPLICATION_JSON_TYPE))) { - assertThat("Bad response", response.getStatus(), is(equalTo(413))); - - verify(messageSender, never()).sendMessages(any(), any(), any(), any()); + assertThat(response.getStatus(), is(equalTo(413))); } } @@ -1120,10 +1117,10 @@ class MessageControllerTest { if (expectOk) { assertEquals(200, response.getStatus()); - verify(messageSender).sendMessages(any(), any(), any(), any()); + verify(messageSender).sendMessages(any(), any(), any(), any(), any()); } else { assertEquals(422, response.getStatus()); - verify(messageSender, never()).sendMessages(any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); } } } @@ -1149,7 +1146,7 @@ class MessageControllerTest { final Set expectedResolvedAccounts, final Set expectedUuids404, @Nullable final MultiRecipientMismatchedDevicesException mismatchedDevicesException) - throws MultiRecipientMismatchedDevicesException { + throws MultiRecipientMismatchedDevicesException, MessageTooLargeException { clock.pin(START_OF_DAY); @@ -1162,7 +1159,7 @@ class MessageControllerTest { if (mismatchedDevicesException != null) { doThrow(mismatchedDevicesException) - .when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()); + .when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); } final boolean ephemeral = true; @@ -1208,9 +1205,11 @@ class MessageControllerTest { anyLong(), eq(isStory), eq(ephemeral), - eq(urgent)); + eq(urgent), + any()); } else { - verify(messageSender, never()).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()); + verify(messageSender, never()) + .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); } } } @@ -1392,23 +1391,6 @@ class MessageControllerTest { Set.of(), null), - Arguments.argumentSet("Oversized payload", - accountsByServiceIdentifier, - MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( - new TestRecipient(new AciServiceIdentifier(singleDeviceAccountAci), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), - new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), Device.PRIMARY_ID, multiDevicePrimaryRegistrationId, new byte[48]), - new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), (byte) (Device.PRIMARY_ID + 1), multiDeviceLinkedRegistrationId, new byte[48])), - MultiRecipientMessageProvider.MAX_MESSAGE_SIZE), - clock.instant().toEpochMilli(), - false, - false, - Optional.empty(), - Optional.of(groupSendEndorsement), - 413, - Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of(), - null), - Arguments.argumentSet("Negative timestamp", accountsByServiceIdentifier, aciMessage, @@ -1646,6 +1628,97 @@ class MessageControllerTest { ); } + @Test + void sendMultiRecipientMessageOversized() throws Exception { + + clock.pin(START_OF_DAY); + + final UUID singleDeviceAccountAci = UUID.randomUUID(); + final UUID singleDeviceAccountPni = UUID.randomUUID(); + final UUID multiDeviceAccountAci = UUID.randomUUID(); + final UUID multiDeviceAccountPni = UUID.randomUUID(); + + final byte[] singleDeviceAccountUak = TestRandomUtil.nextBytes(UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH); + final byte[] multiDeviceAccountUak = TestRandomUtil.nextBytes(UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH); + + final int singleDevicePrimaryRegistrationId = 1; + final int multiDevicePrimaryRegistrationId = 2; + final int multiDeviceLinkedRegistrationId = 3; + + final Device singleDeviceAccountPrimary = mock(Device.class); + when(singleDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID); + when(singleDeviceAccountPrimary.getRegistrationId()).thenReturn(singleDevicePrimaryRegistrationId); + + final Device multiDeviceAccountPrimary = mock(Device.class); + when(multiDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID); + when(multiDeviceAccountPrimary.getRegistrationId()).thenReturn(multiDevicePrimaryRegistrationId); + + final Device multiDeviceAccountLinked = mock(Device.class); + when(multiDeviceAccountLinked.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1)); + when(multiDeviceAccountLinked.getRegistrationId()).thenReturn(multiDeviceLinkedRegistrationId); + + final Account singleDeviceAccount = mock(Account.class); + when(singleDeviceAccount.getIdentifier(IdentityType.ACI)).thenReturn(singleDeviceAccountAci); + when(singleDeviceAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(singleDeviceAccountUak)); + when(singleDeviceAccount.getDevices()).thenReturn(List.of(singleDeviceAccountPrimary)); + when(singleDeviceAccount.getDevice(anyByte())).thenReturn(Optional.empty()); + when(singleDeviceAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(singleDeviceAccountPrimary)); + + final Account multiDeviceAccount = mock(Account.class); + when(multiDeviceAccount.getIdentifier(IdentityType.ACI)).thenReturn(multiDeviceAccountAci); + when(multiDeviceAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(multiDeviceAccountUak)); + when(multiDeviceAccount.getDevices()).thenReturn(List.of(multiDeviceAccountPrimary, multiDeviceAccountLinked)); + when(multiDeviceAccount.getDevice(anyByte())).thenReturn(Optional.empty()); + when(multiDeviceAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(multiDeviceAccountPrimary)); + when(multiDeviceAccount.getDevice((byte) (Device.PRIMARY_ID + 1))).thenReturn(Optional.of(multiDeviceAccountLinked)); + + final Map accountsByServiceIdentifier = Map.of( + new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount, + new AciServiceIdentifier(multiDeviceAccountAci), multiDeviceAccount, + new PniServiceIdentifier(singleDeviceAccountPni), singleDeviceAccount, + new PniServiceIdentifier(multiDeviceAccountPni), multiDeviceAccount); + + final byte[] aciMessage = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(new AciServiceIdentifier(singleDeviceAccountAci), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), Device.PRIMARY_ID, multiDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), (byte) (Device.PRIMARY_ID + 1), multiDeviceLinkedRegistrationId, new byte[48]))); + + when(accountsManager.getByServiceIdentifierAsync(any())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + accountsByServiceIdentifier.forEach(((serviceIdentifier, account) -> + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))))); + + final boolean ephemeral = true; + final boolean urgent = false; + final boolean story = false; + + final Invocation.Builder invocationBuilder = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("ts", clock.millis()) + .queryParam("online", ephemeral) + .queryParam("story", story) + .queryParam("urgent", urgent) + .request() + .header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader(serverSecretParams, + List.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci)), + START_OF_DAY.plus(Duration.ofDays(1)))); + + when(rateLimiter.validateAsync(any(UUID.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + doThrow(new MessageTooLargeException()) + .when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + + try (final Response response = invocationBuilder + .put(Entity.entity(aciMessage, MultiRecipientMessageProvider.MEDIA_TYPE))) { + + assertThat(response.getStatus(), is(equalTo(413))); + } + } + @SuppressWarnings("SameParameterValue") private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java index 55b698723..f5df1cc23 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java @@ -104,7 +104,8 @@ class MessageSenderTest { assertDoesNotThrow(() -> messageSender.sendMessages(account, serviceIdentifier, Map.of(device.getId(), message), - Map.of(device.getId(), registrationId))); + Map.of(device.getId(), registrationId), + null)); final MessageProtos.Envelope expectedMessage = ephemeral ? message.toBuilder().setEphemeral(true).build() @@ -143,7 +144,8 @@ class MessageSenderTest { assertThrows(MismatchedDevicesException.class, () -> messageSender.sendMessages(account, serviceIdentifier, Map.of(device.getId(), message), - Map.of(device.getId(), registrationId + 1))); + Map.of(device.getId(), registrationId + 1), + null)); assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)), mismatchedDevicesException.getMismatchedDevices()); @@ -197,7 +199,8 @@ class MessageSenderTest { System.currentTimeMillis(), false, ephemeral, - urgent) + urgent, + null) .join()); if (expectPushNotificationAttempt) { @@ -243,7 +246,8 @@ class MessageSenderTest { System.currentTimeMillis(), false, false, - true) + true, + null) .join()); assertEquals(Map.of(serviceIdentifier, new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId))), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 2d82edb0e..85004c317 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -31,13 +31,10 @@ import org.mockito.stubbing.Answer; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; @@ -106,7 +103,7 @@ public class ChangeNumberManagerTest { changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null); verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null); verify(accountsManager, never()).updateDevice(any(), anyByte(), any()); - verify(messageSender, never()).sendMessages(eq(account), any(), any(), any()); + verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any()); } @Test @@ -120,7 +117,7 @@ public class ChangeNumberManagerTest { changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); - verify(messageSender, never()).sendMessages(eq(account), any(), any(), any()); + verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any()); } @Test @@ -160,7 +157,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -213,7 +210,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -264,7 +261,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -311,7 +308,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -360,7 +357,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());