Validate parsed message size, not base64-encoded message size

This commit is contained in:
Jon Chambers 2025-02-10 17:13:24 -05:00 committed by GitHub
parent 908a41814b
commit 6032764052
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 55 additions and 33 deletions

View File

@ -167,7 +167,6 @@ public class MessageController {
private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes");
private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage");
private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType");
private static final String SEND_MESSAGE_LATENCY_TIMER_NAME = MetricsUtil.name(MessageController.class, "sendMessageLatency");
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
@ -175,7 +174,6 @@ public class MessageController {
private static final String AUTH_TYPE_TAG_NAME = "authType";
private static final String SENDER_COUNTRY_TAG_NAME = "senderCountry";
private static final String RATE_LIMIT_REASON_TAG_NAME = "rateLimitReason";
private static final String ENVELOPE_TYPE_TAG_NAME = "envelopeType";
private static final String IDENTITY_TYPE_TAG_NAME = "identityType";
private static final String ENDPOINT_TYPE_TAG_NAME = "endpoint";
@ -192,7 +190,7 @@ public class MessageController {
private static final String ENDPOINT_TYPE_MULTI = "multi";
@VisibleForTesting
static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes();
// The Signal desktop client (really, JavaScript in general) can handle message timestamps at most 100,000,000 days
@ -332,14 +330,9 @@ public class MessageController {
int totalContentLength = 0;
for (final IncomingMessage message : messages.messages()) {
int contentLength = 0;
if (StringUtils.isNotEmpty(message.content())) {
contentLength += message.content().length();
}
final int contentLength = decodedSize(message.content());
validateContentLength(contentLength, false, isSyncMessage, isStory, userAgent);
validateEnvelopeType(message.type(), userAgent);
totalContentLength += contentLength;
}
@ -971,12 +964,18 @@ public class MessageController {
}
}
private void validateEnvelopeType(final int type, final String userAgent) {
if (type == Type.SERVER_DELIVERY_RECEIPT_VALUE) {
Metrics.counter(REJECT_INVALID_ENVELOPE_TYPE,
Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(ENVELOPE_TYPE_TAG_NAME, String.valueOf(type))))
.increment();
throw new BadRequestException("reserved envelope type");
@VisibleForTesting
static int decodedSize(final String base64) {
final int padding;
if (StringUtils.endsWith(base64, "==")) {
padding = 2;
} else if (StringUtils.endsWith(base64, "=")) {
padding = 1;
} else {
padding = 0;
}
return ((StringUtils.length(base64) - padding) * 3) / 4;
}
}

View File

@ -5,15 +5,21 @@
package org.whispersystems.textsecuregcm.entities;
import com.google.protobuf.ByteString;
import io.micrometer.core.instrument.Metrics;
import jakarta.validation.constraints.AssertTrue;
import java.util.Base64;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Account;
public record IncomingMessage(int type, byte destinationDeviceId, int destinationRegistrationId, String content) {
private static final String REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME =
MetricsUtil.name(IncomingMessage.class, "rejectInvalidEnvelopeType");
public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier,
@Nullable Account sourceAccount,
@Nullable Byte sourceDeviceId,
@ -23,15 +29,10 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
final boolean urgent,
@Nullable byte[] reportSpamToken) {
final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type());
if (envelopeType == null) {
throw new IllegalArgumentException("Bad envelope type: " + type());
}
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder();
envelopeBuilder.setType(envelopeType)
envelopeBuilder
.setType(MessageProtos.Envelope.Type.forNumber(type))
.setClientTimestamp(timestamp)
.setServerTimestamp(System.currentTimeMillis())
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
@ -55,4 +56,17 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
return envelopeBuilder.build();
}
@AssertTrue
public boolean isValidEnvelopeType() {
if (type() == MessageProtos.Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE ||
MessageProtos.Envelope.Type.forNumber(type()) == null) {
Metrics.counter(REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME).increment();
return false;
}
return true;
}
}

View File

@ -8,18 +8,16 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.PositiveOrZero;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.util.List;
import jakarta.validation.Valid;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.PositiveOrZero;
import java.util.List;
import java.util.Objects;
import org.whispersystems.textsecuregcm.controllers.MessageController;
public record IncomingMessageList(@NotNull
@Valid
@ -49,10 +47,14 @@ public record IncomingMessageList(@NotNull
@AssertTrue
public boolean hasNoDuplicateRecipients() {
boolean valid = messages.stream().filter(m -> m != null).map(IncomingMessage::destinationDeviceId).distinct().count() == messages.size();
final boolean valid = messages.stream()
.filter(Objects::nonNull)
.map(IncomingMessage::destinationDeviceId).distinct().count() == messages.size();
if (!valid) {
REJECT_DUPLICATE_RECIPIENT_COUNTER.increment();
}
return valid;
}
}

View File

@ -9,7 +9,6 @@ import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@ -48,7 +47,6 @@ import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
@ -1141,7 +1139,7 @@ class MessageControllerTest {
assertEquals(200, response.getStatus());
verify(messageSender).sendMessages(any(), any());
} else {
assertEquals(400, response.getStatus());
assertEquals(422, response.getStatus());
verify(messageSender, never()).sendMessages(any(), any());
}
}
@ -1662,4 +1660,13 @@ class MessageControllerTest {
return builder.build();
}
@Test
void decodedSize() {
for (int size = MessageController.MAX_MESSAGE_SIZE - 3; size <= MessageController.MAX_MESSAGE_SIZE + 3; size++) {
final byte[] bytes = TestRandomUtil.nextBytes(size);
final String base64Encoded = Base64.getEncoder().encodeToString(bytes);
assertEquals(bytes.length, MessageController.decodedSize(base64Encoded));
}
}
}