diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 58fe8e599..048587fbd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -139,6 +139,7 @@ import org.whispersystems.textsecuregcm.metrics.NetworkSentGauge; import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.metrics.TrafficSource; +import org.whispersystems.textsecuregcm.providers.MultiDeviceMessageListProvider; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.RedisClientFactory; import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck; @@ -585,6 +586,7 @@ public class WhisperServerService extends Application tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())), @@ -245,8 +239,108 @@ public class MessageController { if (destinationDevice.isPresent()) { Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); - sendMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.getTimestamp(), - messages.isOnline(), incomingMessage, userAgent); + sendMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.getTimestamp(), messages.isOnline(), incomingMessage, userAgent); + } + } + + return Response.ok(new SendMessageResponse( + !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1)).build(); + } catch (NoSuchUserException e) { + throw new WebApplicationException(Response.status(404).build()); + } catch (MismatchedDevicesException e) { + throw new WebApplicationException(Response.status(409) + .type(MediaType.APPLICATION_JSON_TYPE) + .entity(new MismatchedDevices(e.getMissingDevices(), + e.getExtraDevices())) + .build()); + } catch (StaleDevicesException e) { + throw new WebApplicationException(Response.status(410) + .type(MediaType.APPLICATION_JSON) + .entity(new StaleDevices(e.getStaleDevices())) + .build()); + } + } + + @Timed + @Path("/{destination}") + @PUT + @Consumes(MultiDeviceMessageListProvider.MEDIA_TYPE) + @Produces(MediaType.APPLICATION_JSON) + @FilterAbusiveMessages + public Response sendMultiDeviceMessage(@Auth Optional source, + @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + @HeaderParam("User-Agent") String userAgent, + @HeaderParam("X-Forwarded-For") String forwardedFor, + @PathParam("destination") UUID destinationUuid, + @QueryParam("online") boolean online, + @QueryParam("ts") long timestamp, + @Valid IncomingDeviceMessage[] messages) + throws RateLimitExceededException, RateLimitChallengeException { + + if (source.isEmpty() && accessKey.isEmpty()) { + throw new WebApplicationException(Response.Status.UNAUTHORIZED); + } + + final String senderType; + + if (source.isPresent()) { + if (source.get().getAccount().isIdentifiedBy(destinationUuid)) { + senderType = SENDER_TYPE_SELF; + } else { + senderType = SENDER_TYPE_IDENTIFIED; + } + } else { + senderType = SENDER_TYPE_UNIDENTIFIED; + } + + for (final IncomingDeviceMessage message : messages) { + int contentLength = message.getContent().length; + + Metrics.summary(CONTENT_SIZE_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))).record(contentLength); + + if (contentLength > MAX_MESSAGE_SIZE) { + Metrics.counter(REJECT_OVERSIZE_MESSAGE_COUNTER, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))).increment(); + return Response.status(Response.Status.REQUEST_ENTITY_TOO_LARGE).build(); + } + } + + try { + boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationUuid); + + Optional destination; + + if (!isSyncMessage) { + destination = accountsManager.getByAccountIdentifier(destinationUuid) + .or(() -> accountsManager.getByPhoneNumberIdentifier(destinationUuid)); + } else { + destination = source.map(AuthenticatedAccount::getAccount); + } + + OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination); + assert (destination.isPresent()); + + if (source.isPresent() && !isSyncMessage) { + checkRateLimit(source.get(), destination.get()); + } + + final List messagesAsList = Arrays.asList(messages); + validateCompleteDeviceList(destination.get(), messagesAsList, + IncomingDeviceMessage::getDeviceId, isSyncMessage, + source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId)); + validateRegistrationIds(destination.get(), messagesAsList, + IncomingDeviceMessage::getDeviceId, + IncomingDeviceMessage::getRegistrationId); + + final List tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), + Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)), + Tag.of(SENDER_TYPE_TAG_NAME, senderType)); + + for (final IncomingDeviceMessage message : messages) { + Optional destinationDevice = destination.get().getDevice(message.getDeviceId()); + + if (destinationDevice.isPresent()) { + Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); + sendMessage(source, destination.get(), destinationDevice.get(), destinationUuid, timestamp, online, message, userAgent); } } @@ -544,6 +638,33 @@ public class MessageController { } } + private void sendMessage(Optional source, Account destinationAccount, Device destinationDevice, UUID destinationUuid, long timestamp, boolean online, IncomingDeviceMessage message, String userAgentString) throws NoSuchUserException { + try { + Envelope.Builder messageBuilder = Envelope.newBuilder(); + long serverTimestamp = System.currentTimeMillis(); + + messageBuilder + .setType(Envelope.Type.forNumber(message.getType())) + .setTimestamp(timestamp == 0 ? serverTimestamp : timestamp) + .setServerTimestamp(serverTimestamp) + .setDestinationUuid(destinationUuid.toString()) + .setContent(ByteString.copyFrom(message.getContent())); + + source.ifPresent(authenticatedAccount -> + messageBuilder.setSource(authenticatedAccount.getAccount().getNumber()) + .setSourceUuid(authenticatedAccount.getAccount().getUuid().toString()) + .setSourceDevice((int) authenticatedAccount.getAuthenticatedDevice().getId())); + + messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); + } catch (NotPushRegisteredException e) { + if (destinationDevice.isMaster()) { + throw new NoSuchUserException(e); + } else { + logger.debug("Not registered", e); + } + } + } + private void sendMessage(Account destinationAccount, Device destinationDevice, long timestamp, @@ -577,12 +698,26 @@ public class MessageController { } } + private void checkRateLimit(AuthenticatedAccount source, Account destination) throws RateLimitExceededException { + final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber()); + + try { + rateLimiters.getMessagesLimiter().validate(source.getAccount().getUuid(), destination.getUuid()); + } catch (final RateLimitExceededException e) { + Metrics.counter(RATE_LIMITED_MESSAGE_COUNTER_NAME, + SENDER_COUNTRY_TAG_NAME, senderCountryCode, + RATE_LIMIT_REASON_TAG_NAME, "singleDestinationRate").increment(); + + throw e; + } + } + @VisibleForTesting - public static void validateRegistrationIds(Account account, List messages) + public static void validateRegistrationIds(Account account, List messages, Function getDeviceId, Function getRegistrationId) throws StaleDevicesException { final Stream> deviceIdAndRegistrationIdStream = messages .stream() - .map(message -> new Pair<>(message.getDestinationDeviceId(), message.getDestinationRegistrationId())); + .map(message -> new Pair<>(getDeviceId.apply(message), getRegistrationId.apply(message))); validateRegistrationIds(account, deviceIdAndRegistrationIdStream); } @@ -604,10 +739,10 @@ public class MessageController { } @VisibleForTesting - public static void validateCompleteDeviceList(Account account, List messages, boolean isSyncMessage, + public static void validateCompleteDeviceList(Account account, List messages, Function getDeviceId, boolean isSyncMessage, Optional authenticatedDeviceId) throws MismatchedDevicesException { - Set messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId) + Set messageDeviceIds = messages.stream().map(getDeviceId) .collect(Collectors.toSet()); validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingDeviceMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingDeviceMessage.java new file mode 100644 index 000000000..2a0c887e8 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingDeviceMessage.java @@ -0,0 +1,47 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import javax.validation.constraints.Max; +import javax.validation.constraints.Min; +import javax.validation.constraints.NotNull; + +public class IncomingDeviceMessage { + private final int type; + + @Min(1) + private final long deviceId; + + @Min(0) + @Max(65536) + private final int registrationId; + + @NotNull + private final byte[] content; + + public IncomingDeviceMessage(int type, long deviceId, int registrationId, byte[] content) { + this.type = type; + this.deviceId = deviceId; + this.registrationId = registrationId; + this.content = content; + } + + public int getType() { + return type; + } + + public long getDeviceId() { + return deviceId; + } + + public int getRegistrationId() { + return registrationId; + } + + public byte[] getContent() { + return content; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/BinaryProviderBase.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/BinaryProviderBase.java new file mode 100644 index 000000000..1887cb53b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/BinaryProviderBase.java @@ -0,0 +1,90 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.providers; + +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.io.InputStream; +import java.util.UUID; +import javax.ws.rs.BadRequestException; +import javax.ws.rs.WebApplicationException; + +public abstract class BinaryProviderBase { + + /** + * Reads a UUID in network byte order and converts to a UUID object. + */ + UUID readUuid(InputStream stream) throws IOException { + byte[] buffer = new byte[8]; + + int read = stream.readNBytes(buffer, 0, 8); + if (read != 8) { + throw new IOException("Insufficient bytes for UUID"); + } + long msb = convertNetworkByteOrderToLong(buffer); + + read = stream.readNBytes(buffer, 0, 8); + if (read != 8) { + throw new IOException("Insufficient bytes for UUID"); + } + long lsb = convertNetworkByteOrderToLong(buffer); + + return new UUID(msb, lsb); + } + + private long convertNetworkByteOrderToLong(byte[] buffer) { + long result = 0; + for (int i = 0; i < 8; i++) { + result = (result << 8) | (buffer[i] & 0xFFL); + } + return result; + } + + /** + * Reads a varint. A varint larger than 64 bits is rejected with a {@code WebApplicationException}. An + * {@code IOException} is thrown if the stream ends before we finish reading the varint. + * + * @return the varint value + */ + static long readVarint(InputStream stream) throws IOException, WebApplicationException { + boolean hasMore = true; + int currentOffset = 0; + long result = 0; + while (hasMore) { + if (currentOffset >= 64) { + throw new BadRequestException("varint is too large"); + } + int b = stream.read(); + if (b == -1) { + throw new IOException("Missing byte " + (currentOffset / 7) + " of varint"); + } + if (currentOffset == 63 && (b & 0xFE) != 0) { + throw new BadRequestException("varint is too large"); + } + hasMore = (b & 0x80) != 0; + result |= ((long)(b & 0x7F)) << currentOffset; + currentOffset += 7; + } + return result; + } + + /** + * Reads two bytes with most significant byte first. Treats the value as unsigned so the range returned is + * {@code [0, 65535]}. + */ + @VisibleForTesting + static int readU16(InputStream stream) throws IOException { + int b1 = stream.read(); + if (b1 == -1) { + throw new IOException("Missing byte 1 of U16"); + } + int b2 = stream.read(); + if (b2 == -1) { + throw new IOException("Missing byte 2 of U16"); + } + return (b1 << 8) | b2; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiDeviceMessageListProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiDeviceMessageListProvider.java new file mode 100644 index 000000000..e2ef4dce4 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiDeviceMessageListProvider.java @@ -0,0 +1,80 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.providers; + +import io.dropwizard.util.DataSizeUnit; +import java.io.IOException; +import java.io.InputStream; +import java.lang.annotation.Annotation; +import java.lang.reflect.Type; +import javax.ws.rs.BadRequestException; +import javax.ws.rs.Consumes; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.MultivaluedMap; +import javax.ws.rs.core.NoContentException; +import javax.ws.rs.ext.MessageBodyReader; +import javax.ws.rs.ext.Provider; +import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage; + +@Provider +@Consumes(MultiDeviceMessageListProvider.MEDIA_TYPE) +public class MultiDeviceMessageListProvider extends BinaryProviderBase implements MessageBodyReader { + + public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mdml"; + public static final int MAX_MESSAGE_COUNT = 50; + public static final int MAX_MESSAGE_SIZE = Math.toIntExact(DataSizeUnit.KIBIBYTES.toBytes(256)); + public static final byte VERSION = 0x01; + + @Override + public boolean isReadable(Class type, Type genericType, Annotation[] annotations, MediaType mediaType) { + return MEDIA_TYPE.equals(mediaType.toString()) && IncomingDeviceMessage[].class.isAssignableFrom(type); + } + + @Override + public IncomingDeviceMessage[] + readFrom(Class resultType, Type genericType, + Annotation[] annotations, MediaType mediaType, MultivaluedMap httpHeaders, + InputStream entityStream) + throws IOException, WebApplicationException { + int versionByte = entityStream.read(); + if (versionByte == -1) { + throw new NoContentException("Empty body not allowed"); + } + if (versionByte != VERSION) { + throw new BadRequestException("Unsupported version"); + } + int count = entityStream.read(); + if (count == -1) { + throw new IOException("Missing count"); + } + if (count > MAX_MESSAGE_COUNT) { + throw new BadRequestException("Maximum recipient count exceeded"); + } + IncomingDeviceMessage[] messages = new IncomingDeviceMessage[count]; + for (int i = 0; i < count; i++) { + long deviceId = readVarint(entityStream); + int registrationId = readU16(entityStream); + + int type = entityStream.read(); + if (type == -1) { + throw new IOException("Unexpected end of stream reading message type"); + } + + long messageLength = readVarint(entityStream); + if (messageLength > MAX_MESSAGE_SIZE) { + throw new BadRequestException("Message body too large"); + } + byte[] contents = entityStream.readNBytes(Math.toIntExact(messageLength)); + if (contents.length != messageLength) { + throw new IOException("Unexpected end of stream in the middle of message contents"); + } + + messages[i] = new IncomingDeviceMessage(type, deviceId, registrationId, contents); + } + return messages; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java index a19cf087b..d81f4bb7c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java @@ -5,7 +5,6 @@ package org.whispersystems.textsecuregcm.providers; -import com.google.common.annotations.VisibleForTesting; import io.dropwizard.util.DataSizeUnit; import java.io.IOException; import java.io.InputStream; @@ -24,7 +23,7 @@ import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage; @Provider @Consumes(MultiRecipientMessageProvider.MEDIA_TYPE) -public class MultiRecipientMessageProvider implements MessageBodyReader { +public class MultiRecipientMessageProvider extends BinaryProviderBase implements MessageBodyReader { public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm"; public static final int MAX_RECIPIENT_COUNT = 5000; @@ -71,78 +70,4 @@ public class MultiRecipientMessageProvider implements MessageBodyReader= 64) { - throw new BadRequestException("varint is too large"); - } - int b = stream.read(); - if (b == -1) { - throw new IOException("Missing byte " + (currentOffset / 7) + " of varint"); - } - if (currentOffset == 63 && (b & 0xFE) != 0) { - throw new BadRequestException("varint is too large"); - } - hasMore = (b & 0x80) != 0; - result |= (b & 0x7F) << currentOffset; - currentOffset += 7; - } - return result; - } - - /** - * Reads two bytes with most significant byte first. Treats the value as unsigned so the range returned is - * {@code [0, 65535]}. - */ - @VisibleForTesting - static int readU16(InputStream stream) throws IOException { - int b1 = stream.read(); - if (b1 == -1) { - throw new IOException("Missing byte 1 of U16"); - } - int b2 = stream.read(); - if (b2 == -1) { - throw new IOException("Missing byte 2 of U16"); - } - return (b1 << 8) | b2; - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 7c1a6a54a..c92a8280d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -35,6 +35,7 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import java.io.ByteArrayOutputStream; import java.util.Base64; import java.util.Collection; import java.util.HashSet; @@ -72,6 +73,7 @@ import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; +import org.whispersystems.textsecuregcm.providers.MultiDeviceMessageListProvider; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -111,13 +113,14 @@ class MessageControllerTest { private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class); - private final ObjectMapper mapper = new ObjectMapper(); + private static final ObjectMapper mapper = new ObjectMapper(); private static final ResourceExtension resources = ResourceExtension.builder() .addProvider(AuthHelper.getAuthFilter()) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) .addProvider(RateLimitExceededExceptionMapper.class) + .addProvider(MultiDeviceMessageListProvider.class) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, apnFallbackManager, reportMessageManager, multiRecipientMessageExecutor)) @@ -174,28 +177,50 @@ class MessageControllerTest { ); } - @Test - void testSendFromDisabledAccount() throws Exception { + private static Stream> currentMessageSingleDevicePayloads() { + ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); + messageStream.write(1); // version + messageStream.write(1); // count + messageStream.write(1); // device ID + messageStream.writeBytes(new byte[] { (byte)0, (byte)111 }); // registration ID + messageStream.write(1); // message type + messageStream.write(3); // message length + messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents + + try { + return Stream.of( + Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), + IncomingMessageList.class), + MediaType.APPLICATION_JSON_TYPE), + Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE) + ); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + @ParameterizedTest + @MethodSource("currentMessageSingleDevicePayloads") + void testSendFromDisabledAccount(Entity payload) throws Exception { Response response = resources.getJerseyTest() - .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) - .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class), - MediaType.APPLICATION_JSON_TYPE)); + .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) + .put(payload); assertThat("Unauthorized response", response.getStatus(), is(equalTo(401))); } - @Test - void testSingleDeviceCurrent() throws Exception { + @ParameterizedTest + @MethodSource("currentMessageSingleDevicePayloads") + void testSingleDeviceCurrent(Entity payload) throws Exception { Response response = resources.getJerseyTest() - .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class), - MediaType.APPLICATION_JSON_TYPE)); + .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(payload); assertThat("Good Response", response.getStatus(), is(equalTo(200))); @@ -239,15 +264,15 @@ class MessageControllerTest { ); } - @Test - void testSingleDeviceCurrentByPni() throws Exception { + @ParameterizedTest + @MethodSource("currentMessageSingleDevicePayloads") + void testSingleDeviceCurrentByPni(Entity payload) throws Exception { Response response = resources.getJerseyTest() .target(String.format("/v1/messages/%s", SINGLE_DEVICE_PNI)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class), - MediaType.APPLICATION_JSON_TYPE)); + .put(payload); assertThat("Good Response", response.getStatus(), is(equalTo(200))); @@ -271,15 +296,15 @@ class MessageControllerTest { assertThat("Bad request", response.getStatus(), is(equalTo(422))); } - @Test - void testSingleDeviceCurrentUnidentified() throws Exception { + @ParameterizedTest + @MethodSource("currentMessageSingleDevicePayloads") + void testSingleDeviceCurrentUnidentified(Entity payload) throws Exception { Response response = resources.getJerseyTest() - .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) - .request() - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) - .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class), - MediaType.APPLICATION_JSON_TYPE)); + .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) + .request() + .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) + .put(payload); assertThat("Good Response", response.getStatus(), is(equalTo(200))); @@ -290,28 +315,27 @@ class MessageControllerTest { assertFalse(captor.getValue().hasSourceDevice()); } - - @Test - void testSendBadAuth() throws Exception { + @ParameterizedTest + @MethodSource("currentMessageSingleDevicePayloads") + void testSendBadAuth(Entity payload) throws Exception { Response response = resources.getJerseyTest() - .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) - .request() - .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class), - MediaType.APPLICATION_JSON_TYPE)); + .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) + .request() + .put(payload); assertThat("Good Response", response.getStatus(), is(equalTo(401))); } - @Test - void testMultiDeviceMissing() throws Exception { + @ParameterizedTest + @MethodSource("currentMessageSingleDevicePayloads") + void testMultiDeviceMissing(Entity payload) throws Exception { Response response = resources.getJerseyTest() - .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class), - MediaType.APPLICATION_JSON_TYPE)); + .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(payload); assertThat("Good Response Code", response.getStatus(), is(equalTo(409))); @@ -322,15 +346,15 @@ class MessageControllerTest { verifyNoMoreInteractions(messageSender); } - @Test - void testMultiDeviceExtra() throws Exception { + @ParameterizedTest + @MethodSource + void testMultiDeviceExtra(Entity payload) throws Exception { Response response = resources.getJerseyTest() - .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_extra_device.json"), IncomingMessageList.class), - MediaType.APPLICATION_JSON_TYPE)); + .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(payload); assertThat("Good Response Code", response.getStatus(), is(equalTo(409))); @@ -341,29 +365,87 @@ class MessageControllerTest { verifyNoMoreInteractions(messageSender); } - @Test - void testMultiDevice() throws Exception { + private static Stream> testMultiDeviceExtra() { + ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); + messageStream.write(1); // version + messageStream.write(2); // count + + messageStream.write(1); // device ID + messageStream.writeBytes(new byte[] { (byte)0, (byte)111 }); // registration ID + messageStream.write(1); // message type + messageStream.write(3); // message length + messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents + + messageStream.write(3); // device ID + messageStream.writeBytes(new byte[] { (byte)0, (byte)111 }); // registration ID + messageStream.write(1); // message type + messageStream.write(3); // message length + messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents + + try { + return Stream.of( + Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_extra_device.json"), + IncomingMessageList.class), + MediaType.APPLICATION_JSON_TYPE), + Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE) + ); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + @ParameterizedTest + @MethodSource + void testMultiDevice(Entity payload) throws Exception { Response response = resources.getJerseyTest() - .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_multi_device.json"), IncomingMessageList.class), - MediaType.APPLICATION_JSON_TYPE)); + .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(payload); assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); verify(messageSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false)); } - @Test - void testRegistrationIdMismatch() throws Exception { + private static Stream> testMultiDevice() { + ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); + messageStream.write(1); // version + messageStream.write(2); // count + + messageStream.write(1); // device ID + messageStream.writeBytes(new byte[] { (byte)0, (byte)222 }); // registration ID + messageStream.write(1); // message type + messageStream.write(3); // message length + messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents + + messageStream.write(2); // device ID + messageStream.writeBytes(new byte[] { (byte)1, (byte)77 }); // registration ID (333 = 1 * 256 + 77) + messageStream.write(1); // message type + messageStream.write(3); // message length + messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents + + try { + return Stream.of( + Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_multi_device.json"), + IncomingMessageList.class), + MediaType.APPLICATION_JSON_TYPE), + Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE) + ); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + @ParameterizedTest + @MethodSource + void testRegistrationIdMismatch(Entity payload) throws Exception { Response response = resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_registration_id.json"), IncomingMessageList.class), - MediaType.APPLICATION_JSON_TYPE)); + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(payload); assertThat("Good Response Code", response.getStatus(), is(equalTo(410))); @@ -372,7 +454,35 @@ class MessageControllerTest { is(equalTo(jsonFixture("fixtures/mismatched_registration_id.json")))); verifyNoMoreInteractions(messageSender); + } + private static Stream> testRegistrationIdMismatch() { + ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); + messageStream.write(1); // version + messageStream.write(2); // count + + messageStream.write(1); // device ID + messageStream.writeBytes(new byte[] { (byte)0, (byte)222 }); // registration ID + messageStream.write(1); // message type + messageStream.write(3); // message length + messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents + + messageStream.write(2); // device ID + messageStream.writeBytes(new byte[] { (byte)0, (byte)77 }); // wrong registration ID + messageStream.write(1); // message type + messageStream.write(3); // message length + messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents + + try { + return Stream.of( + Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_registration_id.json"), + IncomingMessageList.class), + MediaType.APPLICATION_JSON_TYPE), + Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE) + ); + } catch (Exception e) { + throw new AssertionError(e); + } } @Test @@ -395,11 +505,10 @@ class MessageControllerTest { OutgoingMessageEntityList response = resources.getJerseyTest().target("/v1/messages/") - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .accept(MediaType.APPLICATION_JSON_TYPE) - .get(OutgoingMessageEntityList.class); - + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .accept(MediaType.APPLICATION_JSON_TYPE) + .get(OutgoingMessageEntityList.class); assertEquals(response.getMessages().size(), 2); @@ -429,10 +538,10 @@ class MessageControllerTest { Response response = resources.getJerseyTest().target("/v1/messages/") - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD)) - .accept(MediaType.APPLICATION_JSON_TYPE) - .get(); + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD)) + .accept(MediaType.APPLICATION_JSON_TYPE) + .get(); assertThat("Unauthorized response", response.getStatus(), is(equalTo(401))); } @@ -453,33 +562,32 @@ class MessageControllerTest { uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, null, System.currentTimeMillis(), "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0))); - UUID uuid3 = UUID.randomUUID(); when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid3)).thenReturn(Optional.empty()); Response response = resources.getJerseyTest() - .target(String.format("/v1/messages/uuid/%s", uuid1)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .delete(); + .target(String.format("/v1/messages/uuid/%s", uuid1)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .delete(); assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); verify(receiptSender).sendReceipt(any(AuthenticatedAccount.class), eq(sourceUuid), eq(timestamp)); response = resources.getJerseyTest() - .target(String.format("/v1/messages/uuid/%s", uuid2)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .delete(); + .target(String.format("/v1/messages/uuid/%s", uuid2)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .delete(); assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); verifyNoMoreInteractions(receiptSender); response = resources.getJerseyTest() - .target(String.format("/v1/messages/uuid/%s", uuid3)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .delete(); + .target(String.format("/v1/messages/uuid/%s", uuid3)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .delete(); assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); verifyNoMoreInteractions(receiptSender); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiDeviceMessageListProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiDeviceMessageListProviderTest.java new file mode 100644 index 000000000..ea8c3f586 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiDeviceMessageListProviderTest.java @@ -0,0 +1,169 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.providers; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.annotation.Annotation; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.MultivaluedHashMap; +import javax.ws.rs.core.NoContentException; + +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.entities.IncomingDeviceMessage; + +public class MultiDeviceMessageListProviderTest { + + static byte[] createByteArray(int... bytes) { + byte[] result = new byte[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + result[i] = (byte) bytes[i]; + } + return result; + } + + IncomingDeviceMessage[] tryRead(byte[] bytes) throws Exception { + MultiDeviceMessageListProvider provider = new MultiDeviceMessageListProvider(); + return provider.readFrom( + IncomingDeviceMessage[].class, + IncomingDeviceMessage[].class, + new Annotation[] {}, + MediaType.valueOf(MultiDeviceMessageListProvider.MEDIA_TYPE), + new MultivaluedHashMap<>(), + new ByteArrayInputStream(bytes)); + } + + @Test + void testInvalidVersion() { + assertThatThrownBy(() -> tryRead(createByteArray())) + .isInstanceOf(NoContentException.class); + assertThatThrownBy(() -> tryRead(createByteArray(0x00))) + .isInstanceOf(WebApplicationException.class); + assertThatThrownBy(() -> tryRead(createByteArray(0x59))) + .isInstanceOf(WebApplicationException.class); + } + + @Test + void testBadCount() { + assertThatThrownBy(() -> tryRead(createByteArray(MultiDeviceMessageListProvider.VERSION))) + .isInstanceOf(IOException.class); + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, + MultiDeviceMessageListProvider.MAX_MESSAGE_COUNT + 1))) + .isInstanceOf(WebApplicationException.class); + } + + @Test void testBadDeviceId() { + // Missing + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, 1))) + .isInstanceOf(IOException.class); + // Unfinished varint + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, 1, 0x80))) + .isInstanceOf(IOException.class); + } + + @Test void testBadRegistrationId() { + // Missing + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1))) + .isInstanceOf(IOException.class); + // Truncated (fixed u16 value) + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11))) + .isInstanceOf(IOException.class); + } + + @Test void testBadType() { + // Missing + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22))) + .isInstanceOf(IOException.class); + } + + @Test void testBadMessageLength() { + // Missing + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1))) + .isInstanceOf(IOException.class); + // Unfinished varint + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 0x80))) + .isInstanceOf(IOException.class); + // Too big + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 0x80, 0x80, 0x80, 0x01))) + .isInstanceOf(WebApplicationException.class); + // Missing message + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 0x01))) + .isInstanceOf(IOException.class); + // Truncated message + assertThatThrownBy(() -> tryRead(createByteArray( + MultiDeviceMessageListProvider.VERSION, 1, 0x80, 1, 0x11, 0x22, 1, 5, 0x01, 0x02))) + .isInstanceOf(IOException.class); + } + + @Test + void readThreeMessages() throws Exception { + ByteArrayOutputStream contentStream = new ByteArrayOutputStream(); + contentStream.write(MultiDeviceMessageListProvider.VERSION); + contentStream.write(3); + + contentStream.writeBytes(createByteArray(0x85, 0x02)); // Device ID 0x0105 + contentStream.writeBytes(createByteArray(0x11, 0x05)); // Registration ID 0x1105 + contentStream.write(1); + contentStream.write(5); + contentStream.writeBytes(createByteArray(11, 22, 33, 44, 55)); + + contentStream.write(0x20); // Device ID 0x20 + contentStream.writeBytes(createByteArray(0x30, 0x20)); // Registration ID 0x3020 + contentStream.write(6); + contentStream.writeBytes(createByteArray(0x81, 0x01)); // 129 bytes + contentStream.writeBytes(new byte[129]); + + contentStream.write(1); // Device ID 1 + contentStream.writeBytes(createByteArray(0x00, 0x11)); // Registration ID 0x0011 + contentStream.write(50); + contentStream.write(0); // empty message for some rateLimitReason + + IncomingDeviceMessage[] messages = tryRead(contentStream.toByteArray()); + assertThat(messages.length).isEqualTo(3); + + assertThat(messages[0].getDeviceId()).isEqualTo(0x0105); + assertThat(messages[0].getRegistrationId()).isEqualTo(0x1105); + assertThat(messages[0].getType()).isEqualTo(1); + assertThat(messages[0].getContent()).containsExactly(11, 22, 33, 44, 55); + + assertThat(messages[1].getDeviceId()).isEqualTo(0x20); + assertThat(messages[1].getRegistrationId()).isEqualTo(0x3020); + assertThat(messages[1].getType()).isEqualTo(6); + assertThat(messages[1].getContent()).containsExactly(new byte[129]); + + assertThat(messages[2].getDeviceId()).isEqualTo(1); + assertThat(messages[2].getRegistrationId()).isEqualTo(0x0011); + assertThat(messages[2].getType()).isEqualTo(50); + assertThat(messages[2].getContent()).isEmpty(); + } + + @Test + void emptyListIsStillValid() throws Exception { + ByteArrayOutputStream contentStream = new ByteArrayOutputStream(); + contentStream.write(MultiDeviceMessageListProvider.VERSION); + contentStream.write(0); + + IncomingDeviceMessage[] messages = tryRead(contentStream.toByteArray()); + assertThat(messages).isEmpty(); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProviderTest.java index e7ec942fe..3493af234 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProviderTest.java @@ -6,28 +6,38 @@ package org.whispersystems.textsecuregcm.providers; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.params.provider.Arguments.arguments; import java.io.ByteArrayInputStream; +import java.io.IOException; import java.util.stream.Stream; + +import javax.ws.rs.WebApplicationException; + +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; public class MultiRecipientMessageProviderTest { - static byte[] createTwoByteArray(int b1, int b2) { - return new byte[]{(byte) b1, (byte) b2}; + static byte[] createByteArray(int... bytes) { + byte[] result = new byte[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + result[i] = (byte)bytes[i]; + } + return result; } static Stream readU16TestCases() { return Stream.of( - arguments(0xFFFE, createTwoByteArray(0xFF, 0xFE)), - arguments(0x0001, createTwoByteArray(0x00, 0x01)), - arguments(0xBEEF, createTwoByteArray(0xBE, 0xEF)), - arguments(0xFFFF, createTwoByteArray(0xFF, 0xFF)), - arguments(0x0000, createTwoByteArray(0x00, 0x00)), - arguments(0xF080, createTwoByteArray(0xF0, 0x80)) + arguments(0xFFFE, createByteArray(0xFF, 0xFE)), + arguments(0x0001, createByteArray(0x00, 0x01)), + arguments(0xBEEF, createByteArray(0xBE, 0xEF)), + arguments(0xFFFF, createByteArray(0xFF, 0xFF)), + arguments(0x0000, createByteArray(0x00, 0x00)), + arguments(0xF080, createByteArray(0xF0, 0x80)) ); } @@ -38,7 +48,42 @@ public class MultiRecipientMessageProviderTest { assertThat(MultiRecipientMessageProvider.readU16(stream)).isEqualTo(expectedValue); } } - - + + static Stream readVarintTestCases() { + return Stream.of( + arguments(0x00L, createByteArray(0x00)), + arguments(0x01L, createByteArray(0x01)), + arguments(0x7FL, createByteArray(0x7F)), + arguments(0x80L, createByteArray(0x80, 0x01)), + arguments(0xFFL, createByteArray(0xFF, 0x01)), + arguments(0b1010101_0011001_1100110L, createByteArray(0b1_1100110, 0b1_0011001, 0b0_1010101)), + arguments(0x7FFFFFFF_FFFFFFFFL, createByteArray(0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F)), + arguments(-1L, createByteArray(0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01)) + ); + } + + @ParameterizedTest + @MethodSource("readVarintTestCases") + void testReadVarint(long expectedValue, byte[] input) throws Exception { + try (final ByteArrayInputStream stream = new ByteArrayInputStream(input)) { + assertThat(MultiRecipientMessageProvider.readVarint(stream)).isEqualTo(expectedValue); + } + } + + @Test + void testVarintEOF() { + assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray(0xFF, 0x80)))) + .isInstanceOf(IOException.class); + assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray()))) + .isInstanceOf(IOException.class); + } + + @Test + void testVarintTooBig() { + assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray(0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02)))) + .isInstanceOf(WebApplicationException.class); + assertThatThrownBy(() -> MultiRecipientMessageProvider.readVarint(new ByteArrayInputStream(createByteArray(0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80)))) + .isInstanceOf(WebApplicationException.class); + } }