diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 3169c2919..ad17607fb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -167,7 +167,6 @@ public class MessageController { private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes"); private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage"); - private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType"); private static final String SEND_MESSAGE_LATENCY_TIMER_NAME = MetricsUtil.name(MessageController.class, "sendMessageLatency"); private static final String EPHEMERAL_TAG_NAME = "ephemeral"; @@ -175,7 +174,6 @@ public class MessageController { private static final String AUTH_TYPE_TAG_NAME = "authType"; private static final String SENDER_COUNTRY_TAG_NAME = "senderCountry"; private static final String RATE_LIMIT_REASON_TAG_NAME = "rateLimitReason"; - private static final String ENVELOPE_TYPE_TAG_NAME = "envelopeType"; private static final String IDENTITY_TYPE_TAG_NAME = "identityType"; private static final String ENDPOINT_TYPE_TAG_NAME = "endpoint"; @@ -192,7 +190,7 @@ public class MessageController { private static final String ENDPOINT_TYPE_MULTI = "multi"; @VisibleForTesting - static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes(); + static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes(); private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes(); // The Signal desktop client (really, JavaScript in general) can handle message timestamps at most 100,000,000 days @@ -332,14 +330,9 @@ public class MessageController { int totalContentLength = 0; for (final IncomingMessage message : messages.messages()) { - int contentLength = 0; - - if (StringUtils.isNotEmpty(message.content())) { - contentLength += message.content().length(); - } + final int contentLength = decodedSize(message.content()); validateContentLength(contentLength, false, isSyncMessage, isStory, userAgent); - validateEnvelopeType(message.type(), userAgent); totalContentLength += contentLength; } @@ -971,12 +964,18 @@ public class MessageController { } } - private void validateEnvelopeType(final int type, final String userAgent) { - if (type == Type.SERVER_DELIVERY_RECEIPT_VALUE) { - Metrics.counter(REJECT_INVALID_ENVELOPE_TYPE, - Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(ENVELOPE_TYPE_TAG_NAME, String.valueOf(type)))) - .increment(); - throw new BadRequestException("reserved envelope type"); + @VisibleForTesting + static int decodedSize(final String base64) { + final int padding; + + if (StringUtils.endsWith(base64, "==")) { + padding = 2; + } else if (StringUtils.endsWith(base64, "=")) { + padding = 1; + } else { + padding = 0; } + + return ((StringUtils.length(base64) - padding) * 3) / 4; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index 3422d44a5..5ade8597d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -5,15 +5,21 @@ package org.whispersystems.textsecuregcm.entities; import com.google.protobuf.ByteString; +import io.micrometer.core.instrument.Metrics; +import jakarta.validation.constraints.AssertTrue; import java.util.Base64; import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.storage.Account; public record IncomingMessage(int type, byte destinationDeviceId, int destinationRegistrationId, String content) { + private static final String REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME = + MetricsUtil.name(IncomingMessage.class, "rejectInvalidEnvelopeType"); + public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier, @Nullable Account sourceAccount, @Nullable Byte sourceDeviceId, @@ -23,15 +29,10 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio final boolean urgent, @Nullable byte[] reportSpamToken) { - final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type()); - - if (envelopeType == null) { - throw new IllegalArgumentException("Bad envelope type: " + type()); - } - final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder(); - envelopeBuilder.setType(envelopeType) + envelopeBuilder + .setType(MessageProtos.Envelope.Type.forNumber(type)) .setClientTimestamp(timestamp) .setServerTimestamp(System.currentTimeMillis()) .setDestinationServiceId(destinationIdentifier.toServiceIdentifierString()) @@ -55,4 +56,17 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio return envelopeBuilder.build(); } + + @AssertTrue + public boolean isValidEnvelopeType() { + if (type() == MessageProtos.Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE || + MessageProtos.Envelope.Type.forNumber(type()) == null) { + + Metrics.counter(REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME).increment(); + + return false; + } + + return true; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java index 440b50fcc..7af872970 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java @@ -8,18 +8,16 @@ import static com.codahale.metrics.MetricRegistry.name; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import jakarta.validation.constraints.Max; -import jakarta.validation.constraints.PositiveOrZero; -import org.whispersystems.textsecuregcm.controllers.MessageController; - import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; - -import java.util.List; import jakarta.validation.Valid; import jakarta.validation.constraints.AssertTrue; +import jakarta.validation.constraints.Max; import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.PositiveOrZero; +import java.util.List; +import java.util.Objects; +import org.whispersystems.textsecuregcm.controllers.MessageController; public record IncomingMessageList(@NotNull @Valid @@ -49,10 +47,14 @@ public record IncomingMessageList(@NotNull @AssertTrue public boolean hasNoDuplicateRecipients() { - boolean valid = messages.stream().filter(m -> m != null).map(IncomingMessage::destinationDeviceId).distinct().count() == messages.size(); + final boolean valid = messages.stream() + .filter(Objects::nonNull) + .map(IncomingMessage::destinationDeviceId).distinct().count() == messages.size(); + if (!valid) { REJECT_DUPLICATE_RECIPIENT_COUNTER.increment(); } + return valid; } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 80577aa55..ec5dc198a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -9,7 +9,6 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -48,7 +47,6 @@ import java.time.ZoneOffset; import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Base64; -import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -1141,7 +1139,7 @@ class MessageControllerTest { assertEquals(200, response.getStatus()); verify(messageSender).sendMessages(any(), any()); } else { - assertEquals(400, response.getStatus()); + assertEquals(422, response.getStatus()); verify(messageSender, never()).sendMessages(any(), any()); } } @@ -1662,4 +1660,13 @@ class MessageControllerTest { return builder.build(); } + @Test + void decodedSize() { + for (int size = MessageController.MAX_MESSAGE_SIZE - 3; size <= MessageController.MAX_MESSAGE_SIZE + 3; size++) { + final byte[] bytes = TestRandomUtil.nextBytes(size); + final String base64Encoded = Base64.getEncoder().encodeToString(bytes); + + assertEquals(bytes.length, MessageController.decodedSize(base64Encoded)); + } + } }