Add test for content length validation
This commit is contained in:
parent
579eb85175
commit
73fa3c3fe4
|
@ -138,7 +138,8 @@ public class MessageController {
|
||||||
private static final String SENDER_TYPE_UNIDENTIFIED = "unidentified";
|
private static final String SENDER_TYPE_UNIDENTIFIED = "unidentified";
|
||||||
private static final String SENDER_TYPE_SELF = "self";
|
private static final String SENDER_TYPE_SELF = "self";
|
||||||
|
|
||||||
private static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
|
@VisibleForTesting
|
||||||
|
static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
|
||||||
|
|
||||||
public MessageController(
|
public MessageController(
|
||||||
RateLimiters rateLimiters,
|
RateLimiters rateLimiters,
|
||||||
|
|
|
@ -4,33 +4,63 @@
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.entities;
|
package org.whispersystems.textsecuregcm.entities;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
public class IncomingMessage {
|
public class IncomingMessage {
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private int type;
|
private int type;
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private String destination;
|
private final String destination;
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private long destinationDeviceId = 1;
|
private long destinationDeviceId;
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private int destinationRegistrationId;
|
private int destinationRegistrationId;
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private String body;
|
private final String body;
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private String content;
|
private final String content;
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private String relay;
|
private final String relay;
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private long timestamp; // deprecated
|
private long timestamp; // deprecated
|
||||||
|
|
||||||
|
@JsonCreator
|
||||||
|
public IncomingMessage(
|
||||||
|
@JsonProperty("id") final Integer type,
|
||||||
|
@JsonProperty("destination") final String destination,
|
||||||
|
@JsonProperty("destinationDeviceId") final Long destinationDeviceId,
|
||||||
|
@JsonProperty("destinationRegistrationId") final Integer destinationRegistrationId,
|
||||||
|
@JsonProperty("body") final String body,
|
||||||
|
@JsonProperty("content") final String content,
|
||||||
|
@JsonProperty("relay") final String relay,
|
||||||
|
@JsonProperty("timestamp") final Long timestamp) {
|
||||||
|
if (type != null) {
|
||||||
|
this.type = type;
|
||||||
|
}
|
||||||
|
this.destination = destination;
|
||||||
|
|
||||||
|
if (destinationDeviceId != null) {
|
||||||
|
this.destinationDeviceId = destinationDeviceId;
|
||||||
|
}
|
||||||
|
if (destinationRegistrationId != null) {
|
||||||
|
this.destinationRegistrationId = destinationRegistrationId;
|
||||||
|
}
|
||||||
|
this.body = body;
|
||||||
|
this.content = content;
|
||||||
|
this.relay = relay;
|
||||||
|
if (timestamp != null) {
|
||||||
|
this.timestamp = timestamp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public String getDestination() {
|
public String getDestination() {
|
||||||
return destination;
|
return destination;
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.entities;
|
package org.whispersystems.textsecuregcm.entities;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import javax.validation.Valid;
|
import javax.validation.Valid;
|
||||||
|
@ -14,7 +15,7 @@ public class IncomingMessageList {
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
@NotNull
|
@NotNull
|
||||||
@Valid
|
@Valid
|
||||||
private List<@NotNull IncomingMessage> messages;
|
private final List<@NotNull IncomingMessage> messages;
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private long timestamp;
|
private long timestamp;
|
||||||
|
@ -22,7 +23,19 @@ public class IncomingMessageList {
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private boolean online;
|
private boolean online;
|
||||||
|
|
||||||
public IncomingMessageList() {}
|
@JsonCreator
|
||||||
|
public IncomingMessageList(
|
||||||
|
@JsonProperty("messages") final List<@NotNull IncomingMessage> messages,
|
||||||
|
@JsonProperty("online") final Boolean online,
|
||||||
|
@JsonProperty("timestamp") final Long timestamp) {
|
||||||
|
this.messages = messages;
|
||||||
|
if (timestamp != null) {
|
||||||
|
this.timestamp = timestamp;
|
||||||
|
}
|
||||||
|
if (online != null) {
|
||||||
|
this.online = online;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public List<IncomingMessage> getMessages() {
|
public List<IncomingMessage> getMessages() {
|
||||||
return messages;
|
return messages;
|
||||||
|
|
|
@ -16,6 +16,7 @@ import javax.ws.rs.WebApplicationException;
|
||||||
import javax.ws.rs.core.MediaType;
|
import javax.ws.rs.core.MediaType;
|
||||||
import javax.ws.rs.core.MultivaluedMap;
|
import javax.ws.rs.core.MultivaluedMap;
|
||||||
import javax.ws.rs.core.NoContentException;
|
import javax.ws.rs.core.NoContentException;
|
||||||
|
import javax.ws.rs.core.Response.Status;
|
||||||
import javax.ws.rs.ext.MessageBodyReader;
|
import javax.ws.rs.ext.MessageBodyReader;
|
||||||
import javax.ws.rs.ext.Provider;
|
import javax.ws.rs.ext.Provider;
|
||||||
import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage;
|
import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage;
|
||||||
|
@ -66,7 +67,7 @@ public class MultiDeviceMessageListProvider extends BinaryProviderBase implement
|
||||||
|
|
||||||
long messageLength = readVarint(entityStream);
|
long messageLength = readVarint(entityStream);
|
||||||
if (messageLength > MAX_MESSAGE_SIZE) {
|
if (messageLength > MAX_MESSAGE_SIZE) {
|
||||||
throw new BadRequestException("Message body too large");
|
throw new WebApplicationException("Message body too large", Status.REQUEST_ENTITY_TOO_LARGE);
|
||||||
}
|
}
|
||||||
byte[] contents = entityStream.readNBytes(Math.toIntExact(messageLength));
|
byte[] contents = entityStream.readNBytes(Math.toIntExact(messageLength));
|
||||||
if (contents.length != messageLength) {
|
if (contents.length != messageLength) {
|
||||||
|
|
|
@ -36,6 +36,7 @@ import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||||
import io.dropwizard.testing.junit5.ResourceExtension;
|
import io.dropwizard.testing.junit5.ResourceExtension;
|
||||||
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
|
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
@ -63,6 +64,7 @@ import org.mockito.ArgumentCaptor;
|
||||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||||
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
|
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
|
||||||
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
|
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
|
||||||
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||||
|
@ -612,6 +614,59 @@ class MessageControllerTest {
|
||||||
verify(reportMessageManager).report(senderNumber, messageGuid, AuthHelper.VALID_UUID);
|
verify(reportMessageManager).report(senderNumber, messageGuid, AuthHelper.VALID_UUID);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource
|
||||||
|
void testValidateContentLength(Entity<?> payload) throws Exception {
|
||||||
|
Response response =
|
||||||
|
resources.getJerseyTest()
|
||||||
|
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
|
||||||
|
.request()
|
||||||
|
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes()))
|
||||||
|
.put(payload);
|
||||||
|
|
||||||
|
assertThat("Bad response", response.getStatus(), is(equalTo(413)));
|
||||||
|
|
||||||
|
verify(messageSender, never()).sendMessage(any(Account.class), any(Device.class), any(Envelope.class),
|
||||||
|
anyBoolean());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Stream<Entity<?>> testValidateContentLength() {
|
||||||
|
final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1);
|
||||||
|
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
|
||||||
|
messageStream.write(1); // version
|
||||||
|
messageStream.write(1); // count
|
||||||
|
messageStream.write(1); // device ID
|
||||||
|
messageStream.writeBytes(new byte[]{(byte) 0, (byte) 111}); // registration ID
|
||||||
|
messageStream.write(1); // message type
|
||||||
|
writeVarint(contentLength, messageStream); // message length
|
||||||
|
final byte[] contentBytes = new byte[contentLength];
|
||||||
|
Arrays.fill(contentBytes, (byte) 1);
|
||||||
|
messageStream.writeBytes(contentBytes); // message contents
|
||||||
|
|
||||||
|
try {
|
||||||
|
return Stream.of(
|
||||||
|
Entity.entity(new IncomingMessageList(
|
||||||
|
List.of(new IncomingMessage(1, null, 1L, null, new String(contentBytes), null, null, null)), null, null),
|
||||||
|
MediaType.APPLICATION_JSON_TYPE),
|
||||||
|
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
|
||||||
|
);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new AssertionError(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void writeVarint(int value, ByteArrayOutputStream outputStream) {
|
||||||
|
while (true) {
|
||||||
|
int bits = value & 0x7f;
|
||||||
|
value >>>= 7;
|
||||||
|
if (value == 0) {
|
||||||
|
outputStream.write((byte) bits);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
outputStream.write((byte) (bits | 0x80));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource
|
@MethodSource
|
||||||
void testValidateEnvelopeType(String payloadFilename, boolean expectOk) throws Exception {
|
void testValidateEnvelopeType(String payloadFilename, boolean expectOk) throws Exception {
|
||||||
|
|
Loading…
Reference in New Issue