Retire the (unused!) binary message format

This commit is contained in:
Jon Chambers 2022-07-27 11:22:20 -04:00 committed by Jon Chambers
parent aa36dc95ef
commit e9119da040
9 changed files with 144 additions and 829 deletions

View File

@ -138,7 +138,6 @@ import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener; import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener;
import org.whispersystems.textsecuregcm.metrics.TrafficSource; import org.whispersystems.textsecuregcm.metrics.TrafficSource;
import org.whispersystems.textsecuregcm.providers.MultiDeviceMessageListProvider;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory; import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck; import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck;
@ -591,7 +590,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
environment.jersey().register(new ContentLengthFilter(TrafficSource.HTTP)); environment.jersey().register(new ContentLengthFilter(TrafficSource.HTTP));
environment.jersey().register(MultiDeviceMessageListProvider.class);
environment.jersey().register(MultiRecipientMessageProvider.class); environment.jersey().register(MultiRecipientMessageProvider.class);
environment.jersey().register(new MetricsApplicationEventListener(TrafficSource.HTTP)); environment.jersey().register(new MetricsApplicationEventListener(TrafficSource.HTTP));
environment.jersey() environment.jersey()
@ -613,7 +611,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
clientPresenceManager, websocketScheduledExecutor)); clientPresenceManager, websocketScheduledExecutor));
webSocketEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); webSocketEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
webSocketEnvironment.jersey().register(new ContentLengthFilter(TrafficSource.WEBSOCKET)); webSocketEnvironment.jersey().register(new ContentLengthFilter(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(MultiDeviceMessageListProvider.class);
webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class); webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class);
webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET)); webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager)); webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));

View File

@ -61,7 +61,6 @@ import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKey
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices; import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices; import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
@ -77,7 +76,6 @@ import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeException; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeException;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.providers.MultiDeviceMessageListProvider;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
@ -264,113 +262,6 @@ public class MessageController {
} }
} }
@Timed
@Path("/{destination}")
@PUT
@Consumes(MultiDeviceMessageListProvider.MEDIA_TYPE)
@Produces(MediaType.APPLICATION_JSON)
@FilterAbusiveMessages
public Response sendMultiDeviceMessage(@Auth Optional<AuthenticatedAccount> source,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam("User-Agent") String userAgent,
@HeaderParam("X-Forwarded-For") String forwardedFor,
@PathParam("destination") UUID destinationUuid,
@QueryParam("online") boolean online,
@QueryParam("ts") long timestamp,
@NotNull @Valid IncomingDeviceMessage[] messages)
throws RateLimitExceededException {
if (source.isEmpty() && accessKey.isEmpty()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
final String senderType;
if (source.isPresent()) {
if (source.get().getAccount().isIdentifiedBy(destinationUuid)) {
senderType = SENDER_TYPE_SELF;
} else {
senderType = SENDER_TYPE_IDENTIFIED;
}
} else {
senderType = SENDER_TYPE_UNIDENTIFIED;
}
for (final IncomingDeviceMessage message : messages) {
validateContentLength(message.getContent().length, userAgent);
validateEnvelopeType(message.getType(), userAgent);
}
try {
boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationUuid);
Optional<Account> destination;
if (!isSyncMessage) {
destination = accountsManager.getByAccountIdentifier(destinationUuid)
.or(() -> accountsManager.getByPhoneNumberIdentifier(destinationUuid));
} else {
destination = source.map(AuthenticatedAccount::getAccount);
}
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
assert (destination.isPresent());
if (source.isPresent() && !isSyncMessage) {
checkRateLimit(source.get(), destination.get(), userAgent);
}
final Set<Long> excludedDeviceIds;
if (isSyncMessage) {
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
} else {
excludedDeviceIds = Collections.emptySet();
}
DestinationDeviceValidator.validateCompleteDeviceList(
destination.get(),
Arrays.stream(messages).map(IncomingDeviceMessage::getDeviceId).collect(Collectors.toSet()),
excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(
destination.get(),
Arrays.stream(messages).toList(),
IncomingDeviceMessage::getDeviceId,
IncomingDeviceMessage::getRegistrationId,
destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
Tag.of(SENDER_TYPE_TAG_NAME, senderType));
for (final IncomingDeviceMessage message : messages) {
Optional<Device> destinationDevice = destination.get().getDevice(message.getDeviceId());
if (destinationDevice.isPresent()) {
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
sendMessage(source, destination.get(), destinationDevice.get(), destinationUuid, timestamp, online, message);
}
}
return Response.ok(new SendMessageResponse(
!isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1)).build();
} catch (NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build());
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
}
}
@Timed @Timed
@Path("/multi_recipient") @Path("/multi_recipient")
@PUT @PUT
@ -658,34 +549,6 @@ public class MessageController {
} }
} }
private void sendMessage(Optional<AuthenticatedAccount> source, Account destinationAccount, Device destinationDevice,
UUID destinationUuid, long timestamp, boolean online, IncomingDeviceMessage message) throws NoSuchUserException {
try {
Envelope.Builder messageBuilder = Envelope.newBuilder();
long serverTimestamp = System.currentTimeMillis();
messageBuilder
.setType(Envelope.Type.forNumber(message.getType()))
.setTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationUuid(destinationUuid.toString())
.setContent(ByteString.copyFrom(message.getContent()));
source.ifPresent(authenticatedAccount ->
messageBuilder.setSource(authenticatedAccount.getAccount().getNumber())
.setSourceUuid(authenticatedAccount.getAccount().getUuid().toString())
.setSourceDevice((int) authenticatedAccount.getAuthenticatedDevice().getId()));
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
} catch (NotPushRegisteredException e) {
if (destinationDevice.isMaster()) {
throw new NoSuchUserException(e);
} else {
logger.debug("Not registered", e);
}
}
}
private void sendMessage(Account destinationAccount, private void sendMessage(Account destinationAccount,
Device destinationDevice, Device destinationDevice,
long timestamp, long timestamp,

View File

@ -1,47 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
public class IncomingDeviceMessage {
private final int type;
@Min(1)
private final long deviceId;
@Min(0)
@Max(65536)
private final int registrationId;
@NotNull
private final byte[] content;
public IncomingDeviceMessage(int type, long deviceId, int registrationId, byte[] content) {
this.type = type;
this.deviceId = deviceId;
this.registrationId = registrationId;
this.content = content;
}
public int getType() {
return type;
}
public long getDeviceId() {
return deviceId;
}
public int getRegistrationId() {
return registrationId;
}
public byte[] getContent() {
return content;
}
}

View File

@ -1,90 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.providers;
import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.io.InputStream;
import java.util.UUID;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.WebApplicationException;
public abstract class BinaryProviderBase {
/**
* Reads a UUID in network byte order and converts to a UUID object.
*/
UUID readUuid(InputStream stream) throws IOException {
byte[] buffer = new byte[8];
int read = stream.readNBytes(buffer, 0, 8);
if (read != 8) {
throw new IOException("Insufficient bytes for UUID");
}
long msb = convertNetworkByteOrderToLong(buffer);
read = stream.readNBytes(buffer, 0, 8);
if (read != 8) {
throw new IOException("Insufficient bytes for UUID");
}
long lsb = convertNetworkByteOrderToLong(buffer);
return new UUID(msb, lsb);
}
private long convertNetworkByteOrderToLong(byte[] buffer) {
long result = 0;
for (int i = 0; i < 8; i++) {
result = (result << 8) | (buffer[i] & 0xFFL);
}
return result;
}
/**
* Reads a varint. A varint larger than 64 bits is rejected with a {@code WebApplicationException}. An
* {@code IOException} is thrown if the stream ends before we finish reading the varint.
*
* @return the varint value
*/
static long readVarint(InputStream stream) throws IOException, WebApplicationException {
boolean hasMore = true;
int currentOffset = 0;
long result = 0;
while (hasMore) {
if (currentOffset >= 64) {
throw new BadRequestException("varint is too large");
}
int b = stream.read();
if (b == -1) {
throw new IOException("Missing byte " + (currentOffset / 7) + " of varint");
}
if (currentOffset == 63 && (b & 0xFE) != 0) {
throw new BadRequestException("varint is too large");
}
hasMore = (b & 0x80) != 0;
result |= ((long)(b & 0x7F)) << currentOffset;
currentOffset += 7;
}
return result;
}
/**
* Reads two bytes with most significant byte first. Treats the value as unsigned so the range returned is
* {@code [0, 65535]}.
*/
@VisibleForTesting
static int readU16(InputStream stream) throws IOException {
int b1 = stream.read();
if (b1 == -1) {
throw new IOException("Missing byte 1 of U16");
}
int b2 = stream.read();
if (b2 == -1) {
throw new IOException("Missing byte 2 of U16");
}
return (b1 << 8) | b2;
}
}

View File

@ -1,81 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.providers;
import io.dropwizard.util.DataSizeUnit;
import java.io.IOException;
import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Type;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.NoContentException;
import javax.ws.rs.core.Response.Status;
import javax.ws.rs.ext.MessageBodyReader;
import javax.ws.rs.ext.Provider;
import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage;
@Provider
@Consumes(MultiDeviceMessageListProvider.MEDIA_TYPE)
public class MultiDeviceMessageListProvider extends BinaryProviderBase implements MessageBodyReader<IncomingDeviceMessage[]> {
public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mdml";
public static final int MAX_MESSAGE_COUNT = 50;
public static final int MAX_MESSAGE_SIZE = Math.toIntExact(DataSizeUnit.KIBIBYTES.toBytes(256));
public static final byte VERSION = 0x01;
@Override
public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
return MEDIA_TYPE.equals(mediaType.toString()) && IncomingDeviceMessage[].class.isAssignableFrom(type);
}
@Override
public IncomingDeviceMessage[]
readFrom(Class<IncomingDeviceMessage[]> resultType, Type genericType,
Annotation[] annotations, MediaType mediaType, MultivaluedMap<String, String> httpHeaders,
InputStream entityStream)
throws IOException, WebApplicationException {
int versionByte = entityStream.read();
if (versionByte == -1) {
throw new NoContentException("Empty body not allowed");
}
if (versionByte != VERSION) {
throw new BadRequestException("Unsupported version");
}
int count = entityStream.read();
if (count == -1) {
throw new IOException("Missing count");
}
if (count > MAX_MESSAGE_COUNT) {
throw new BadRequestException("Maximum recipient count exceeded");
}
IncomingDeviceMessage[] messages = new IncomingDeviceMessage[count];
for (int i = 0; i < count; i++) {
long deviceId = readVarint(entityStream);
int registrationId = readU16(entityStream);
int type = entityStream.read();
if (type == -1) {
throw new IOException("Unexpected end of stream reading message type");
}
long messageLength = readVarint(entityStream);
if (messageLength > MAX_MESSAGE_SIZE) {
throw new WebApplicationException("Message body too large", Status.REQUEST_ENTITY_TOO_LARGE);
}
byte[] contents = entityStream.readNBytes(Math.toIntExact(messageLength));
if (contents.length != messageLength) {
throw new IOException("Unexpected end of stream in the middle of message contents");
}
messages[i] = new IncomingDeviceMessage(type, deviceId, registrationId, contents);
}
return messages;
}
}

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.providers; package org.whispersystems.textsecuregcm.providers;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.util.DataSizeUnit; import io.dropwizard.util.DataSizeUnit;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -23,7 +24,7 @@ import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
@Provider @Provider
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE) @Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
public class MultiRecipientMessageProvider extends BinaryProviderBase implements MessageBodyReader<MultiRecipientMessage> { public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRecipientMessage> {
public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm"; public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm";
public static final int MAX_RECIPIENT_COUNT = 5000; public static final int MAX_RECIPIENT_COUNT = 5000;
@ -70,4 +71,78 @@ public class MultiRecipientMessageProvider extends BinaryProviderBase implements
} }
return new MultiRecipientMessage(recipients, commonPayload); return new MultiRecipientMessage(recipients, commonPayload);
} }
/**
* Reads a UUID in network byte order and converts to a UUID object.
*/
private UUID readUuid(InputStream stream) throws IOException {
byte[] buffer = new byte[8];
int read = stream.readNBytes(buffer, 0, 8);
if (read != 8) {
throw new IOException("Insufficient bytes for UUID");
}
long msb = convertNetworkByteOrderToLong(buffer);
read = stream.readNBytes(buffer, 0, 8);
if (read != 8) {
throw new IOException("Insufficient bytes for UUID");
}
long lsb = convertNetworkByteOrderToLong(buffer);
return new UUID(msb, lsb);
}
private long convertNetworkByteOrderToLong(byte[] buffer) {
long result = 0;
for (int i = 0; i < 8; i++) {
result = (result << 8) | (buffer[i] & 0xFFL);
}
return result;
}
/**
* Reads a varint. A varint larger than 64 bits is rejected with a {@code WebApplicationException}. An
* {@code IOException} is thrown if the stream ends before we finish reading the varint.
*
* @return the varint value
*/
private long readVarint(InputStream stream) throws IOException, WebApplicationException {
boolean hasMore = true;
int currentOffset = 0;
int result = 0;
while (hasMore) {
if (currentOffset >= 64) {
throw new BadRequestException("varint is too large");
}
int b = stream.read();
if (b == -1) {
throw new IOException("Missing byte " + (currentOffset / 7) + " of varint");
}
if (currentOffset == 63 && (b & 0xFE) != 0) {
throw new BadRequestException("varint is too large");
}
hasMore = (b & 0x80) != 0;
result |= (b & 0x7F) << currentOffset;
currentOffset += 7;
}
return result;
}
/**
* Reads two bytes with most significant byte first. Treats the value as unsigned so the range returned is
* {@code [0, 65535]}.
*/
@VisibleForTesting
static int readU16(InputStream stream) throws IOException {
int b1 = stream.read();
if (b1 == -1) {
throw new IOException("Missing byte 1 of U16");
}
int b2 = stream.read();
if (b2 == -1) {
throw new IOException("Missing byte 2 of U16");
}
return (b1 << 8) | b2;
}
} }

View File

@ -31,7 +31,6 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.ByteArrayOutputStream;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.LinkedList; import java.util.LinkedList;
@ -67,7 +66,6 @@ import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.providers.MultiDeviceMessageListProvider;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
@ -116,7 +114,6 @@ class MessageControllerTest {
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addProvider(RateLimitExceededExceptionMapper.class) .addProvider(RateLimitExceededExceptionMapper.class)
.addProvider(MultiDeviceMessageListProvider.class)
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource( .addResource(
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager,
@ -176,72 +173,30 @@ class MessageControllerTest {
); );
} }
private static Stream<Entity<?>> currentMessageSingleDevicePayloads() { @Test
ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); void testSendFromDisabledAccount() throws Exception {
messageStream.write(1); // version
messageStream.write(1); // count
messageStream.write(1); // device ID
messageStream.writeBytes(new byte[] { (byte)0, (byte)111 }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
try {
return Stream.of(
Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE),
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
);
} catch (Exception e) {
throw new AssertionError(e);
}
}
private static Stream<Entity<?>> currentMessageSingleDevicePayloadsPni() {
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
messageStream.write(1); // version
messageStream.write(1); // count
messageStream.write(1); // device ID
messageStream.writeBytes(new byte[] { (byte)0x04, (byte)0x57 }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
try {
return Stream.of(
Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE),
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
);
} catch (Exception e) {
throw new AssertionError(e);
}
}
@ParameterizedTest
@MethodSource("currentMessageSingleDevicePayloads")
void testSendFromDisabledAccount(Entity<?> payload) throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.put(payload); .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Unauthorized response", response.getStatus(), is(equalTo(401))); assertThat("Unauthorized response", response.getStatus(), is(equalTo(401)));
} }
@ParameterizedTest @Test
@MethodSource("currentMessageSingleDevicePayloads") void testSingleDeviceCurrent() throws Exception {
void testSingleDeviceCurrent(Entity<?> payload) throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(payload); .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@ -252,15 +207,16 @@ class MessageControllerTest {
assertTrue(captor.getValue().hasSourceDevice()); assertTrue(captor.getValue().hasSourceDevice());
} }
@ParameterizedTest @Test
@MethodSource("currentMessageSingleDevicePayloadsPni") void testSingleDeviceCurrentByPni() throws Exception {
void testSingleDeviceCurrentByPni(Entity<?> payload) throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_PNI)) .target(String.format("/v1/messages/%s", SINGLE_DEVICE_PNI))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(payload); .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@ -284,15 +240,16 @@ class MessageControllerTest {
assertThat("Bad request", response.getStatus(), is(equalTo(422))); assertThat("Bad request", response.getStatus(), is(equalTo(422)));
} }
@ParameterizedTest @Test
@MethodSource("currentMessageSingleDevicePayloads") void testSingleDeviceCurrentUnidentified() throws Exception {
void testSingleDeviceCurrentUnidentified(Entity<?> payload) throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes()))
.put(payload); .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@ -303,27 +260,29 @@ class MessageControllerTest {
assertFalse(captor.getValue().hasSourceDevice()); assertFalse(captor.getValue().hasSourceDevice());
} }
@ParameterizedTest @Test
@MethodSource("currentMessageSingleDevicePayloads") void testSendBadAuth() throws Exception {
void testSendBadAuth(Entity<?> payload) throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request() .request()
.put(payload); .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Good Response", response.getStatus(), is(equalTo(401))); assertThat("Good Response", response.getStatus(), is(equalTo(401)));
} }
@ParameterizedTest @Test
@MethodSource("currentMessageSingleDevicePayloads") void testMultiDeviceMissing() throws Exception {
void testMultiDeviceMissing(Entity<?> payload) throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(payload); .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Good Response Code", response.getStatus(), is(equalTo(409))); assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
@ -334,15 +293,16 @@ class MessageControllerTest {
verifyNoMoreInteractions(messageSender); verifyNoMoreInteractions(messageSender);
} }
@ParameterizedTest @Test
@MethodSource void testMultiDeviceExtra() throws Exception {
void testMultiDeviceExtra(Entity<?> payload) throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(payload); .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_extra_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Good Response Code", response.getStatus(), is(equalTo(409))); assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
@ -353,131 +313,47 @@ class MessageControllerTest {
verifyNoMoreInteractions(messageSender); verifyNoMoreInteractions(messageSender);
} }
private static Stream<Entity<?>> testMultiDeviceExtra() { @Test
ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); void testMultiDevice() throws Exception {
messageStream.write(1); // version
messageStream.write(2); // count
messageStream.write(1); // device ID
messageStream.writeBytes(new byte[] { (byte)0, (byte)111 }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
messageStream.write(3); // device ID
messageStream.writeBytes(new byte[] { (byte)0, (byte)111 }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
try {
return Stream.of(
Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_extra_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE),
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
);
} catch (Exception e) {
throw new AssertionError(e);
}
}
@ParameterizedTest
@MethodSource
void testMultiDevice(Entity<?> payload) throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(payload); .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_multi_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
verify(messageSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false)); verify(messageSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false));
} }
private static Stream<Entity<?>> testMultiDevice() { @Test
ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); void testMultiDeviceByPni() throws Exception {
messageStream.write(1); // version
messageStream.write(2); // count
messageStream.write(1); // device ID
messageStream.writeBytes(new byte[] { (byte)0, (byte)222 }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
messageStream.write(2); // device ID
messageStream.writeBytes(new byte[] { (byte)1, (byte)77 }); // registration ID (333 = 1 * 256 + 77)
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
try {
return Stream.of(
Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_multi_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE),
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
);
} catch (Exception e) {
throw new AssertionError(e);
}
}
@ParameterizedTest
@MethodSource
void testMultiDeviceByPni(Entity<?> payload) throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_PNI)) .target(String.format("/v1/messages/%s", MULTI_DEVICE_PNI))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(payload); .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_multi_device_pni.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
verify(messageSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false)); verify(messageSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false));
} }
private static Stream<Entity<?>> testMultiDeviceByPni() { @Test
ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); void testRegistrationIdMismatch() throws Exception {
messageStream.write(1); // version
messageStream.write(2); // count
messageStream.write(1); // device ID
messageStream.writeBytes(new byte[] { (byte)0x08, (byte)0xae }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
messageStream.write(2); // device ID
messageStream.writeBytes(new byte[] { (byte)0x0d, (byte)0x05 }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
try {
return Stream.of(
Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_multi_device_pni.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE),
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
);
} catch (Exception e) {
throw new AssertionError(e);
}
}
@ParameterizedTest
@MethodSource
void testRegistrationIdMismatch(Entity<?> payload) throws Exception {
Response response = Response response =
resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(payload); .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_registration_id.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Good Response Code", response.getStatus(), is(equalTo(410))); assertThat("Good Response Code", response.getStatus(), is(equalTo(410)));
@ -488,35 +364,6 @@ class MessageControllerTest {
verifyNoMoreInteractions(messageSender); verifyNoMoreInteractions(messageSender);
} }
private static Stream<Entity<?>> testRegistrationIdMismatch() {
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
messageStream.write(1); // version
messageStream.write(2); // count
messageStream.write(1); // device ID
messageStream.writeBytes(new byte[] { (byte)0, (byte)222 }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
messageStream.write(2); // device ID
messageStream.writeBytes(new byte[] { (byte)0, (byte)77 }); // wrong registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
try {
return Stream.of(
Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_registration_id.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE),
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
);
} catch (Exception e) {
throw new AssertionError(e);
}
}
@Test @Test
void testGetMessages() { void testGetMessages() {
@ -726,15 +573,21 @@ class MessageControllerTest {
messageGuid, AuthHelper.VALID_UUID); messageGuid, AuthHelper.VALID_UUID);
} }
@ParameterizedTest @Test
@MethodSource void testValidateContentLength() throws Exception {
void testValidateContentLength(Entity<?> payload) throws Exception { final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1);
final byte[] contentBytes = new byte[contentLength];
Arrays.fill(contentBytes, (byte) 1);
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes()))
.put(payload); .put(Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, null, 1L, 1, new String(contentBytes))), false,
System.currentTimeMillis()),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Bad response", response.getStatus(), is(equalTo(413))); assertThat("Bad response", response.getStatus(), is(equalTo(413)));
@ -742,44 +595,6 @@ class MessageControllerTest {
anyBoolean()); anyBoolean());
} }
private static Stream<Entity<?>> testValidateContentLength() {
final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1);
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
messageStream.write(1); // version
messageStream.write(1); // count
messageStream.write(1); // device ID
messageStream.writeBytes(new byte[]{(byte) 0, (byte) 111}); // registration ID
messageStream.write(1); // message type
writeVarint(contentLength, messageStream); // message length
final byte[] contentBytes = new byte[contentLength];
Arrays.fill(contentBytes, (byte) 1);
messageStream.writeBytes(contentBytes); // message contents
try {
return Stream.of(
Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, null, 1L, 1, new String(contentBytes))), false,
System.currentTimeMillis()),
MediaType.APPLICATION_JSON_TYPE),
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
);
} catch (Exception e) {
throw new AssertionError(e);
}
}
private static void writeVarint(int value, ByteArrayOutputStream outputStream) {
while (true) {
int bits = value & 0x7f;
value >>>= 7;
if (value == 0) {
outputStream.write((byte) bits);
return;
}
outputStream.write((byte) (bits | 0x80));
}
}
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testValidateEnvelopeType(String payloadFilename, boolean expectOk) throws Exception { void testValidateEnvelopeType(String payloadFilename, boolean expectOk) throws Exception {

View File

@ -1,169 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.providers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.annotation.Annotation;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.NoContentException;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage;
public class MultiDeviceMessageListProviderTest {
static byte[] createByteArray(int... bytes) {
byte[] result = new byte[bytes.length];
for (int i = 0; i < bytes.length; i++) {
result[i] = (byte) bytes[i];
}
return result;
}
IncomingDeviceMessage[] tryRead(byte[] bytes) throws Exception {
MultiDeviceMessageListProvider provider = new MultiDeviceMessageListProvider();
return provider.readFrom(
IncomingDeviceMessage[].class,
IncomingDeviceMessage[].class,
new Annotation[] {},
MediaType.valueOf(MultiDeviceMessageListProvider.MEDIA_TYPE),
new MultivaluedHashMap<>(),
new ByteArrayInputStream(bytes));
}
@Test
void testInvalidVersion() {
assertThatThrownBy(() -> tryRead(createByteArray()))
.isInstanceOf(NoContentException.class);
assertThatThrownBy(() -> tryRead(createByteArray(0x00)))
.isInstanceOf(WebApplicationException.class);
assertThatThrownBy(() -> tryRead(createByteArray(0x59)))
.isInstanceOf(WebApplicationException.class);
}
@Test
void testBadCount() {
assertThatThrownBy(() -> tryRead(createByteArray(MultiDeviceMessageListProvider.VERSION)))
.isInstanceOf(IOException.class);
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION,
MultiDeviceMessageListProvider.MAX_MESSAGE_COUNT + 1)))
.isInstanceOf(WebApplicationException.class);
}
@Test void testBadDeviceId() {
// Missing
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION, 1)))
.isInstanceOf(IOException.class);
// Unfinished varint
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION, 1, 0x80)))
.isInstanceOf(IOException.class);
}
@Test void testBadRegistrationId() {
// Missing
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1)))
.isInstanceOf(IOException.class);
// Truncated (fixed u16 value)
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11)))
.isInstanceOf(IOException.class);
}
@Test void testBadType() {
// Missing
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22)))
.isInstanceOf(IOException.class);
}
@Test void testBadMessageLength() {
// Missing
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1)))
.isInstanceOf(IOException.class);
// Unfinished varint
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 0x80)))
.isInstanceOf(IOException.class);
// Too big
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 0x80, 0x80, 0x80, 0x01)))
.isInstanceOf(WebApplicationException.class);
// Missing message
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 0x01)))
.isInstanceOf(IOException.class);
// Truncated message
assertThatThrownBy(() -> tryRead(createByteArray(
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 5, 0x01, 0x02)))
.isInstanceOf(IOException.class);
}
@Test
void readThreeMessages() throws Exception {
ByteArrayOutputStream contentStream = new ByteArrayOutputStream();
contentStream.write(MultiDeviceMessageListProvider.VERSION);
contentStream.write(3);
contentStream.writeBytes(createByteArray(0x85, 0x02)); // Device ID 0x0105
contentStream.writeBytes(createByteArray(0x11, 0x05)); // Registration ID 0x1105
contentStream.write(1);
contentStream.write(5);
contentStream.writeBytes(createByteArray(11, 22, 33, 44, 55));
contentStream.write(0x20); // Device ID 0x20
contentStream.writeBytes(createByteArray(0x30, 0x20)); // Registration ID 0x3020
contentStream.write(6);
contentStream.writeBytes(createByteArray(0x81, 0x01)); // 129 bytes
contentStream.writeBytes(new byte[129]);
contentStream.write(1); // Device ID 1
contentStream.writeBytes(createByteArray(0x00, 0x11)); // Registration ID 0x0011
contentStream.write(50);
contentStream.write(0); // empty message for some rateLimitReason
IncomingDeviceMessage[] messages = tryRead(contentStream.toByteArray());
assertThat(messages.length).isEqualTo(3);
assertThat(messages[0].getDeviceId()).isEqualTo(0x0105);
assertThat(messages[0].getRegistrationId()).isEqualTo(0x1105);
assertThat(messages[0].getType()).isEqualTo(1);
assertThat(messages[0].getContent()).containsExactly(11, 22, 33, 44, 55);
assertThat(messages[1].getDeviceId()).isEqualTo(0x20);
assertThat(messages[1].getRegistrationId()).isEqualTo(0x3020);
assertThat(messages[1].getType()).isEqualTo(6);
assertThat(messages[1].getContent()).containsExactly(new byte[129]);
assertThat(messages[2].getDeviceId()).isEqualTo(1);
assertThat(messages[2].getRegistrationId()).isEqualTo(0x0011);
assertThat(messages[2].getType()).isEqualTo(50);
assertThat(messages[2].getContent()).isEmpty();
}
@Test
void emptyListIsStillValid() throws Exception {
ByteArrayOutputStream contentStream = new ByteArrayOutputStream();
contentStream.write(MultiDeviceMessageListProvider.VERSION);
contentStream.write(0);
IncomingDeviceMessage[] messages = tryRead(contentStream.toByteArray());
assertThat(messages).isEmpty();
}
}

View File

@ -6,38 +6,28 @@
package org.whispersystems.textsecuregcm.providers; package org.whispersystems.textsecuregcm.providers;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.junit.jupiter.params.provider.Arguments.arguments;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.ws.rs.WebApplicationException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
public class MultiRecipientMessageProviderTest { public class MultiRecipientMessageProviderTest {
static byte[] createByteArray(int... bytes) { static byte[] createTwoByteArray(int b1, int b2) {
byte[] result = new byte[bytes.length]; return new byte[]{(byte) b1, (byte) b2};
for (int i = 0; i < bytes.length; i++) {
result[i] = (byte)bytes[i];
}
return result;
} }
static Stream<Arguments> readU16TestCases() { static Stream<Arguments> readU16TestCases() {
return Stream.of( return Stream.of(
arguments(0xFFFE, createByteArray(0xFF, 0xFE)), arguments(0xFFFE, createTwoByteArray(0xFF, 0xFE)),
arguments(0x0001, createByteArray(0x00, 0x01)), arguments(0x0001, createTwoByteArray(0x00, 0x01)),
arguments(0xBEEF, createByteArray(0xBE, 0xEF)), arguments(0xBEEF, createTwoByteArray(0xBE, 0xEF)),
arguments(0xFFFF, createByteArray(0xFF, 0xFF)), arguments(0xFFFF, createTwoByteArray(0xFF, 0xFF)),
arguments(0x0000, createByteArray(0x00, 0x00)), arguments(0x0000, createTwoByteArray(0x00, 0x00)),
arguments(0xF080, createByteArray(0xF0, 0x80)) arguments(0xF080, createTwoByteArray(0xF0, 0x80))
); );
} }
@ -48,42 +38,4 @@ public class MultiRecipientMessageProviderTest {
assertThat(MultiRecipientMessageProvider.readU16(stream)).isEqualTo(expectedValue); assertThat(MultiRecipientMessageProvider.readU16(stream)).isEqualTo(expectedValue);
} }
} }
static Stream<Arguments> readVarintTestCases() {
return Stream.of(
arguments(0x00L, createByteArray(0x00)),
arguments(0x01L, createByteArray(0x01)),
arguments(0x7FL, createByteArray(0x7F)),
arguments(0x80L, createByteArray(0x80, 0x01)),
arguments(0xFFL, createByteArray(0xFF, 0x01)),
arguments(0b1010101_0011001_1100110L, createByteArray(0b1_1100110, 0b1_0011001, 0b0_1010101)),
arguments(0x7FFFFFFF_FFFFFFFFL, createByteArray(0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F)),
arguments(-1L, createByteArray(0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01))
);
}
@ParameterizedTest
@MethodSource("readVarintTestCases")
void testReadVarint(long expectedValue, byte[] input) throws Exception {
try (final ByteArrayInputStream stream = new ByteArrayInputStream(input)) {
assertThat(MultiRecipientMessageProvider.readVarint(stream)).isEqualTo(expectedValue);
}
}
@Test
void testVarintEOF() {
assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray(0xFF, 0x80))))
.isInstanceOf(IOException.class);
assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray())))
.isInstanceOf(IOException.class);
}
@Test
void testVarintTooBig() {
assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray(0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02))))
.isInstanceOf(WebApplicationException.class);
assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray(0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80))))
.isInstanceOf(WebApplicationException.class);
}
} }