From 50f681ffe8518bd2623ee5fcbdfa12ac39efcb9d Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Tue, 4 Mar 2025 11:08:18 -0500 Subject: [PATCH] Centralize message length validation --- .../controllers/MessageController.java | 60 ++++--------------- .../textsecuregcm/push/MessageSender.java | 52 ++++++++++++++++ .../push/MessageTooLargeException.java | 11 ++++ .../util/NoStackTraceException.java | 29 +++++++++ .../controllers/MessageControllerTest.java | 17 ++---- .../textsecuregcm/push/MessageSenderTest.java | 11 ++++ 6 files changed, 121 insertions(+), 59 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/push/MessageTooLargeException.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceException.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 85ad5c493..9969bf2e9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -7,11 +7,8 @@ package org.whispersystems.textsecuregcm.controllers; import static com.codahale.metrics.MetricRegistry.name; import com.codahale.metrics.annotation.Timed; -import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HttpHeaders; import io.dropwizard.auth.Auth; -import io.dropwizard.util.DataSize; -import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; @@ -102,6 +99,7 @@ import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.push.MessageTooLargeException; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -155,10 +153,7 @@ public class MessageController { private static final CompletableFuture[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0]; - private static final String REJECT_OVERSIZE_MESSAGE_COUNTER = name(MessageController.class, "rejectOversizeMessage"); - private static final String LARGE_BUT_NOT_OVERSIZE_MESSAGE_COUNTER = name(MessageController.class, "largeMessage"); private static final String SENT_MESSAGE_COUNTER_NAME = name(MessageController.class, "sentMessages"); - private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize"); private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes"); private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage"); @@ -184,10 +179,6 @@ public class MessageController { private static final String ENDPOINT_TYPE_SINGLE = "single"; private static final String ENDPOINT_TYPE_MULTI = "multi"; - @VisibleForTesting - static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes(); - private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes(); - // The Signal desktop client (really, JavaScript in general) can handle message timestamps at most 100,000,000 days // past the epoch; please see https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Date#the_epoch_timestamps_and_invalid_date // for additional details. @@ -325,7 +316,11 @@ public class MessageController { for (final IncomingMessage message : messages.messages()) { final int contentLength = message.content() != null ? message.content().length : 0; - validateContentLength(contentLength, false, isSyncMessage, isStory, userAgent); + try { + MessageSender.validateContentLength(contentLength, false, isSyncMessage, isStory, userAgent); + } catch (final MessageTooLargeException e) { + throw new WebApplicationException(Status.REQUEST_ENTITY_TOO_LARGE); + } totalContentLength += contentLength; } @@ -513,8 +508,13 @@ public class MessageController { } // Verify that the message isn't too large before performing more expensive validations - multiRecipientMessage.getRecipients().values().forEach(recipient -> - validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipient), true, false, isStory, userAgent)); + multiRecipientMessage.getRecipients().values().forEach(recipient -> { + try { + MessageSender.validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipient), true, false, isStory, userAgent); + } catch (final MessageTooLargeException e) { + throw new WebApplicationException(Status.REQUEST_ENTITY_TOO_LARGE); + } + }); // Check that the request is well-formed and doesn't contain repeated entries for the same device for the same // recipient @@ -920,38 +920,4 @@ public class MessageController { throw e; } } - - private void validateContentLength(final int contentLength, - final boolean isMultiRecipientMessage, - final boolean isSyncMessage, - final boolean isStory, - final String userAgent) { - - final boolean oversize = contentLength > MAX_MESSAGE_SIZE; - - DistributionSummary.builder(CONTENT_SIZE_DISTRIBUTION_NAME) - .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), - Tag.of("oversize", String.valueOf(oversize)), - Tag.of("multiRecipientMessage", String.valueOf(isMultiRecipientMessage)), - Tag.of("syncMessage", String.valueOf(isSyncMessage)), - Tag.of("story", String.valueOf(isStory)))) - .publishPercentileHistogram(true) - .register(Metrics.globalRegistry) - .record(contentLength); - - if (oversize) { - Metrics.counter(REJECT_OVERSIZE_MESSAGE_COUNTER, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), - Tag.of("multiRecipientMessage", String.valueOf(isMultiRecipientMessage)), - Tag.of("syncMessage", String.valueOf(isSyncMessage)), - Tag.of("story", String.valueOf(isStory)))) - .increment(); - throw new WebApplicationException(Status.REQUEST_ENTITY_TOO_LARGE); - } - if (contentLength > LARGE_MESSAGE_SIZE) { - Metrics.counter( - LARGE_BUT_NOT_OVERSIZE_MESSAGE_COUNTER, - Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of("multiRecipientMessage", String.valueOf(isMultiRecipientMessage)))) - .increment(); - } - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java index 5cd249a3a..f7e798e3a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java @@ -8,11 +8,18 @@ import static com.codahale.metrics.MetricRegistry.name; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import com.google.common.annotations.VisibleForTesting; +import io.dropwizard.util.DataSize; +import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import java.util.Map; import java.util.concurrent.CompletableFuture; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; +import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; @@ -34,6 +41,11 @@ public class MessageSender { private final MessagesManager messagesManager; private final PushNotificationManager pushNotificationManager; + // Note that these names deliberately reference `MessageController` for metric continuity + private static final String REJECT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "rejectOversizeMessage"); + private static final String LARGE_BUT_NOT_OVERSIZE_MESSAGE_COUNTER_NAME = name(MessageController.class, "largeMessage"); + private static final String CONTENT_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(MessageController.class, "messageContentSize"); + private static final String SEND_COUNTER_NAME = name(MessageSender.class, "sendMessage"); private static final String CHANNEL_TAG_NAME = "channel"; private static final String EPHEMERAL_TAG_NAME = "ephemeral"; @@ -42,6 +54,10 @@ public class MessageSender { private static final String STORY_TAG_NAME = "story"; private static final String SEALED_SENDER_TAG_NAME = "sealedSender"; + @VisibleForTesting + public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes(); + private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes(); + public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) { this.messagesManager = messagesManager; this.pushNotificationManager = pushNotificationManager; @@ -137,4 +153,40 @@ public class MessageSender { return "none"; } } + + public static void validateContentLength(final int contentLength, + final boolean isMultiRecipientMessage, + final boolean isSyncMessage, + final boolean isStory, + final String userAgent) throws MessageTooLargeException { + + final boolean oversize = contentLength > MAX_MESSAGE_SIZE; + + DistributionSummary.builder(CONTENT_SIZE_DISTRIBUTION_NAME) + .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), + Tag.of("oversize", String.valueOf(oversize)), + Tag.of("multiRecipientMessage", String.valueOf(isMultiRecipientMessage)), + Tag.of("syncMessage", String.valueOf(isSyncMessage)), + Tag.of("story", String.valueOf(isStory)))) + .publishPercentileHistogram(true) + .register(Metrics.globalRegistry) + .record(contentLength); + + if (oversize) { + Metrics.counter(REJECT_OVERSIZE_MESSAGE_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), + Tag.of("multiRecipientMessage", String.valueOf(isMultiRecipientMessage)), + Tag.of("syncMessage", String.valueOf(isSyncMessage)), + Tag.of("story", String.valueOf(isStory)))) + .increment(); + + throw new MessageTooLargeException(); + } + + if (contentLength > LARGE_MESSAGE_SIZE) { + Metrics.counter( + LARGE_BUT_NOT_OVERSIZE_MESSAGE_COUNTER_NAME, + Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of("multiRecipientMessage", String.valueOf(isMultiRecipientMessage)))) + .increment(); + } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageTooLargeException.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageTooLargeException.java new file mode 100644 index 000000000..232958c64 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageTooLargeException.java @@ -0,0 +1,11 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.push; + +import org.whispersystems.textsecuregcm.util.NoStackTraceException; + +public class MessageTooLargeException extends NoStackTraceException { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceException.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceException.java new file mode 100644 index 000000000..4d49402fd --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceException.java @@ -0,0 +1,29 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +/** + * An abstract base class for exceptions that do not include a stack trace. Stackless exceptions are generally intended + * for internal error-handling cases where the error will never be logged or otherwise reported. + */ +public abstract class NoStackTraceException extends Exception { + + public NoStackTraceException() { + super(null, null, true, false); + } + + public NoStackTraceException(final String message) { + super(message, null, true, false); + } + + public NoStackTraceException(final String message, final Throwable cause) { + super(message, cause, true, false); + } + + public NoStackTraceException(final Throwable cause) { + super(null, cause, true, false); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 2549b5521..e258f322b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -29,7 +29,6 @@ import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; -import com.fasterxml.jackson.core.JsonProcessingException; import com.google.protobuf.ByteString; import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; @@ -490,7 +489,7 @@ class MessageControllerTest { // `long`) instead of the validation layer, we get a 400 instead of a 422 "99999999999999999999999999999999999, 400" }) - void testSingleDeviceExtremeTimestamp(final String timestamp, final int expectedStatus) throws JsonProcessingException { + void testSingleDeviceExtremeTimestamp(final String timestamp, final int expectedStatus) { final String jsonTemplate = """ { "timestamp" : %s, @@ -607,9 +606,7 @@ class MessageControllerTest { assertEquals(3, envelopeCaptor.getValue().size()); - envelopeCaptor.getValue().values().forEach(envelope -> { - assertTrue(envelope.getUrgent()); - }); + envelopeCaptor.getValue().values().forEach(envelope -> assertTrue(envelope.getUrgent())); } } @@ -633,9 +630,7 @@ class MessageControllerTest { assertEquals(3, envelopeCaptor.getValue().size()); - envelopeCaptor.getValue().values().forEach(envelope -> { - assertFalse(envelope.getUrgent()); - }); + envelopeCaptor.getValue().values().forEach(envelope -> assertFalse(envelope.getUrgent())); } } @@ -948,7 +943,7 @@ class MessageControllerTest { final UUID senderAci = UUID.randomUUID(); final UUID senderPni = UUID.randomUUID(); final String userAgent = "user-agent"; - UUID messageGuid = UUID.randomUUID(); + final UUID messageGuid = UUID.randomUUID(); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(senderAci); @@ -959,8 +954,6 @@ class MessageControllerTest { when(accountsManager.findRecentlyDeletedPhoneNumberIdentifier(senderAci)).thenReturn(Optional.of(senderPni)); when(phoneNumberIdentifiers.getPhoneNumber(senderPni)).thenReturn(CompletableFuture.completedFuture(List.of(senderNumber))); - messageGuid = UUID.randomUUID(); - try (final Response response = resources.getJerseyTest() .target(String.format("/v1/messages/report/%s/%s", senderAci, messageGuid)) @@ -1086,7 +1079,7 @@ class MessageControllerTest { @Test void testValidateContentLength() { - final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1); + final int contentLength = Math.toIntExact(MessageSender.MAX_MESSAGE_SIZE + 1); final byte[] contentBytes = new byte[contentLength]; Arrays.fill(contentBytes, (byte) 1); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java index 1544e2635..5282e98eb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.push; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyByte; @@ -24,6 +25,7 @@ import java.util.Map; import java.util.UUID; import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -172,4 +174,13 @@ class MessageSenderTest { return arguments; } + + @Test + void validateContentLength() { + assertThrows(MessageTooLargeException.class, () -> + MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE + 1, false, false, false, null)); + + assertDoesNotThrow(() -> + MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null)); + } }