From 2ab3c97ee8e59591c85f8701cd4c72dedddbe2df Mon Sep 17 00:00:00 2001 From: Jordan Rose Date: Fri, 8 Dec 2023 08:52:47 -0800 Subject: [PATCH] Replace MultiRecipientMessage parsing with libsignal's implementation Co-authored-by: Jonathan Klabunde Tomer --- pom.xml | 2 +- .../controllers/MessageController.java | 72 ++++----- .../entities/MultiRecipientMessage.java | 97 ------------- .../identity/ServiceIdentifier.java | 11 ++ .../MultiRecipientMessageProvider.java | 137 ++---------------- .../textsecuregcm/util/Pair.java | 30 +--- .../controllers/MessageControllerTest.java | 11 +- .../MultiRecipientMessageProviderTest.java | 41 ------ 8 files changed, 64 insertions(+), 337 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProviderTest.java diff --git a/pom.xml b/pom.xml index a84f53826..a3e7cba79 100644 --- a/pom.xml +++ b/pom.xml @@ -284,7 +284,7 @@ org.signal libsignal-server - 0.33.0 + 0.35.0 org.apache.logging.log4j diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index ab18c7673..6302a51eb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -25,13 +25,9 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse; import java.security.MessageDigest; import java.time.Duration; import java.util.ArrayList; -import java.util.Arrays; import java.util.Base64; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -42,9 +38,6 @@ import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; -import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -73,6 +66,9 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; import org.apache.commons.lang3.StringUtils; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient; +import org.signal.libsignal.protocol.util.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.Anonymous; @@ -88,8 +84,6 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; -import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage; -import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMessageResponse; @@ -118,13 +112,13 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; -import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.websocket.Stories; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; +import reactor.util.function.Tuples; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @Path("/v1/messages") @@ -134,7 +128,8 @@ public class MessageController { private record MultiRecipientDeliveryData( ServiceIdentifier serviceIdentifier, Account account, - Map perDeviceData) { + Recipient recipient, + Map deviceIdToRegistrationId) { } private static final Logger logger = LoggerFactory.getLogger(MessageController.class); @@ -369,20 +364,22 @@ public class MessageController { * Build mapping of service IDs to resolved accounts and device/registration IDs */ private Map buildRecipientMap( - MultiRecipientMessage multiRecipientMessage, boolean isStory) { - return Flux.fromArray(multiRecipientMessage.recipients()) - .groupBy(Recipient::uuid, multiRecipientMessage.recipients().length) + SealedSenderMultiRecipientMessage multiRecipientMessage, boolean isStory) { + return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet()) + .map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue())) .flatMap( - gf -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(gf.key())) + t -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(t.getT1())) .switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new)) - .flatMap( + .map( account -> - gf.collectMap(Recipient::deviceId) - .map(perRecipientData -> - new MultiRecipientDeliveryData( - gf.key(), - account, - perRecipientData)))) + new MultiRecipientDeliveryData( + t.getT1(), + account, + t.getT2(), + t.getT2().getDevicesAndRegistrationIds().collect( + Collectors.toMap(Pair::first, Pair::second)))) + // IllegalStateException is thrown by Collectors#toMap when we have multiple entries for the same device + .onErrorMap(e -> e instanceof IllegalStateException ? new BadRequestException() : e)) .collectMap(MultiRecipientDeliveryData::serviceIdentifier) .block(); } @@ -429,8 +426,8 @@ public class MessageController { @Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted") @QueryParam("story") boolean isStory, - @Parameter(description="The sealed-sender multi-recipient message payload") - @NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException { + @Parameter(description="The sealed-sender multi-recipient message payload as serialized by libsignal") + @NotNull SealedSenderMultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException { final Map recipients = buildRecipientMap(multiRecipientMessage, isStory); @@ -456,13 +453,13 @@ public class MessageController { final Account account = recipient.account(); try { - DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.perDeviceData().keySet(), Collections.emptySet()); + DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(), Collections.emptySet()); DestinationDeviceValidator.validateRegistrationIds( account, - recipient.perDeviceData().values(), - Recipient::deviceId, - Recipient::registrationId, + recipient.deviceIdToRegistrationId().entrySet(), + Map.Entry::getKey, + e -> Integer.valueOf(e.getValue()), recipient.serviceIdentifier().identityType() == IdentityType.PNI); } catch (MismatchedDevicesException e) { accountMismatchedDevices.add( @@ -500,17 +497,19 @@ public class MessageController { CompletableFuture.allOf( recipients.values().stream() .flatMap(recipientData -> - recipientData.perDeviceData().values().stream().map( - recipient -> CompletableFuture.runAsync( + recipientData.deviceIdToRegistrationId().keySet().stream().map( + deviceId ->CompletableFuture.runAsync( () -> { final Account destinationAccount = recipientData.account(); + final byte[] payload = multiRecipientMessage.messageForRecipient(recipientData.recipient()); + // we asserted this must exist in validateCompleteDeviceList - final Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow(); + final Device destinationDevice = destinationAccount.getDevice(deviceId).orElseThrow(); try { sentMessageCounter.increment(); sendCommonPayloadMessage( destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, online, - isStory, isUrgent, recipient, multiRecipientMessage.commonPayload()); + isStory, isUrgent, payload); } catch (NoSuchUserException e) { // this should never happen, because we already asserted the device is present and enabled Metrics.counter( @@ -740,17 +739,10 @@ public class MessageController { boolean online, boolean story, boolean urgent, - Recipient recipient, - byte[] commonPayload) throws NoSuchUserException { + byte[] payload) throws NoSuchUserException { try { Envelope.Builder messageBuilder = Envelope.newBuilder(); long serverTimestamp = System.currentTimeMillis(); - byte[] recipientKeyMaterial = recipient.perRecipientKeyMaterial(); - - byte[] payload = new byte[1 + recipientKeyMaterial.length + commonPayload.length]; - payload[0] = MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER; - System.arraycopy(recipientKeyMaterial, 0, payload, 1, recipientKeyMaterial.length); - System.arraycopy(commonPayload, 0, payload, 1 + recipientKeyMaterial.length, commonPayload.length); messageBuilder .setType(Type.UNIDENTIFIED_SENDER) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java deleted file mode 100644 index edec47ab6..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright 2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.entities; - -import static com.codahale.metrics.MetricRegistry.name; - -import java.util.Arrays; -import java.util.Objects; -import javax.validation.Valid; -import javax.validation.constraints.AssertTrue; -import javax.validation.constraints.Max; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; -import javax.validation.constraints.Size; - -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import org.whispersystems.textsecuregcm.controllers.MessageController; -import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; -import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; -import org.whispersystems.textsecuregcm.util.Pair; - -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; -import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter; - -public record MultiRecipientMessage( - @NotNull @Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT) @Valid Recipient[] recipients, - @NotNull @Size(min = 32) byte[] commonPayload) { - - private static final Counter REJECT_DUPLICATE_RECIPIENT_COUNTER = - Metrics.counter( - name(MessageController.class, "rejectDuplicateRecipients"), - "multiRecipient", "false"); - - public record Recipient(@NotNull - @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) - @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) - ServiceIdentifier uuid, - @Min(1) byte deviceId, - @Min(0) @Max(65535) int registrationId, - @Size(min = 48, max = 48) @NotNull byte[] perRecipientKeyMaterial) { - - @Override - public boolean equals(final Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - Recipient recipient = (Recipient) o; - return deviceId == recipient.deviceId && registrationId == recipient.registrationId && uuid.equals(recipient.uuid) - && Arrays.equals(perRecipientKeyMaterial, recipient.perRecipientKeyMaterial); - } - - @Override - public int hashCode() { - int result = Objects.hash(uuid, deviceId, registrationId); - result = 31 * result + Arrays.hashCode(perRecipientKeyMaterial); - return result; - } - } - - public MultiRecipientMessage(Recipient[] recipients, byte[] commonPayload) { - this.recipients = recipients; - this.commonPayload = commonPayload; - } - - @AssertTrue - public boolean hasNoDuplicateRecipients() { - boolean valid = - Arrays.stream(recipients).map(r -> new Pair<>(r.uuid(), r.deviceId())).distinct().count() == recipients.length; - if (!valid) { - REJECT_DUPLICATE_RECIPIENT_COUNTER.increment(); - } - return valid; - } - - @Override - public boolean equals(final Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - MultiRecipientMessage that = (MultiRecipientMessage) o; - return Arrays.equals(recipients, that.recipients) && Arrays.equals(commonPayload, that.commonPayload); - } - - @Override - public int hashCode() { - int result = Arrays.hashCode(recipients); - result = 31 * result + Arrays.hashCode(commonPayload); - return result; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java b/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java index c012f924c..ab52312a8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.identity; import io.swagger.v3.oas.annotations.media.Schema; import java.util.UUID; +import org.signal.libsignal.protocol.ServiceId; /** * A "service identifier" is a tuple of a UUID and identity type that identifies an account and identity within the @@ -70,4 +71,14 @@ public interface ServiceIdentifier { return PniServiceIdentifier.fromBytes(bytes); } } + + static ServiceIdentifier fromLibsignal(final ServiceId libsignalServiceId) { + if (libsignalServiceId instanceof ServiceId.Aci) { + return new AciServiceIdentifier(libsignalServiceId.getRawUUID()); + } + if (libsignalServiceId instanceof ServiceId.Pni) { + return new PniServiceIdentifier(libsignalServiceId.getRawUUID()); + } + throw new IllegalArgumentException("unknown libsignal ServiceId type"); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java index bdcf02ef9..236789b84 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java @@ -5,7 +5,6 @@ package org.whispersystems.textsecuregcm.providers; -import com.google.common.annotations.VisibleForTesting; import io.dropwizard.util.DataSizeUnit; import java.io.IOException; import java.io.InputStream; @@ -19,150 +18,36 @@ import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.NoContentException; import javax.ws.rs.ext.MessageBodyReader; import javax.ws.rs.ext.Provider; -import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage; -import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; +import org.signal.libsignal.protocol.InvalidMessageException; +import org.signal.libsignal.protocol.InvalidVersionException; @Provider @Consumes(MultiRecipientMessageProvider.MEDIA_TYPE) -public class MultiRecipientMessageProvider implements MessageBodyReader { +public class MultiRecipientMessageProvider implements MessageBodyReader { public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm"; public static final int MAX_RECIPIENT_COUNT = 5000; public static final int MAX_MESSAGE_SIZE = Math.toIntExact(32 + DataSizeUnit.KIBIBYTES.toBytes(256)); - public static final byte AMBIGUOUS_ID_VERSION_IDENTIFIER = 0x22; - public static final byte EXPLICIT_ID_VERSION_IDENTIFIER = 0x23; - - private enum Version { - AMBIGUOUS_ID(AMBIGUOUS_ID_VERSION_IDENTIFIER), - EXPLICIT_ID(EXPLICIT_ID_VERSION_IDENTIFIER); - - private final byte identifier; - - Version(final byte identifier) { - this.identifier = identifier; - } - - static Version forVersionByte(final byte versionByte) { - for (final Version version : values()) { - if (version.identifier == versionByte) { - return version; - } - } - - throw new IllegalArgumentException("Unrecognized version byte: " + versionByte); - } - } - @Override public boolean isReadable(Class type, Type genericType, Annotation[] annotations, MediaType mediaType) { - return MEDIA_TYPE.equals(mediaType.toString()) && MultiRecipientMessage.class.isAssignableFrom(type); + return MEDIA_TYPE.equals(mediaType.toString()) && SealedSenderMultiRecipientMessage.class.isAssignableFrom(type); } @Override - public MultiRecipientMessage readFrom(Class type, Type genericType, Annotation[] annotations, + public SealedSenderMultiRecipientMessage readFrom(Class type, Type genericType, Annotation[] annotations, MediaType mediaType, MultivaluedMap httpHeaders, InputStream entityStream) throws IOException, WebApplicationException { - int versionByte = entityStream.read(); - if (versionByte == -1) { + byte[] fullMessage = entityStream.readNBytes(MAX_MESSAGE_SIZE + MAX_RECIPIENT_COUNT * 100); + if (fullMessage.length == 0) { throw new NoContentException("Empty body not allowed"); } - final Version version; - try { - version = Version.forVersionByte((byte) versionByte); - } catch (final IllegalArgumentException e) { - throw new BadRequestException("Unsupported version"); + return SealedSenderMultiRecipientMessage.parse(fullMessage); + } catch (InvalidMessageException | InvalidVersionException e) { + throw new BadRequestException(e); } - - long count = readVarint(entityStream); - if (count > MAX_RECIPIENT_COUNT) { - throw new BadRequestException("Maximum recipient count exceeded"); - } - MultiRecipientMessage.Recipient[] recipients = new MultiRecipientMessage.Recipient[Math.toIntExact(count)]; - for (int i = 0; i < Math.toIntExact(count); i++) { - ServiceIdentifier identifier = readIdentifier(entityStream, version); - final byte deviceId; - { - long deviceIdLong = readVarint(entityStream); - if (deviceIdLong > Byte.MAX_VALUE) { - throw new BadRequestException("Invalid device ID"); - } - deviceId = (byte) deviceIdLong; - } - int registrationId = readU16(entityStream); - byte[] perRecipientKeyMaterial = entityStream.readNBytes(48); - if (perRecipientKeyMaterial.length != 48) { - throw new IOException("Failed to read expected number of key material bytes for a recipient"); - } - recipients[i] = new MultiRecipientMessage.Recipient(identifier, deviceId, registrationId, perRecipientKeyMaterial); - } - - // caller is responsible for checking that the entity stream is at EOF when we return; if there are more bytes than - // this it'll return an error back. We just need to limit how many we'll accept here. - byte[] commonPayload = entityStream.readNBytes(MAX_MESSAGE_SIZE); - if (commonPayload.length < 32) { - throw new IOException("Failed to read expected number of common key material bytes"); - } - return new MultiRecipientMessage(recipients, commonPayload); - } - - /** - * Reads a service identifier from the given stream. - */ - private ServiceIdentifier readIdentifier(final InputStream stream, final Version version) throws IOException { - final byte[] uuidBytes = switch (version) { - case AMBIGUOUS_ID -> stream.readNBytes(16); - case EXPLICIT_ID -> stream.readNBytes(17); - }; - - return ServiceIdentifier.fromBytes(uuidBytes); - } - - /** - * Reads a varint. A varint larger than 64 bits is rejected with a {@code WebApplicationException}. An - * {@code IOException} is thrown if the stream ends before we finish reading the varint. - * - * @return the varint value - */ - @VisibleForTesting - public static long readVarint(InputStream stream) throws IOException, WebApplicationException { - boolean hasMore = true; - int currentOffset = 0; - long result = 0; - while (hasMore) { - if (currentOffset >= 64) { - throw new BadRequestException("varint is too large"); - } - int b = stream.read(); - if (b == -1) { - throw new IOException("Missing byte " + (currentOffset / 7) + " of varint"); - } - if (currentOffset == 63 && (b & 0xFE) != 0) { - throw new BadRequestException("varint is too large"); - } - hasMore = (b & 0x80) != 0; - result |= (b & 0x7FL) << currentOffset; - currentOffset += 7; - } - return result; - } - - /** - * Reads two bytes with most significant byte first. Treats the value as unsigned so the range returned is - * {@code [0, 65535]}. - */ - @VisibleForTesting - static int readU16(InputStream stream) throws IOException { - int b1 = stream.read(); - if (b1 == -1) { - throw new IOException("Missing byte 1 of U16"); - } - int b2 = stream.read(); - if (b2 == -1) { - throw new IOException("Missing byte 2 of U16"); - } - return (b1 << 8) | b2; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/Pair.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/Pair.java index 014a5e906..70fab8a98 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/Pair.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/Pair.java @@ -4,32 +4,14 @@ */ package org.whispersystems.textsecuregcm.util; -import static com.google.common.base.Objects.equal; +import java.util.Map; -public class Pair { - private final T1 v1; - private final T2 v2; - - public Pair(T1 v1, T2 v2) { - this.v1 = v1; - this.v2 = v2; +public record Pair(T1 first, T2 second) { + public Pair(org.signal.libsignal.protocol.util.Pair p) { + this(p.first(), p.second()); } - public T1 first() { - return v1; - } - - public T2 second() { - return v2; - } - - public boolean equals(Object o) { - return o instanceof Pair && - equal(((Pair) o).first(), first()) && - equal(((Pair) o).second(), second()); - } - - public int hashCode() { - return first().hashCode() ^ second().hashCode(); + public Pair(Map.Entry e) { + this(e.getKey(), e.getValue()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 9826f93aa..06fca3edb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -8,8 +8,8 @@ package org.whispersystems.textsecuregcm.controllers; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; -import static org.hamcrest.collection.IsEmptyCollection.empty; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.collection.IsEmptyCollection.empty; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; @@ -19,7 +19,6 @@ import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -55,7 +54,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; -import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; @@ -78,7 +76,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.ArgumentsSources; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.junitpioneer.jupiter.cartesian.ArgumentSets; @@ -981,9 +978,7 @@ class MessageControllerTest { bb.order(ByteOrder.BIG_ENDIAN); // first write the header - bb.put(explicitIdentifiers - ? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER - : MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte + bb.put(explicitIdentifiers ? (byte) 0x23 : (byte) 0x22); // version byte // count varint int nRecip = recipients.size(); @@ -1258,7 +1253,7 @@ class MessageControllerTest { .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE)); - checkBadMultiRecipientResponse(response, 422); + checkBadMultiRecipientResponse(response, 400); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProviderTest.java deleted file mode 100644 index ff3240759..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProviderTest.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.providers; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.params.provider.Arguments.arguments; - -import java.io.ByteArrayInputStream; -import java.util.stream.Stream; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; - -public class MultiRecipientMessageProviderTest { - - static byte[] createTwoByteArray(int b1, int b2) { - return new byte[]{(byte) b1, (byte) b2}; - } - - static Stream readU16TestCases() { - return Stream.of( - arguments(0xFFFE, createTwoByteArray(0xFF, 0xFE)), - arguments(0x0001, createTwoByteArray(0x00, 0x01)), - arguments(0xBEEF, createTwoByteArray(0xBE, 0xEF)), - arguments(0xFFFF, createTwoByteArray(0xFF, 0xFF)), - arguments(0x0000, createTwoByteArray(0x00, 0x00)), - arguments(0xF080, createTwoByteArray(0xF0, 0x80)) - ); - } - - @ParameterizedTest - @MethodSource("readU16TestCases") - void testReadU16(int expectedValue, byte[] input) throws Exception { - try (final ByteArrayInputStream stream = new ByteArrayInputStream(input)) { - assertThat(MultiRecipientMessageProvider.readU16(stream)).isEqualTo(expectedValue); - } - } -}