Add a binary format for incoming messages
The existing, general incoming message endpoint accepts messages as JSON strings containing base64 data, along with all the metadata as other JSON keys. That's not very efficient, and we don't make use of that full generality anyway. This commit introduces a new binary format that supports everything we're using from the old format (with the help of some query parameters like multi-recipient messages).
This commit is contained in:
parent
51bac394ec
commit
41bf2b2c42
|
@ -139,6 +139,7 @@ import org.whispersystems.textsecuregcm.metrics.NetworkSentGauge;
|
|||
import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge;
|
||||
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
|
||||
import org.whispersystems.textsecuregcm.metrics.TrafficSource;
|
||||
import org.whispersystems.textsecuregcm.providers.MultiDeviceMessageListProvider;
|
||||
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
|
||||
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
|
||||
import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck;
|
||||
|
@ -585,6 +586,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
|
||||
environment.jersey().register(new ContentLengthFilter(TrafficSource.HTTP));
|
||||
environment.jersey().register(acceptNumericOnlineFlagRequestFilter);
|
||||
environment.jersey().register(MultiDeviceMessageListProvider.class);
|
||||
environment.jersey().register(MultiRecipientMessageProvider.class);
|
||||
environment.jersey().register(new MetricsApplicationEventListener(TrafficSource.HTTP));
|
||||
environment.jersey()
|
||||
|
@ -607,6 +609,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
webSocketEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
|
||||
webSocketEnvironment.jersey().register(new ContentLengthFilter(TrafficSource.WEBSOCKET));
|
||||
webSocketEnvironment.jersey().register(acceptNumericOnlineFlagRequestFilter);
|
||||
webSocketEnvironment.jersey().register(MultiDeviceMessageListProvider.class);
|
||||
webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class);
|
||||
webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
|
||||
webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
|
||||
|
|
|
@ -62,6 +62,7 @@ import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKey
|
|||
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
|
@ -77,6 +78,7 @@ import org.whispersystems.textsecuregcm.entities.StaleDevices;
|
|||
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeException;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||
import org.whispersystems.textsecuregcm.providers.MultiDeviceMessageListProvider;
|
||||
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
|
||||
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
|
@ -218,23 +220,15 @@ public class MessageController {
|
|||
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
|
||||
assert (destination.isPresent());
|
||||
|
||||
if (source.isPresent() && !source.get().getAccount().isIdentifiedBy(destinationUuid)) {
|
||||
final String senderCountryCode = Util.getCountryCode(source.get().getAccount().getNumber());
|
||||
|
||||
try {
|
||||
rateLimiters.getMessagesLimiter().validate(source.get().getAccount().getUuid(), destination.get().getUuid());
|
||||
} catch (final RateLimitExceededException e) {
|
||||
Metrics.counter(RATE_LIMITED_MESSAGE_COUNTER_NAME,
|
||||
SENDER_COUNTRY_TAG_NAME, senderCountryCode,
|
||||
RATE_LIMIT_REASON_TAG_NAME, "singleDestinationRate").increment();
|
||||
|
||||
throw e;
|
||||
}
|
||||
if (source.isPresent() && !isSyncMessage) {
|
||||
checkRateLimit(source.get(), destination.get());
|
||||
}
|
||||
|
||||
validateCompleteDeviceList(destination.get(), messages.getMessages(), isSyncMessage,
|
||||
validateCompleteDeviceList(destination.get(), messages.getMessages(),
|
||||
IncomingMessage::getDestinationDeviceId, isSyncMessage,
|
||||
source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId));
|
||||
validateRegistrationIds(destination.get(), messages.getMessages());
|
||||
validateRegistrationIds(destination.get(), messages.getMessages(),
|
||||
IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId);
|
||||
|
||||
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
|
||||
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())),
|
||||
|
@ -245,8 +239,108 @@ public class MessageController {
|
|||
|
||||
if (destinationDevice.isPresent()) {
|
||||
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
|
||||
sendMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.getTimestamp(),
|
||||
messages.isOnline(), incomingMessage, userAgent);
|
||||
sendMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.getTimestamp(), messages.isOnline(), incomingMessage, userAgent);
|
||||
}
|
||||
}
|
||||
|
||||
return Response.ok(new SendMessageResponse(
|
||||
!isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1)).build();
|
||||
} catch (NoSuchUserException e) {
|
||||
throw new WebApplicationException(Response.status(404).build());
|
||||
} catch (MismatchedDevicesException e) {
|
||||
throw new WebApplicationException(Response.status(409)
|
||||
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||
.entity(new MismatchedDevices(e.getMissingDevices(),
|
||||
e.getExtraDevices()))
|
||||
.build());
|
||||
} catch (StaleDevicesException e) {
|
||||
throw new WebApplicationException(Response.status(410)
|
||||
.type(MediaType.APPLICATION_JSON)
|
||||
.entity(new StaleDevices(e.getStaleDevices()))
|
||||
.build());
|
||||
}
|
||||
}
|
||||
|
||||
@Timed
|
||||
@Path("/{destination}")
|
||||
@PUT
|
||||
@Consumes(MultiDeviceMessageListProvider.MEDIA_TYPE)
|
||||
@Produces(MediaType.APPLICATION_JSON)
|
||||
@FilterAbusiveMessages
|
||||
public Response sendMultiDeviceMessage(@Auth Optional<AuthenticatedAccount> source,
|
||||
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
|
||||
@HeaderParam("User-Agent") String userAgent,
|
||||
@HeaderParam("X-Forwarded-For") String forwardedFor,
|
||||
@PathParam("destination") UUID destinationUuid,
|
||||
@QueryParam("online") boolean online,
|
||||
@QueryParam("ts") long timestamp,
|
||||
@Valid IncomingDeviceMessage[] messages)
|
||||
throws RateLimitExceededException, RateLimitChallengeException {
|
||||
|
||||
if (source.isEmpty() && accessKey.isEmpty()) {
|
||||
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
|
||||
}
|
||||
|
||||
final String senderType;
|
||||
|
||||
if (source.isPresent()) {
|
||||
if (source.get().getAccount().isIdentifiedBy(destinationUuid)) {
|
||||
senderType = SENDER_TYPE_SELF;
|
||||
} else {
|
||||
senderType = SENDER_TYPE_IDENTIFIED;
|
||||
}
|
||||
} else {
|
||||
senderType = SENDER_TYPE_UNIDENTIFIED;
|
||||
}
|
||||
|
||||
for (final IncomingDeviceMessage message : messages) {
|
||||
int contentLength = message.getContent().length;
|
||||
|
||||
Metrics.summary(CONTENT_SIZE_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))).record(contentLength);
|
||||
|
||||
if (contentLength > MAX_MESSAGE_SIZE) {
|
||||
Metrics.counter(REJECT_OVERSIZE_MESSAGE_COUNTER, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))).increment();
|
||||
return Response.status(Response.Status.REQUEST_ENTITY_TOO_LARGE).build();
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationUuid);
|
||||
|
||||
Optional<Account> destination;
|
||||
|
||||
if (!isSyncMessage) {
|
||||
destination = accountsManager.getByAccountIdentifier(destinationUuid)
|
||||
.or(() -> accountsManager.getByPhoneNumberIdentifier(destinationUuid));
|
||||
} else {
|
||||
destination = source.map(AuthenticatedAccount::getAccount);
|
||||
}
|
||||
|
||||
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
|
||||
assert (destination.isPresent());
|
||||
|
||||
if (source.isPresent() && !isSyncMessage) {
|
||||
checkRateLimit(source.get(), destination.get());
|
||||
}
|
||||
|
||||
final List<IncomingDeviceMessage> messagesAsList = Arrays.asList(messages);
|
||||
validateCompleteDeviceList(destination.get(), messagesAsList,
|
||||
IncomingDeviceMessage::getDeviceId, isSyncMessage,
|
||||
source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId));
|
||||
validateRegistrationIds(destination.get(), messagesAsList,
|
||||
IncomingDeviceMessage::getDeviceId,
|
||||
IncomingDeviceMessage::getRegistrationId);
|
||||
|
||||
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
|
||||
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
|
||||
Tag.of(SENDER_TYPE_TAG_NAME, senderType));
|
||||
|
||||
for (final IncomingDeviceMessage message : messages) {
|
||||
Optional<Device> destinationDevice = destination.get().getDevice(message.getDeviceId());
|
||||
|
||||
if (destinationDevice.isPresent()) {
|
||||
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
|
||||
sendMessage(source, destination.get(), destinationDevice.get(), destinationUuid, timestamp, online, message, userAgent);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -544,6 +638,33 @@ public class MessageController {
|
|||
}
|
||||
}
|
||||
|
||||
private void sendMessage(Optional<AuthenticatedAccount> source, Account destinationAccount, Device destinationDevice, UUID destinationUuid, long timestamp, boolean online, IncomingDeviceMessage message, String userAgentString) throws NoSuchUserException {
|
||||
try {
|
||||
Envelope.Builder messageBuilder = Envelope.newBuilder();
|
||||
long serverTimestamp = System.currentTimeMillis();
|
||||
|
||||
messageBuilder
|
||||
.setType(Envelope.Type.forNumber(message.getType()))
|
||||
.setTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
|
||||
.setServerTimestamp(serverTimestamp)
|
||||
.setDestinationUuid(destinationUuid.toString())
|
||||
.setContent(ByteString.copyFrom(message.getContent()));
|
||||
|
||||
source.ifPresent(authenticatedAccount ->
|
||||
messageBuilder.setSource(authenticatedAccount.getAccount().getNumber())
|
||||
.setSourceUuid(authenticatedAccount.getAccount().getUuid().toString())
|
||||
.setSourceDevice((int) authenticatedAccount.getAuthenticatedDevice().getId()));
|
||||
|
||||
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
|
||||
} catch (NotPushRegisteredException e) {
|
||||
if (destinationDevice.isMaster()) {
|
||||
throw new NoSuchUserException(e);
|
||||
} else {
|
||||
logger.debug("Not registered", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void sendMessage(Account destinationAccount,
|
||||
Device destinationDevice,
|
||||
long timestamp,
|
||||
|
@ -577,12 +698,26 @@ public class MessageController {
|
|||
}
|
||||
}
|
||||
|
||||
private void checkRateLimit(AuthenticatedAccount source, Account destination) throws RateLimitExceededException {
|
||||
final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber());
|
||||
|
||||
try {
|
||||
rateLimiters.getMessagesLimiter().validate(source.getAccount().getUuid(), destination.getUuid());
|
||||
} catch (final RateLimitExceededException e) {
|
||||
Metrics.counter(RATE_LIMITED_MESSAGE_COUNTER_NAME,
|
||||
SENDER_COUNTRY_TAG_NAME, senderCountryCode,
|
||||
RATE_LIMIT_REASON_TAG_NAME, "singleDestinationRate").increment();
|
||||
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public static void validateRegistrationIds(Account account, List<IncomingMessage> messages)
|
||||
public static <T> void validateRegistrationIds(Account account, List<T> messages, Function<T, Long> getDeviceId, Function<T, Integer> getRegistrationId)
|
||||
throws StaleDevicesException {
|
||||
final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = messages
|
||||
.stream()
|
||||
.map(message -> new Pair<>(message.getDestinationDeviceId(), message.getDestinationRegistrationId()));
|
||||
.map(message -> new Pair<>(getDeviceId.apply(message), getRegistrationId.apply(message)));
|
||||
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
|
||||
}
|
||||
|
||||
|
@ -604,10 +739,10 @@ public class MessageController {
|
|||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public static void validateCompleteDeviceList(Account account, List<IncomingMessage> messages, boolean isSyncMessage,
|
||||
public static <T> void validateCompleteDeviceList(Account account, List<T> messages, Function<T, Long> getDeviceId, boolean isSyncMessage,
|
||||
Optional<Long> authenticatedDeviceId)
|
||||
throws MismatchedDevicesException {
|
||||
Set<Long> messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId)
|
||||
Set<Long> messageDeviceIds = messages.stream().map(getDeviceId)
|
||||
.collect(Collectors.toSet());
|
||||
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Copyright 2021 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.entities;
|
||||
|
||||
import javax.validation.constraints.Max;
|
||||
import javax.validation.constraints.Min;
|
||||
import javax.validation.constraints.NotNull;
|
||||
|
||||
public class IncomingDeviceMessage {
|
||||
private final int type;
|
||||
|
||||
@Min(1)
|
||||
private final long deviceId;
|
||||
|
||||
@Min(0)
|
||||
@Max(65536)
|
||||
private final int registrationId;
|
||||
|
||||
@NotNull
|
||||
private final byte[] content;
|
||||
|
||||
public IncomingDeviceMessage(int type, long deviceId, int registrationId, byte[] content) {
|
||||
this.type = type;
|
||||
this.deviceId = deviceId;
|
||||
this.registrationId = registrationId;
|
||||
this.content = content;
|
||||
}
|
||||
|
||||
public int getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
public long getDeviceId() {
|
||||
return deviceId;
|
||||
}
|
||||
|
||||
public int getRegistrationId() {
|
||||
return registrationId;
|
||||
}
|
||||
|
||||
public byte[] getContent() {
|
||||
return content;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Copyright 2021 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.providers;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.UUID;
|
||||
import javax.ws.rs.BadRequestException;
|
||||
import javax.ws.rs.WebApplicationException;
|
||||
|
||||
public abstract class BinaryProviderBase {
|
||||
|
||||
/**
|
||||
* Reads a UUID in network byte order and converts to a UUID object.
|
||||
*/
|
||||
UUID readUuid(InputStream stream) throws IOException {
|
||||
byte[] buffer = new byte[8];
|
||||
|
||||
int read = stream.readNBytes(buffer, 0, 8);
|
||||
if (read != 8) {
|
||||
throw new IOException("Insufficient bytes for UUID");
|
||||
}
|
||||
long msb = convertNetworkByteOrderToLong(buffer);
|
||||
|
||||
read = stream.readNBytes(buffer, 0, 8);
|
||||
if (read != 8) {
|
||||
throw new IOException("Insufficient bytes for UUID");
|
||||
}
|
||||
long lsb = convertNetworkByteOrderToLong(buffer);
|
||||
|
||||
return new UUID(msb, lsb);
|
||||
}
|
||||
|
||||
private long convertNetworkByteOrderToLong(byte[] buffer) {
|
||||
long result = 0;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
result = (result << 8) | (buffer[i] & 0xFFL);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads a varint. A varint larger than 64 bits is rejected with a {@code WebApplicationException}. An
|
||||
* {@code IOException} is thrown if the stream ends before we finish reading the varint.
|
||||
*
|
||||
* @return the varint value
|
||||
*/
|
||||
static long readVarint(InputStream stream) throws IOException, WebApplicationException {
|
||||
boolean hasMore = true;
|
||||
int currentOffset = 0;
|
||||
long result = 0;
|
||||
while (hasMore) {
|
||||
if (currentOffset >= 64) {
|
||||
throw new BadRequestException("varint is too large");
|
||||
}
|
||||
int b = stream.read();
|
||||
if (b == -1) {
|
||||
throw new IOException("Missing byte " + (currentOffset / 7) + " of varint");
|
||||
}
|
||||
if (currentOffset == 63 && (b & 0xFE) != 0) {
|
||||
throw new BadRequestException("varint is too large");
|
||||
}
|
||||
hasMore = (b & 0x80) != 0;
|
||||
result |= ((long)(b & 0x7F)) << currentOffset;
|
||||
currentOffset += 7;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads two bytes with most significant byte first. Treats the value as unsigned so the range returned is
|
||||
* {@code [0, 65535]}.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
static int readU16(InputStream stream) throws IOException {
|
||||
int b1 = stream.read();
|
||||
if (b1 == -1) {
|
||||
throw new IOException("Missing byte 1 of U16");
|
||||
}
|
||||
int b2 = stream.read();
|
||||
if (b2 == -1) {
|
||||
throw new IOException("Missing byte 2 of U16");
|
||||
}
|
||||
return (b1 << 8) | b2;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Copyright 2021 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.providers;
|
||||
|
||||
import io.dropwizard.util.DataSizeUnit;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.annotation.Annotation;
|
||||
import java.lang.reflect.Type;
|
||||
import javax.ws.rs.BadRequestException;
|
||||
import javax.ws.rs.Consumes;
|
||||
import javax.ws.rs.WebApplicationException;
|
||||
import javax.ws.rs.core.MediaType;
|
||||
import javax.ws.rs.core.MultivaluedMap;
|
||||
import javax.ws.rs.core.NoContentException;
|
||||
import javax.ws.rs.ext.MessageBodyReader;
|
||||
import javax.ws.rs.ext.Provider;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage;
|
||||
|
||||
@Provider
|
||||
@Consumes(MultiDeviceMessageListProvider.MEDIA_TYPE)
|
||||
public class MultiDeviceMessageListProvider extends BinaryProviderBase implements MessageBodyReader<IncomingDeviceMessage[]> {
|
||||
|
||||
public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mdml";
|
||||
public static final int MAX_MESSAGE_COUNT = 50;
|
||||
public static final int MAX_MESSAGE_SIZE = Math.toIntExact(DataSizeUnit.KIBIBYTES.toBytes(256));
|
||||
public static final byte VERSION = 0x01;
|
||||
|
||||
@Override
|
||||
public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
|
||||
return MEDIA_TYPE.equals(mediaType.toString()) && IncomingDeviceMessage[].class.isAssignableFrom(type);
|
||||
}
|
||||
|
||||
@Override
|
||||
public IncomingDeviceMessage[]
|
||||
readFrom(Class<IncomingDeviceMessage[]> resultType, Type genericType,
|
||||
Annotation[] annotations, MediaType mediaType, MultivaluedMap<String, String> httpHeaders,
|
||||
InputStream entityStream)
|
||||
throws IOException, WebApplicationException {
|
||||
int versionByte = entityStream.read();
|
||||
if (versionByte == -1) {
|
||||
throw new NoContentException("Empty body not allowed");
|
||||
}
|
||||
if (versionByte != VERSION) {
|
||||
throw new BadRequestException("Unsupported version");
|
||||
}
|
||||
int count = entityStream.read();
|
||||
if (count == -1) {
|
||||
throw new IOException("Missing count");
|
||||
}
|
||||
if (count > MAX_MESSAGE_COUNT) {
|
||||
throw new BadRequestException("Maximum recipient count exceeded");
|
||||
}
|
||||
IncomingDeviceMessage[] messages = new IncomingDeviceMessage[count];
|
||||
for (int i = 0; i < count; i++) {
|
||||
long deviceId = readVarint(entityStream);
|
||||
int registrationId = readU16(entityStream);
|
||||
|
||||
int type = entityStream.read();
|
||||
if (type == -1) {
|
||||
throw new IOException("Unexpected end of stream reading message type");
|
||||
}
|
||||
|
||||
long messageLength = readVarint(entityStream);
|
||||
if (messageLength > MAX_MESSAGE_SIZE) {
|
||||
throw new BadRequestException("Message body too large");
|
||||
}
|
||||
byte[] contents = entityStream.readNBytes(Math.toIntExact(messageLength));
|
||||
if (contents.length != messageLength) {
|
||||
throw new IOException("Unexpected end of stream in the middle of message contents");
|
||||
}
|
||||
|
||||
messages[i] = new IncomingDeviceMessage(type, deviceId, registrationId, contents);
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
}
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.providers;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.dropwizard.util.DataSizeUnit;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
@ -24,7 +23,7 @@ import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
|
|||
|
||||
@Provider
|
||||
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
|
||||
public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRecipientMessage> {
|
||||
public class MultiRecipientMessageProvider extends BinaryProviderBase implements MessageBodyReader<MultiRecipientMessage> {
|
||||
|
||||
public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm";
|
||||
public static final int MAX_RECIPIENT_COUNT = 5000;
|
||||
|
@ -71,78 +70,4 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
|
|||
}
|
||||
return new MultiRecipientMessage(recipients, commonPayload);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads a UUID in network byte order and converts to a UUID object.
|
||||
*/
|
||||
private UUID readUuid(InputStream stream) throws IOException {
|
||||
byte[] buffer = new byte[8];
|
||||
|
||||
int read = stream.readNBytes(buffer, 0, 8);
|
||||
if (read != 8) {
|
||||
throw new IOException("Insufficient bytes for UUID");
|
||||
}
|
||||
long msb = convertNetworkByteOrderToLong(buffer);
|
||||
|
||||
read = stream.readNBytes(buffer, 0, 8);
|
||||
if (read != 8) {
|
||||
throw new IOException("Insufficient bytes for UUID");
|
||||
}
|
||||
long lsb = convertNetworkByteOrderToLong(buffer);
|
||||
|
||||
return new UUID(msb, lsb);
|
||||
}
|
||||
|
||||
private long convertNetworkByteOrderToLong(byte[] buffer) {
|
||||
long result = 0;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
result = (result << 8) | (buffer[i] & 0xFFL);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads a varint. A varint larger than 64 bits is rejected with a {@code WebApplicationException}. An
|
||||
* {@code IOException} is thrown if the stream ends before we finish reading the varint.
|
||||
*
|
||||
* @return the varint value
|
||||
*/
|
||||
private long readVarint(InputStream stream) throws IOException, WebApplicationException {
|
||||
boolean hasMore = true;
|
||||
int currentOffset = 0;
|
||||
int result = 0;
|
||||
while (hasMore) {
|
||||
if (currentOffset >= 64) {
|
||||
throw new BadRequestException("varint is too large");
|
||||
}
|
||||
int b = stream.read();
|
||||
if (b == -1) {
|
||||
throw new IOException("Missing byte " + (currentOffset / 7) + " of varint");
|
||||
}
|
||||
if (currentOffset == 63 && (b & 0xFE) != 0) {
|
||||
throw new BadRequestException("varint is too large");
|
||||
}
|
||||
hasMore = (b & 0x80) != 0;
|
||||
result |= (b & 0x7F) << currentOffset;
|
||||
currentOffset += 7;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads two bytes with most significant byte first. Treats the value as unsigned so the range returned is
|
||||
* {@code [0, 65535]}.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
static int readU16(InputStream stream) throws IOException {
|
||||
int b1 = stream.read();
|
||||
if (b1 == -1) {
|
||||
throw new IOException("Missing byte 1 of U16");
|
||||
}
|
||||
int b2 = stream.read();
|
||||
if (b2 == -1) {
|
||||
throw new IOException("Missing byte 2 of U16");
|
||||
}
|
||||
return (b1 << 8) | b2;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,6 +35,7 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
|
|||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import io.dropwizard.testing.junit5.ResourceExtension;
|
||||
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.util.Base64;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
|
@ -72,6 +73,7 @@ import org.whispersystems.textsecuregcm.entities.StaleDevices;
|
|||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
|
||||
import org.whispersystems.textsecuregcm.providers.MultiDeviceMessageListProvider;
|
||||
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
import org.whispersystems.textsecuregcm.push.ReceiptSender;
|
||||
|
@ -111,13 +113,14 @@ class MessageControllerTest {
|
|||
private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
|
||||
private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class);
|
||||
|
||||
private final ObjectMapper mapper = new ObjectMapper();
|
||||
private static final ObjectMapper mapper = new ObjectMapper();
|
||||
|
||||
private static final ResourceExtension resources = ResourceExtension.builder()
|
||||
.addProvider(AuthHelper.getAuthFilter())
|
||||
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
|
||||
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
|
||||
.addProvider(RateLimitExceededExceptionMapper.class)
|
||||
.addProvider(MultiDeviceMessageListProvider.class)
|
||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||
.addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager,
|
||||
messagesManager, apnFallbackManager, reportMessageManager, multiRecipientMessageExecutor))
|
||||
|
@ -174,28 +177,50 @@ class MessageControllerTest {
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSendFromDisabledAccount() throws Exception {
|
||||
private static Stream<Entity<?>> currentMessageSingleDevicePayloads() {
|
||||
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
|
||||
messageStream.write(1); // version
|
||||
messageStream.write(1); // count
|
||||
messageStream.write(1); // device ID
|
||||
messageStream.writeBytes(new byte[] { (byte)0, (byte)111 }); // registration ID
|
||||
messageStream.write(1); // message type
|
||||
messageStream.write(3); // message length
|
||||
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
|
||||
|
||||
try {
|
||||
return Stream.of(
|
||||
Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"),
|
||||
IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE),
|
||||
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
|
||||
);
|
||||
} catch (Exception e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("currentMessageSingleDevicePayloads")
|
||||
void testSendFromDisabledAccount(Entity<?> payload) throws Exception {
|
||||
Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
|
||||
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE));
|
||||
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
|
||||
.put(payload);
|
||||
|
||||
assertThat("Unauthorized response", response.getStatus(), is(equalTo(401)));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSingleDeviceCurrent() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("currentMessageSingleDevicePayloads")
|
||||
void testSingleDeviceCurrent(Entity<?> payload) throws Exception {
|
||||
Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE));
|
||||
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(payload);
|
||||
|
||||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
|
@ -239,15 +264,15 @@ class MessageControllerTest {
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSingleDeviceCurrentByPni() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("currentMessageSingleDevicePayloads")
|
||||
void testSingleDeviceCurrentByPni(Entity<?> payload) throws Exception {
|
||||
Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_PNI))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE));
|
||||
.put(payload);
|
||||
|
||||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
|
@ -271,15 +296,15 @@ class MessageControllerTest {
|
|||
assertThat("Bad request", response.getStatus(), is(equalTo(422)));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSingleDeviceCurrentUnidentified() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("currentMessageSingleDevicePayloads")
|
||||
void testSingleDeviceCurrentUnidentified(Entity<?> payload) throws Exception {
|
||||
Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
|
||||
.request()
|
||||
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes()))
|
||||
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE));
|
||||
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
|
||||
.request()
|
||||
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes()))
|
||||
.put(payload);
|
||||
|
||||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
|
@ -290,28 +315,27 @@ class MessageControllerTest {
|
|||
assertFalse(captor.getValue().hasSourceDevice());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void testSendBadAuth() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("currentMessageSingleDevicePayloads")
|
||||
void testSendBadAuth(Entity<?> payload) throws Exception {
|
||||
Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
|
||||
.request()
|
||||
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE));
|
||||
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
|
||||
.request()
|
||||
.put(payload);
|
||||
|
||||
assertThat("Good Response", response.getStatus(), is(equalTo(401)));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiDeviceMissing() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("currentMessageSingleDevicePayloads")
|
||||
void testMultiDeviceMissing(Entity<?> payload) throws Exception {
|
||||
Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE));
|
||||
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(payload);
|
||||
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
|
||||
|
||||
|
@ -322,15 +346,15 @@ class MessageControllerTest {
|
|||
verifyNoMoreInteractions(messageSender);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiDeviceExtra() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void testMultiDeviceExtra(Entity<?> payload) throws Exception {
|
||||
Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_extra_device.json"), IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE));
|
||||
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(payload);
|
||||
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
|
||||
|
||||
|
@ -341,29 +365,87 @@ class MessageControllerTest {
|
|||
verifyNoMoreInteractions(messageSender);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiDevice() throws Exception {
|
||||
private static Stream<Entity<?>> testMultiDeviceExtra() {
|
||||
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
|
||||
messageStream.write(1); // version
|
||||
messageStream.write(2); // count
|
||||
|
||||
messageStream.write(1); // device ID
|
||||
messageStream.writeBytes(new byte[] { (byte)0, (byte)111 }); // registration ID
|
||||
messageStream.write(1); // message type
|
||||
messageStream.write(3); // message length
|
||||
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
|
||||
|
||||
messageStream.write(3); // device ID
|
||||
messageStream.writeBytes(new byte[] { (byte)0, (byte)111 }); // registration ID
|
||||
messageStream.write(1); // message type
|
||||
messageStream.write(3); // message length
|
||||
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
|
||||
|
||||
try {
|
||||
return Stream.of(
|
||||
Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_extra_device.json"),
|
||||
IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE),
|
||||
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
|
||||
);
|
||||
} catch (Exception e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void testMultiDevice(Entity<?> payload) throws Exception {
|
||||
Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_multi_device.json"), IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE));
|
||||
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(payload);
|
||||
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
verify(messageSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRegistrationIdMismatch() throws Exception {
|
||||
private static Stream<Entity<?>> testMultiDevice() {
|
||||
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
|
||||
messageStream.write(1); // version
|
||||
messageStream.write(2); // count
|
||||
|
||||
messageStream.write(1); // device ID
|
||||
messageStream.writeBytes(new byte[] { (byte)0, (byte)222 }); // registration ID
|
||||
messageStream.write(1); // message type
|
||||
messageStream.write(3); // message length
|
||||
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
|
||||
|
||||
messageStream.write(2); // device ID
|
||||
messageStream.writeBytes(new byte[] { (byte)1, (byte)77 }); // registration ID (333 = 1 * 256 + 77)
|
||||
messageStream.write(1); // message type
|
||||
messageStream.write(3); // message length
|
||||
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
|
||||
|
||||
try {
|
||||
return Stream.of(
|
||||
Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_multi_device.json"),
|
||||
IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE),
|
||||
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
|
||||
);
|
||||
} catch (Exception e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void testRegistrationIdMismatch(Entity<?> payload) throws Exception {
|
||||
Response response =
|
||||
resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_registration_id.json"), IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE));
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(payload);
|
||||
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(410)));
|
||||
|
||||
|
@ -372,7 +454,35 @@ class MessageControllerTest {
|
|||
is(equalTo(jsonFixture("fixtures/mismatched_registration_id.json"))));
|
||||
|
||||
verifyNoMoreInteractions(messageSender);
|
||||
}
|
||||
|
||||
private static Stream<Entity<?>> testRegistrationIdMismatch() {
|
||||
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
|
||||
messageStream.write(1); // version
|
||||
messageStream.write(2); // count
|
||||
|
||||
messageStream.write(1); // device ID
|
||||
messageStream.writeBytes(new byte[] { (byte)0, (byte)222 }); // registration ID
|
||||
messageStream.write(1); // message type
|
||||
messageStream.write(3); // message length
|
||||
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
|
||||
|
||||
messageStream.write(2); // device ID
|
||||
messageStream.writeBytes(new byte[] { (byte)0, (byte)77 }); // wrong registration ID
|
||||
messageStream.write(1); // message type
|
||||
messageStream.write(3); // message length
|
||||
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
|
||||
|
||||
try {
|
||||
return Stream.of(
|
||||
Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_registration_id.json"),
|
||||
IncomingMessageList.class),
|
||||
MediaType.APPLICATION_JSON_TYPE),
|
||||
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
|
||||
);
|
||||
} catch (Exception e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -395,11 +505,10 @@ class MessageControllerTest {
|
|||
|
||||
OutgoingMessageEntityList response =
|
||||
resources.getJerseyTest().target("/v1/messages/")
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.accept(MediaType.APPLICATION_JSON_TYPE)
|
||||
.get(OutgoingMessageEntityList.class);
|
||||
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.accept(MediaType.APPLICATION_JSON_TYPE)
|
||||
.get(OutgoingMessageEntityList.class);
|
||||
|
||||
assertEquals(response.getMessages().size(), 2);
|
||||
|
||||
|
@ -429,10 +538,10 @@ class MessageControllerTest {
|
|||
|
||||
Response response =
|
||||
resources.getJerseyTest().target("/v1/messages/")
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD))
|
||||
.accept(MediaType.APPLICATION_JSON_TYPE)
|
||||
.get();
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD))
|
||||
.accept(MediaType.APPLICATION_JSON_TYPE)
|
||||
.get();
|
||||
|
||||
assertThat("Unauthorized response", response.getStatus(), is(equalTo(401)));
|
||||
}
|
||||
|
@ -453,33 +562,32 @@ class MessageControllerTest {
|
|||
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE,
|
||||
null, System.currentTimeMillis(), "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0)));
|
||||
|
||||
|
||||
UUID uuid3 = UUID.randomUUID();
|
||||
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid3)).thenReturn(Optional.empty());
|
||||
|
||||
Response response = resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/uuid/%s", uuid1))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.delete();
|
||||
.target(String.format("/v1/messages/uuid/%s", uuid1))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.delete();
|
||||
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
|
||||
verify(receiptSender).sendReceipt(any(AuthenticatedAccount.class), eq(sourceUuid), eq(timestamp));
|
||||
|
||||
response = resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/uuid/%s", uuid2))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.delete();
|
||||
.target(String.format("/v1/messages/uuid/%s", uuid2))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.delete();
|
||||
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
|
||||
verifyNoMoreInteractions(receiptSender);
|
||||
|
||||
response = resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/uuid/%s", uuid3))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.delete();
|
||||
.target(String.format("/v1/messages/uuid/%s", uuid3))
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.delete();
|
||||
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
|
||||
verifyNoMoreInteractions(receiptSender);
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
/*
|
||||
* Copyright 2021 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.providers;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.lang.annotation.Annotation;
|
||||
|
||||
import javax.ws.rs.WebApplicationException;
|
||||
import javax.ws.rs.core.MediaType;
|
||||
import javax.ws.rs.core.MultivaluedHashMap;
|
||||
import javax.ws.rs.core.NoContentException;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage;
|
||||
|
||||
public class MultiDeviceMessageListProviderTest {
|
||||
|
||||
static byte[] createByteArray(int... bytes) {
|
||||
byte[] result = new byte[bytes.length];
|
||||
for (int i = 0; i < bytes.length; i++) {
|
||||
result[i] = (byte) bytes[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
IncomingDeviceMessage[] tryRead(byte[] bytes) throws Exception {
|
||||
MultiDeviceMessageListProvider provider = new MultiDeviceMessageListProvider();
|
||||
return provider.readFrom(
|
||||
IncomingDeviceMessage[].class,
|
||||
IncomingDeviceMessage[].class,
|
||||
new Annotation[] {},
|
||||
MediaType.valueOf(MultiDeviceMessageListProvider.MEDIA_TYPE),
|
||||
new MultivaluedHashMap<>(),
|
||||
new ByteArrayInputStream(bytes));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testInvalidVersion() {
|
||||
assertThatThrownBy(() -> tryRead(createByteArray()))
|
||||
.isInstanceOf(NoContentException.class);
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(0x00)))
|
||||
.isInstanceOf(WebApplicationException.class);
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(0x59)))
|
||||
.isInstanceOf(WebApplicationException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBadCount() {
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(MultiDeviceMessageListProvider.VERSION)))
|
||||
.isInstanceOf(IOException.class);
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION,
|
||||
MultiDeviceMessageListProvider.MAX_MESSAGE_COUNT + 1)))
|
||||
.isInstanceOf(WebApplicationException.class);
|
||||
}
|
||||
|
||||
@Test void testBadDeviceId() {
|
||||
// Missing
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION, 1)))
|
||||
.isInstanceOf(IOException.class);
|
||||
// Unfinished varint
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION, 1, 0x80)))
|
||||
.isInstanceOf(IOException.class);
|
||||
}
|
||||
|
||||
@Test void testBadRegistrationId() {
|
||||
// Missing
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1)))
|
||||
.isInstanceOf(IOException.class);
|
||||
// Truncated (fixed u16 value)
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11)))
|
||||
.isInstanceOf(IOException.class);
|
||||
}
|
||||
|
||||
@Test void testBadType() {
|
||||
// Missing
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22)))
|
||||
.isInstanceOf(IOException.class);
|
||||
}
|
||||
|
||||
@Test void testBadMessageLength() {
|
||||
// Missing
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1)))
|
||||
.isInstanceOf(IOException.class);
|
||||
// Unfinished varint
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 0x80)))
|
||||
.isInstanceOf(IOException.class);
|
||||
// Too big
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 0x80, 0x80, 0x80, 0x01)))
|
||||
.isInstanceOf(WebApplicationException.class);
|
||||
// Missing message
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 0x01)))
|
||||
.isInstanceOf(IOException.class);
|
||||
// Truncated message
|
||||
assertThatThrownBy(() -> tryRead(createByteArray(
|
||||
MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 5, 0x01, 0x02)))
|
||||
.isInstanceOf(IOException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void readThreeMessages() throws Exception {
|
||||
ByteArrayOutputStream contentStream = new ByteArrayOutputStream();
|
||||
contentStream.write(MultiDeviceMessageListProvider.VERSION);
|
||||
contentStream.write(3);
|
||||
|
||||
contentStream.writeBytes(createByteArray(0x85, 0x02)); // Device ID 0x0105
|
||||
contentStream.writeBytes(createByteArray(0x11, 0x05)); // Registration ID 0x1105
|
||||
contentStream.write(1);
|
||||
contentStream.write(5);
|
||||
contentStream.writeBytes(createByteArray(11, 22, 33, 44, 55));
|
||||
|
||||
contentStream.write(0x20); // Device ID 0x20
|
||||
contentStream.writeBytes(createByteArray(0x30, 0x20)); // Registration ID 0x3020
|
||||
contentStream.write(6);
|
||||
contentStream.writeBytes(createByteArray(0x81, 0x01)); // 129 bytes
|
||||
contentStream.writeBytes(new byte[129]);
|
||||
|
||||
contentStream.write(1); // Device ID 1
|
||||
contentStream.writeBytes(createByteArray(0x00, 0x11)); // Registration ID 0x0011
|
||||
contentStream.write(50);
|
||||
contentStream.write(0); // empty message for some rateLimitReason
|
||||
|
||||
IncomingDeviceMessage[] messages = tryRead(contentStream.toByteArray());
|
||||
assertThat(messages.length).isEqualTo(3);
|
||||
|
||||
assertThat(messages[0].getDeviceId()).isEqualTo(0x0105);
|
||||
assertThat(messages[0].getRegistrationId()).isEqualTo(0x1105);
|
||||
assertThat(messages[0].getType()).isEqualTo(1);
|
||||
assertThat(messages[0].getContent()).containsExactly(11, 22, 33, 44, 55);
|
||||
|
||||
assertThat(messages[1].getDeviceId()).isEqualTo(0x20);
|
||||
assertThat(messages[1].getRegistrationId()).isEqualTo(0x3020);
|
||||
assertThat(messages[1].getType()).isEqualTo(6);
|
||||
assertThat(messages[1].getContent()).containsExactly(new byte[129]);
|
||||
|
||||
assertThat(messages[2].getDeviceId()).isEqualTo(1);
|
||||
assertThat(messages[2].getRegistrationId()).isEqualTo(0x0011);
|
||||
assertThat(messages[2].getType()).isEqualTo(50);
|
||||
assertThat(messages[2].getContent()).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void emptyListIsStillValid() throws Exception {
|
||||
ByteArrayOutputStream contentStream = new ByteArrayOutputStream();
|
||||
contentStream.write(MultiDeviceMessageListProvider.VERSION);
|
||||
contentStream.write(0);
|
||||
|
||||
IncomingDeviceMessage[] messages = tryRead(contentStream.toByteArray());
|
||||
assertThat(messages).isEmpty();
|
||||
}
|
||||
|
||||
}
|
|
@ -6,28 +6,38 @@
|
|||
package org.whispersystems.textsecuregcm.providers;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.junit.jupiter.params.provider.Arguments.arguments;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import javax.ws.rs.WebApplicationException;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
public class MultiRecipientMessageProviderTest {
|
||||
|
||||
static byte[] createTwoByteArray(int b1, int b2) {
|
||||
return new byte[]{(byte) b1, (byte) b2};
|
||||
static byte[] createByteArray(int... bytes) {
|
||||
byte[] result = new byte[bytes.length];
|
||||
for (int i = 0; i < bytes.length; i++) {
|
||||
result[i] = (byte)bytes[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static Stream<Arguments> readU16TestCases() {
|
||||
return Stream.of(
|
||||
arguments(0xFFFE, createTwoByteArray(0xFF, 0xFE)),
|
||||
arguments(0x0001, createTwoByteArray(0x00, 0x01)),
|
||||
arguments(0xBEEF, createTwoByteArray(0xBE, 0xEF)),
|
||||
arguments(0xFFFF, createTwoByteArray(0xFF, 0xFF)),
|
||||
arguments(0x0000, createTwoByteArray(0x00, 0x00)),
|
||||
arguments(0xF080, createTwoByteArray(0xF0, 0x80))
|
||||
arguments(0xFFFE, createByteArray(0xFF, 0xFE)),
|
||||
arguments(0x0001, createByteArray(0x00, 0x01)),
|
||||
arguments(0xBEEF, createByteArray(0xBE, 0xEF)),
|
||||
arguments(0xFFFF, createByteArray(0xFF, 0xFF)),
|
||||
arguments(0x0000, createByteArray(0x00, 0x00)),
|
||||
arguments(0xF080, createByteArray(0xF0, 0x80))
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -38,7 +48,42 @@ public class MultiRecipientMessageProviderTest {
|
|||
assertThat(MultiRecipientMessageProvider.readU16(stream)).isEqualTo(expectedValue);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
static Stream<Arguments> readVarintTestCases() {
|
||||
return Stream.of(
|
||||
arguments(0x00L, createByteArray(0x00)),
|
||||
arguments(0x01L, createByteArray(0x01)),
|
||||
arguments(0x7FL, createByteArray(0x7F)),
|
||||
arguments(0x80L, createByteArray(0x80, 0x01)),
|
||||
arguments(0xFFL, createByteArray(0xFF, 0x01)),
|
||||
arguments(0b1010101_0011001_1100110L, createByteArray(0b1_1100110, 0b1_0011001, 0b0_1010101)),
|
||||
arguments(0x7FFFFFFF_FFFFFFFFL, createByteArray(0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F)),
|
||||
arguments(-1L, createByteArray(0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01))
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("readVarintTestCases")
|
||||
void testReadVarint(long expectedValue, byte[] input) throws Exception {
|
||||
try (final ByteArrayInputStream stream = new ByteArrayInputStream(input)) {
|
||||
assertThat(MultiRecipientMessageProvider.readVarint(stream)).isEqualTo(expectedValue);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testVarintEOF() {
|
||||
assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray(0xFF, 0x80))))
|
||||
.isInstanceOf(IOException.class);
|
||||
assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray())))
|
||||
.isInstanceOf(IOException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testVarintTooBig() {
|
||||
assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray(0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02))))
|
||||
.isInstanceOf(WebApplicationException.class);
|
||||
assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray(0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80))))
|
||||
.isInstanceOf(WebApplicationException.class);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue