Add test for content length validation
This commit is contained in:
		
							parent
							
								
									579eb85175
								
							
						
					
					
						commit
						73fa3c3fe4
					
				| 
						 | 
				
			
			@ -138,7 +138,8 @@ public class MessageController {
 | 
			
		|||
  private static final String SENDER_TYPE_UNIDENTIFIED = "unidentified";
 | 
			
		||||
  private static final String SENDER_TYPE_SELF = "self";
 | 
			
		||||
 | 
			
		||||
  private static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
 | 
			
		||||
  @VisibleForTesting
 | 
			
		||||
  static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
 | 
			
		||||
 | 
			
		||||
  public MessageController(
 | 
			
		||||
      RateLimiters rateLimiters,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,33 +4,63 @@
 | 
			
		|||
 */
 | 
			
		||||
package org.whispersystems.textsecuregcm.entities;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonCreator;
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
			
		||||
 | 
			
		||||
public class IncomingMessage {
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private int    type;
 | 
			
		||||
  private int type;
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private String destination;
 | 
			
		||||
  private final String destination;
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private long   destinationDeviceId = 1;
 | 
			
		||||
  private long destinationDeviceId;
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private int destinationRegistrationId;
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private String body;
 | 
			
		||||
  private final String body;
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private String content;
 | 
			
		||||
  private final String content;
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private String relay;
 | 
			
		||||
  private final String relay;
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private long   timestamp; // deprecated
 | 
			
		||||
  private long timestamp; // deprecated
 | 
			
		||||
 | 
			
		||||
  @JsonCreator
 | 
			
		||||
  public IncomingMessage(
 | 
			
		||||
      @JsonProperty("id") final Integer type,
 | 
			
		||||
      @JsonProperty("destination") final String destination,
 | 
			
		||||
      @JsonProperty("destinationDeviceId") final Long destinationDeviceId,
 | 
			
		||||
      @JsonProperty("destinationRegistrationId") final Integer destinationRegistrationId,
 | 
			
		||||
      @JsonProperty("body") final String body,
 | 
			
		||||
      @JsonProperty("content") final String content,
 | 
			
		||||
      @JsonProperty("relay") final String relay,
 | 
			
		||||
      @JsonProperty("timestamp") final Long timestamp) {
 | 
			
		||||
    if (type != null) {
 | 
			
		||||
      this.type = type;
 | 
			
		||||
    }
 | 
			
		||||
    this.destination = destination;
 | 
			
		||||
 | 
			
		||||
    if (destinationDeviceId != null) {
 | 
			
		||||
      this.destinationDeviceId = destinationDeviceId;
 | 
			
		||||
    }
 | 
			
		||||
    if (destinationRegistrationId != null) {
 | 
			
		||||
      this.destinationRegistrationId = destinationRegistrationId;
 | 
			
		||||
    }
 | 
			
		||||
    this.body = body;
 | 
			
		||||
    this.content = content;
 | 
			
		||||
    this.relay = relay;
 | 
			
		||||
    if (timestamp != null) {
 | 
			
		||||
      this.timestamp = timestamp;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public String getDestination() {
 | 
			
		||||
    return destination;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,6 +4,7 @@
 | 
			
		|||
 */
 | 
			
		||||
package org.whispersystems.textsecuregcm.entities;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonCreator;
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import javax.validation.Valid;
 | 
			
		||||
| 
						 | 
				
			
			@ -14,7 +15,7 @@ public class IncomingMessageList {
 | 
			
		|||
  @JsonProperty
 | 
			
		||||
  @NotNull
 | 
			
		||||
  @Valid
 | 
			
		||||
  private List<@NotNull IncomingMessage> messages;
 | 
			
		||||
  private final List<@NotNull IncomingMessage> messages;
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private long timestamp;
 | 
			
		||||
| 
						 | 
				
			
			@ -22,7 +23,19 @@ public class IncomingMessageList {
 | 
			
		|||
  @JsonProperty
 | 
			
		||||
  private boolean online;
 | 
			
		||||
 | 
			
		||||
  public IncomingMessageList() {}
 | 
			
		||||
  @JsonCreator
 | 
			
		||||
  public IncomingMessageList(
 | 
			
		||||
      @JsonProperty("messages") final List<@NotNull IncomingMessage> messages,
 | 
			
		||||
      @JsonProperty("online") final Boolean online,
 | 
			
		||||
      @JsonProperty("timestamp") final Long timestamp) {
 | 
			
		||||
    this.messages = messages;
 | 
			
		||||
    if (timestamp != null) {
 | 
			
		||||
      this.timestamp = timestamp;
 | 
			
		||||
    }
 | 
			
		||||
    if (online != null) {
 | 
			
		||||
      this.online = online;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public List<IncomingMessage> getMessages() {
 | 
			
		||||
    return messages;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,7 @@ import javax.ws.rs.WebApplicationException;
 | 
			
		|||
import javax.ws.rs.core.MediaType;
 | 
			
		||||
import javax.ws.rs.core.MultivaluedMap;
 | 
			
		||||
import javax.ws.rs.core.NoContentException;
 | 
			
		||||
import javax.ws.rs.core.Response.Status;
 | 
			
		||||
import javax.ws.rs.ext.MessageBodyReader;
 | 
			
		||||
import javax.ws.rs.ext.Provider;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage;
 | 
			
		||||
| 
						 | 
				
			
			@ -66,7 +67,7 @@ public class MultiDeviceMessageListProvider extends BinaryProviderBase implement
 | 
			
		|||
 | 
			
		||||
      long messageLength = readVarint(entityStream);
 | 
			
		||||
      if (messageLength > MAX_MESSAGE_SIZE) {
 | 
			
		||||
        throw new BadRequestException("Message body too large");
 | 
			
		||||
        throw new WebApplicationException("Message body too large", Status.REQUEST_ENTITY_TOO_LARGE);
 | 
			
		||||
      }
 | 
			
		||||
      byte[] contents = entityStream.readNBytes(Math.toIntExact(messageLength));
 | 
			
		||||
      if (contents.length != messageLength) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,6 +36,7 @@ import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
 | 
			
		|||
import io.dropwizard.testing.junit5.ResourceExtension;
 | 
			
		||||
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
 | 
			
		||||
import java.io.ByteArrayOutputStream;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.Base64;
 | 
			
		||||
import java.util.Collection;
 | 
			
		||||
import java.util.HashSet;
 | 
			
		||||
| 
						 | 
				
			
			@ -63,6 +64,7 @@ import org.mockito.ArgumentCaptor;
 | 
			
		|||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
 | 
			
		||||
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
 | 
			
		||||
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
 | 
			
		||||
| 
						 | 
				
			
			@ -612,6 +614,59 @@ class MessageControllerTest {
 | 
			
		|||
    verify(reportMessageManager).report(senderNumber, messageGuid, AuthHelper.VALID_UUID);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @ParameterizedTest
 | 
			
		||||
  @MethodSource
 | 
			
		||||
  void testValidateContentLength(Entity<?> payload) throws Exception {
 | 
			
		||||
    Response response =
 | 
			
		||||
        resources.getJerseyTest()
 | 
			
		||||
            .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
 | 
			
		||||
            .request()
 | 
			
		||||
            .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes()))
 | 
			
		||||
            .put(payload);
 | 
			
		||||
 | 
			
		||||
    assertThat("Bad response", response.getStatus(), is(equalTo(413)));
 | 
			
		||||
 | 
			
		||||
    verify(messageSender, never()).sendMessage(any(Account.class), any(Device.class), any(Envelope.class),
 | 
			
		||||
        anyBoolean());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private static Stream<Entity<?>> testValidateContentLength() {
 | 
			
		||||
    final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1);
 | 
			
		||||
    ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
 | 
			
		||||
    messageStream.write(1); // version
 | 
			
		||||
    messageStream.write(1); // count
 | 
			
		||||
    messageStream.write(1); // device ID
 | 
			
		||||
    messageStream.writeBytes(new byte[]{(byte) 0, (byte) 111}); // registration ID
 | 
			
		||||
    messageStream.write(1); // message type
 | 
			
		||||
    writeVarint(contentLength, messageStream); // message length
 | 
			
		||||
    final byte[] contentBytes = new byte[contentLength];
 | 
			
		||||
    Arrays.fill(contentBytes, (byte) 1);
 | 
			
		||||
    messageStream.writeBytes(contentBytes); // message contents
 | 
			
		||||
 | 
			
		||||
    try {
 | 
			
		||||
      return Stream.of(
 | 
			
		||||
          Entity.entity(new IncomingMessageList(
 | 
			
		||||
                  List.of(new IncomingMessage(1, null, 1L, null, new String(contentBytes), null, null, null)), null, null),
 | 
			
		||||
              MediaType.APPLICATION_JSON_TYPE),
 | 
			
		||||
          Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
 | 
			
		||||
      );
 | 
			
		||||
    } catch (Exception e) {
 | 
			
		||||
      throw new AssertionError(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private static void writeVarint(int value, ByteArrayOutputStream outputStream) {
 | 
			
		||||
    while (true) {
 | 
			
		||||
      int bits = value & 0x7f;
 | 
			
		||||
      value >>>= 7;
 | 
			
		||||
      if (value == 0) {
 | 
			
		||||
        outputStream.write((byte) bits);
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      outputStream.write((byte) (bits | 0x80));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @ParameterizedTest
 | 
			
		||||
  @MethodSource
 | 
			
		||||
  void testValidateEnvelopeType(String payloadFilename, boolean expectOk) throws Exception {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue