Validate parsed message size, not base64-encoded message size
This commit is contained in:
parent
908a41814b
commit
6032764052
|
@ -167,7 +167,6 @@ public class MessageController {
|
|||
private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes");
|
||||
private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage");
|
||||
|
||||
private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType");
|
||||
private static final String SEND_MESSAGE_LATENCY_TIMER_NAME = MetricsUtil.name(MessageController.class, "sendMessageLatency");
|
||||
|
||||
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
|
||||
|
@ -175,7 +174,6 @@ public class MessageController {
|
|||
private static final String AUTH_TYPE_TAG_NAME = "authType";
|
||||
private static final String SENDER_COUNTRY_TAG_NAME = "senderCountry";
|
||||
private static final String RATE_LIMIT_REASON_TAG_NAME = "rateLimitReason";
|
||||
private static final String ENVELOPE_TYPE_TAG_NAME = "envelopeType";
|
||||
private static final String IDENTITY_TYPE_TAG_NAME = "identityType";
|
||||
private static final String ENDPOINT_TYPE_TAG_NAME = "endpoint";
|
||||
|
||||
|
@ -192,7 +190,7 @@ public class MessageController {
|
|||
private static final String ENDPOINT_TYPE_MULTI = "multi";
|
||||
|
||||
@VisibleForTesting
|
||||
static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
|
||||
static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
|
||||
private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes();
|
||||
|
||||
// The Signal desktop client (really, JavaScript in general) can handle message timestamps at most 100,000,000 days
|
||||
|
@ -332,14 +330,9 @@ public class MessageController {
|
|||
int totalContentLength = 0;
|
||||
|
||||
for (final IncomingMessage message : messages.messages()) {
|
||||
int contentLength = 0;
|
||||
|
||||
if (StringUtils.isNotEmpty(message.content())) {
|
||||
contentLength += message.content().length();
|
||||
}
|
||||
final int contentLength = decodedSize(message.content());
|
||||
|
||||
validateContentLength(contentLength, false, isSyncMessage, isStory, userAgent);
|
||||
validateEnvelopeType(message.type(), userAgent);
|
||||
|
||||
totalContentLength += contentLength;
|
||||
}
|
||||
|
@ -971,12 +964,18 @@ public class MessageController {
|
|||
}
|
||||
}
|
||||
|
||||
private void validateEnvelopeType(final int type, final String userAgent) {
|
||||
if (type == Type.SERVER_DELIVERY_RECEIPT_VALUE) {
|
||||
Metrics.counter(REJECT_INVALID_ENVELOPE_TYPE,
|
||||
Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(ENVELOPE_TYPE_TAG_NAME, String.valueOf(type))))
|
||||
.increment();
|
||||
throw new BadRequestException("reserved envelope type");
|
||||
@VisibleForTesting
|
||||
static int decodedSize(final String base64) {
|
||||
final int padding;
|
||||
|
||||
if (StringUtils.endsWith(base64, "==")) {
|
||||
padding = 2;
|
||||
} else if (StringUtils.endsWith(base64, "=")) {
|
||||
padding = 1;
|
||||
} else {
|
||||
padding = 0;
|
||||
}
|
||||
|
||||
return ((StringUtils.length(base64) - padding) * 3) / 4;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,15 +5,21 @@
|
|||
package org.whispersystems.textsecuregcm.entities;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import jakarta.validation.constraints.AssertTrue;
|
||||
import java.util.Base64;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
|
||||
public record IncomingMessage(int type, byte destinationDeviceId, int destinationRegistrationId, String content) {
|
||||
|
||||
private static final String REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME =
|
||||
MetricsUtil.name(IncomingMessage.class, "rejectInvalidEnvelopeType");
|
||||
|
||||
public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier,
|
||||
@Nullable Account sourceAccount,
|
||||
@Nullable Byte sourceDeviceId,
|
||||
|
@ -23,15 +29,10 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
|
|||
final boolean urgent,
|
||||
@Nullable byte[] reportSpamToken) {
|
||||
|
||||
final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type());
|
||||
|
||||
if (envelopeType == null) {
|
||||
throw new IllegalArgumentException("Bad envelope type: " + type());
|
||||
}
|
||||
|
||||
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder();
|
||||
|
||||
envelopeBuilder.setType(envelopeType)
|
||||
envelopeBuilder
|
||||
.setType(MessageProtos.Envelope.Type.forNumber(type))
|
||||
.setClientTimestamp(timestamp)
|
||||
.setServerTimestamp(System.currentTimeMillis())
|
||||
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
|
||||
|
@ -55,4 +56,17 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
|
|||
|
||||
return envelopeBuilder.build();
|
||||
}
|
||||
|
||||
@AssertTrue
|
||||
public boolean isValidEnvelopeType() {
|
||||
if (type() == MessageProtos.Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE ||
|
||||
MessageProtos.Envelope.Type.forNumber(type()) == null) {
|
||||
|
||||
Metrics.counter(REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME).increment();
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,18 +8,16 @@ import static com.codahale.metrics.MetricRegistry.name;
|
|||
|
||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
import jakarta.validation.constraints.Max;
|
||||
import jakarta.validation.constraints.PositiveOrZero;
|
||||
import org.whispersystems.textsecuregcm.controllers.MessageController;
|
||||
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
|
||||
import java.util.List;
|
||||
import jakarta.validation.Valid;
|
||||
import jakarta.validation.constraints.AssertTrue;
|
||||
import jakarta.validation.constraints.Max;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.PositiveOrZero;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import org.whispersystems.textsecuregcm.controllers.MessageController;
|
||||
|
||||
public record IncomingMessageList(@NotNull
|
||||
@Valid
|
||||
|
@ -49,10 +47,14 @@ public record IncomingMessageList(@NotNull
|
|||
|
||||
@AssertTrue
|
||||
public boolean hasNoDuplicateRecipients() {
|
||||
boolean valid = messages.stream().filter(m -> m != null).map(IncomingMessage::destinationDeviceId).distinct().count() == messages.size();
|
||||
final boolean valid = messages.stream()
|
||||
.filter(Objects::nonNull)
|
||||
.map(IncomingMessage::destinationDeviceId).distinct().count() == messages.size();
|
||||
|
||||
if (!valid) {
|
||||
REJECT_DUPLICATE_RECIPIENT_COUNTER.increment();
|
||||
}
|
||||
|
||||
return valid;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ import static org.hamcrest.CoreMatchers.equalTo;
|
|||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.CoreMatchers.not;
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
|
@ -48,7 +47,6 @@ import java.time.ZoneOffset;
|
|||
import java.time.temporal.ChronoUnit;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
@ -1141,7 +1139,7 @@ class MessageControllerTest {
|
|||
assertEquals(200, response.getStatus());
|
||||
verify(messageSender).sendMessages(any(), any());
|
||||
} else {
|
||||
assertEquals(400, response.getStatus());
|
||||
assertEquals(422, response.getStatus());
|
||||
verify(messageSender, never()).sendMessages(any(), any());
|
||||
}
|
||||
}
|
||||
|
@ -1662,4 +1660,13 @@ class MessageControllerTest {
|
|||
return builder.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void decodedSize() {
|
||||
for (int size = MessageController.MAX_MESSAGE_SIZE - 3; size <= MessageController.MAX_MESSAGE_SIZE + 3; size++) {
|
||||
final byte[] bytes = TestRandomUtil.nextBytes(size);
|
||||
final String base64Encoded = Base64.getEncoder().encodeToString(bytes);
|
||||
|
||||
assertEquals(bytes.length, MessageController.decodedSize(base64Encoded));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue