Add test for content length validation

This commit is contained in:
Chris Eager 2022-02-15 17:40:27 -08:00 committed by Chris Eager
parent 579eb85175
commit 73fa3c3fe4
5 changed files with 111 additions and 11 deletions

View File

@ -138,7 +138,8 @@ public class MessageController {
private static final String SENDER_TYPE_UNIDENTIFIED = "unidentified"; private static final String SENDER_TYPE_UNIDENTIFIED = "unidentified";
private static final String SENDER_TYPE_SELF = "self"; private static final String SENDER_TYPE_SELF = "self";
private static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes(); @VisibleForTesting
static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
public MessageController( public MessageController(
RateLimiters rateLimiters, RateLimiters rateLimiters,

View File

@ -4,33 +4,63 @@
*/ */
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
public class IncomingMessage { public class IncomingMessage {
@JsonProperty @JsonProperty
private int type; private int type;
@JsonProperty @JsonProperty
private String destination; private final String destination;
@JsonProperty @JsonProperty
private long destinationDeviceId = 1; private long destinationDeviceId;
@JsonProperty @JsonProperty
private int destinationRegistrationId; private int destinationRegistrationId;
@JsonProperty @JsonProperty
private String body; private final String body;
@JsonProperty @JsonProperty
private String content; private final String content;
@JsonProperty @JsonProperty
private String relay; private final String relay;
@JsonProperty @JsonProperty
private long timestamp; // deprecated private long timestamp; // deprecated
@JsonCreator
public IncomingMessage(
@JsonProperty("id") final Integer type,
@JsonProperty("destination") final String destination,
@JsonProperty("destinationDeviceId") final Long destinationDeviceId,
@JsonProperty("destinationRegistrationId") final Integer destinationRegistrationId,
@JsonProperty("body") final String body,
@JsonProperty("content") final String content,
@JsonProperty("relay") final String relay,
@JsonProperty("timestamp") final Long timestamp) {
if (type != null) {
this.type = type;
}
this.destination = destination;
if (destinationDeviceId != null) {
this.destinationDeviceId = destinationDeviceId;
}
if (destinationRegistrationId != null) {
this.destinationRegistrationId = destinationRegistrationId;
}
this.body = body;
this.content = content;
this.relay = relay;
if (timestamp != null) {
this.timestamp = timestamp;
}
}
public String getDestination() { public String getDestination() {
return destination; return destination;

View File

@ -4,6 +4,7 @@
*/ */
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List; import java.util.List;
import javax.validation.Valid; import javax.validation.Valid;
@ -14,7 +15,7 @@ public class IncomingMessageList {
@JsonProperty @JsonProperty
@NotNull @NotNull
@Valid @Valid
private List<@NotNull IncomingMessage> messages; private final List<@NotNull IncomingMessage> messages;
@JsonProperty @JsonProperty
private long timestamp; private long timestamp;
@ -22,7 +23,19 @@ public class IncomingMessageList {
@JsonProperty @JsonProperty
private boolean online; private boolean online;
public IncomingMessageList() {} @JsonCreator
public IncomingMessageList(
@JsonProperty("messages") final List<@NotNull IncomingMessage> messages,
@JsonProperty("online") final Boolean online,
@JsonProperty("timestamp") final Long timestamp) {
this.messages = messages;
if (timestamp != null) {
this.timestamp = timestamp;
}
if (online != null) {
this.online = online;
}
}
public List<IncomingMessage> getMessages() { public List<IncomingMessage> getMessages() {
return messages; return messages;

View File

@ -16,6 +16,7 @@ import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.NoContentException; import javax.ws.rs.core.NoContentException;
import javax.ws.rs.core.Response.Status;
import javax.ws.rs.ext.MessageBodyReader; import javax.ws.rs.ext.MessageBodyReader;
import javax.ws.rs.ext.Provider; import javax.ws.rs.ext.Provider;
import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage; import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage;
@ -66,7 +67,7 @@ public class MultiDeviceMessageListProvider extends BinaryProviderBase implement
long messageLength = readVarint(entityStream); long messageLength = readVarint(entityStream);
if (messageLength > MAX_MESSAGE_SIZE) { if (messageLength > MAX_MESSAGE_SIZE) {
throw new BadRequestException("Message body too large"); throw new WebApplicationException("Message body too large", Status.REQUEST_ENTITY_TOO_LARGE);
} }
byte[] contents = entityStream.readNBytes(Math.toIntExact(messageLength)); byte[] contents = entityStream.readNBytes(Math.toIntExact(messageLength));
if (contents.length != messageLength) { if (contents.length != messageLength) {

View File

@ -36,6 +36,7 @@ import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collection; import java.util.Collection;
import java.util.HashSet; import java.util.HashSet;
@ -63,6 +64,7 @@ import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
@ -612,6 +614,59 @@ class MessageControllerTest {
verify(reportMessageManager).report(senderNumber, messageGuid, AuthHelper.VALID_UUID); verify(reportMessageManager).report(senderNumber, messageGuid, AuthHelper.VALID_UUID);
} }
@ParameterizedTest
@MethodSource
void testValidateContentLength(Entity<?> payload) throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes()))
.put(payload);
assertThat("Bad response", response.getStatus(), is(equalTo(413)));
verify(messageSender, never()).sendMessage(any(Account.class), any(Device.class), any(Envelope.class),
anyBoolean());
}
private static Stream<Entity<?>> testValidateContentLength() {
final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1);
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
messageStream.write(1); // version
messageStream.write(1); // count
messageStream.write(1); // device ID
messageStream.writeBytes(new byte[]{(byte) 0, (byte) 111}); // registration ID
messageStream.write(1); // message type
writeVarint(contentLength, messageStream); // message length
final byte[] contentBytes = new byte[contentLength];
Arrays.fill(contentBytes, (byte) 1);
messageStream.writeBytes(contentBytes); // message contents
try {
return Stream.of(
Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, null, 1L, null, new String(contentBytes), null, null, null)), null, null),
MediaType.APPLICATION_JSON_TYPE),
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
);
} catch (Exception e) {
throw new AssertionError(e);
}
}
private static void writeVarint(int value, ByteArrayOutputStream outputStream) {
while (true) {
int bits = value & 0x7f;
value >>>= 7;
if (value == 0) {
outputStream.write((byte) bits);
return;
}
outputStream.write((byte) (bits | 0x80));
}
}
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testValidateEnvelopeType(String payloadFilename, boolean expectOk) throws Exception { void testValidateEnvelopeType(String payloadFilename, boolean expectOk) throws Exception {