diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index f9b04ac70..a4457c92f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -138,7 +138,8 @@ public class MessageController { private static final String SENDER_TYPE_UNIDENTIFIED = "unidentified"; private static final String SENDER_TYPE_SELF = "self"; - private static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes(); + @VisibleForTesting + static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes(); public MessageController( RateLimiters rateLimiters, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index cb84c5be2..3ddcc34c1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -4,33 +4,63 @@ */ package org.whispersystems.textsecuregcm.entities; +import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; public class IncomingMessage { @JsonProperty - private int type; + private int type; @JsonProperty - private String destination; + private final String destination; @JsonProperty - private long destinationDeviceId = 1; + private long destinationDeviceId; @JsonProperty private int destinationRegistrationId; @JsonProperty - private String body; + private final String body; @JsonProperty - private String content; + private final String content; @JsonProperty - private String relay; + private final String relay; @JsonProperty - private long timestamp; // deprecated + private long timestamp; // deprecated + + @JsonCreator + public IncomingMessage( + @JsonProperty("id") final Integer type, + @JsonProperty("destination") final String destination, + @JsonProperty("destinationDeviceId") final Long destinationDeviceId, + @JsonProperty("destinationRegistrationId") final Integer destinationRegistrationId, + @JsonProperty("body") final String body, + @JsonProperty("content") final String content, + @JsonProperty("relay") final String relay, + @JsonProperty("timestamp") final Long timestamp) { + if (type != null) { + this.type = type; + } + this.destination = destination; + + if (destinationDeviceId != null) { + this.destinationDeviceId = destinationDeviceId; + } + if (destinationRegistrationId != null) { + this.destinationRegistrationId = destinationRegistrationId; + } + this.body = body; + this.content = content; + this.relay = relay; + if (timestamp != null) { + this.timestamp = timestamp; + } + } public String getDestination() { return destination; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java index 3572ff299..555b59e92 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java @@ -4,6 +4,7 @@ */ package org.whispersystems.textsecuregcm.entities; +import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; import javax.validation.Valid; @@ -14,7 +15,7 @@ public class IncomingMessageList { @JsonProperty @NotNull @Valid - private List<@NotNull IncomingMessage> messages; + private final List<@NotNull IncomingMessage> messages; @JsonProperty private long timestamp; @@ -22,7 +23,19 @@ public class IncomingMessageList { @JsonProperty private boolean online; - public IncomingMessageList() {} + @JsonCreator + public IncomingMessageList( + @JsonProperty("messages") final List<@NotNull IncomingMessage> messages, + @JsonProperty("online") final Boolean online, + @JsonProperty("timestamp") final Long timestamp) { + this.messages = messages; + if (timestamp != null) { + this.timestamp = timestamp; + } + if (online != null) { + this.online = online; + } + } public List getMessages() { return messages; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiDeviceMessageListProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiDeviceMessageListProvider.java index e2ef4dce4..8649254bf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiDeviceMessageListProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiDeviceMessageListProvider.java @@ -16,6 +16,7 @@ import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.NoContentException; +import javax.ws.rs.core.Response.Status; import javax.ws.rs.ext.MessageBodyReader; import javax.ws.rs.ext.Provider; import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage; @@ -66,7 +67,7 @@ public class MultiDeviceMessageListProvider extends BinaryProviderBase implement long messageLength = readVarint(entityStream); if (messageLength > MAX_MESSAGE_SIZE) { - throw new BadRequestException("Message body too large"); + throw new WebApplicationException("Message body too large", Status.REQUEST_ENTITY_TOO_LARGE); } byte[] contents = entityStream.readNBytes(Math.toIntExact(messageLength)); if (contents.length != messageLength) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 3d78087df..2b9779442 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -36,6 +36,7 @@ import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import java.io.ByteArrayOutputStream; +import java.util.Arrays; import java.util.Base64; import java.util.Collection; import java.util.HashSet; @@ -63,6 +64,7 @@ import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; +import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; @@ -612,6 +614,59 @@ class MessageControllerTest { verify(reportMessageManager).report(senderNumber, messageGuid, AuthHelper.VALID_UUID); } + @ParameterizedTest + @MethodSource + void testValidateContentLength(Entity payload) throws Exception { + Response response = + resources.getJerseyTest() + .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) + .request() + .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) + .put(payload); + + assertThat("Bad response", response.getStatus(), is(equalTo(413))); + + verify(messageSender, never()).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), + anyBoolean()); + } + + private static Stream> testValidateContentLength() { + final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1); + ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); + messageStream.write(1); // version + messageStream.write(1); // count + messageStream.write(1); // device ID + messageStream.writeBytes(new byte[]{(byte) 0, (byte) 111}); // registration ID + messageStream.write(1); // message type + writeVarint(contentLength, messageStream); // message length + final byte[] contentBytes = new byte[contentLength]; + Arrays.fill(contentBytes, (byte) 1); + messageStream.writeBytes(contentBytes); // message contents + + try { + return Stream.of( + Entity.entity(new IncomingMessageList( + List.of(new IncomingMessage(1, null, 1L, null, new String(contentBytes), null, null, null)), null, null), + MediaType.APPLICATION_JSON_TYPE), + Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE) + ); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + private static void writeVarint(int value, ByteArrayOutputStream outputStream) { + while (true) { + int bits = value & 0x7f; + value >>>= 7; + if (value == 0) { + outputStream.write((byte) bits); + return; + } + outputStream.write((byte) (bits | 0x80)); + } + } + @ParameterizedTest @MethodSource void testValidateEnvelopeType(String payloadFilename, boolean expectOk) throws Exception {