From 20392a567bef27f47a593bd9af6e791f07420de8 Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer Date: Fri, 1 Dec 2023 14:39:31 -0800 Subject: [PATCH] Revert "multisend cleanup" This reverts commit c03249b411217b890b1684a1f32f49d77b03c1e3. --- .../controllers/MessageController.java | 266 ++++++------ .../controllers/MessageControllerTest.java | 404 +++++++++--------- .../textsecuregcm/util/MockUtils.java | 17 +- 3 files changed, 343 insertions(+), 344 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 3d783670c..adb208c9e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -16,12 +16,6 @@ import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; -import io.swagger.v3.oas.annotations.Operation; -import io.swagger.v3.oas.annotations.Parameter; -import io.swagger.v3.oas.annotations.media.Content; -import io.swagger.v3.oas.annotations.media.Schema; -import io.swagger.v3.oas.annotations.responses.ApiResponse; - import java.security.MessageDigest; import java.time.Duration; import java.util.ArrayList; @@ -43,8 +37,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; -import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; import javax.annotation.Nonnull; @@ -95,7 +87,6 @@ import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageRespon import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; -import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -120,8 +111,6 @@ import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.websocket.Stories; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @@ -129,12 +118,6 @@ import reactor.core.scheduler.Scheduler; @io.swagger.v3.oas.annotations.tags.Tag(name = "Messages") public class MessageController { - private record MessageRecipient( - ServiceIdentifier serviceIdentifier, - Account account, - Map perDeviceData) { - } - private static final Logger logger = LoggerFactory.getLogger(MessageController.class); private final RateLimiters rateLimiters; @@ -155,9 +138,9 @@ public class MessageController { private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize"); private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes"); private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage"); + private static final String RATE_LIMITED_STORIES_COUNTER_NAME = name(MessageController.class, "rateLimitedStory"); private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType"); - private static final String UNEXPECTED_MISSING_USER_COUNTER_NAME = name(MessageController.class, "unexpectedMissingDestinationForMultiRecipientMessage"); private static final String EPHEMERAL_TAG_NAME = "ephemeral"; private static final String SENDER_TYPE_TAG_NAME = "senderType"; @@ -360,25 +343,26 @@ public class MessageController { /** - * Build mapping of service IDs to resolved accounts and device/registration IDs + * Build mapping of accounts to devices/registration IDs. */ - private Map buildRecipientMap( - MultiRecipientMessage multiRecipientMessage, boolean isStory) { - return Flux.fromArray(multiRecipientMessage.recipients()) - .groupBy(Recipient::uuid) - .flatMap( - gf -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(gf.key())) - .switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new)) - .flatMap( - account -> - gf.collectMap(Recipient::deviceId) - .map(perRecipientData -> - new MessageRecipient( - gf.key(), - account, - perRecipientData)))) - .collectMap(MessageRecipient::serviceIdentifier) - .block(); + private Map>> buildDeviceIdAndRegistrationIdMap( + MultiRecipientMessage multiRecipientMessage, + Map accountsByServiceIdentifier) { + + return Arrays.stream(multiRecipientMessage.recipients()) + // for normal messages, all recipients UUIDs are in the map, + // but story messages might specify inactive UUIDs, which we + // have previously filtered + .filter(r -> accountsByServiceIdentifier.containsKey(r.uuid())) + .collect(Collectors.toMap( + recipient -> accountsByServiceIdentifier.get(recipient.uuid()), + recipient -> new HashSet<>( + Collections.singletonList(new Pair<>(recipient.deviceId(), recipient.registrationId()))), + (a, b) -> { + a.addAll(b); + return a; + } + )); } @Timed @@ -387,87 +371,79 @@ public class MessageController { @Consumes(MultiRecipientMessageProvider.MEDIA_TYPE) @Produces(MediaType.APPLICATION_JSON) @FilterSpam - @Operation( - summary = "Send multi-recipient sealed-sender message", - description = """ - Deliver a common-payload message to multiple recipients. - An unidentifed-access key for all recipients must be provided, unless the message is a story. - """) - @ApiResponse(responseCode="200", description="Message was successfully sent to all recipients", useReturnTypeSchema=true) - @ApiResponse(responseCode="400", description="The envelope specified delivery to the same recipient device multiple times") - @ApiResponse(responseCode="401", description="The message is not a story and the unauthorized access key is incorrect") - @ApiResponse( - responseCode="404", - description="The message is not a story and some of the recipient service IDs do not correspond to registered Signal users") - @ApiResponse( - responseCode = "409", description = "Incorrect set of devices supplied for some recipients", - content = @Content(schema = @Schema(implementation = AccountMismatchedDevices[].class))) - @ApiResponse( - responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices", - content = @Content(schema = @Schema(implementation = AccountStaleDevices[].class))) - public Response sendMultiRecipientMessage( - @Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message") @HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys, - @HeaderParam(HttpHeaders.USER_AGENT) String userAgent, - - @Parameter(description="If true, deliver the message only to recipients that are online when it is sent") @QueryParam("online") boolean online, - - @Parameter(description="The sender's timestamp for the envelope") @QueryParam("ts") long timestamp, - - @Parameter(description="If true, this message should cause push notifications to be sent to recipients") @QueryParam("urgent") @DefaultValue("true") final boolean isUrgent, - - @Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted") @QueryParam("story") boolean isStory, - @Parameter(description="The sealed-sender multi-recipient message payload") @NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException { - final Map recipients = buildRecipientMap(multiRecipientMessage, isStory); + final Map accountsByServiceIdentifier = new HashMap<>(); + + for (final Recipient recipient : multiRecipientMessage.recipients()) { + if (!accountsByServiceIdentifier.containsKey(recipient.uuid())) { + final Optional maybeAccount = accountsManager.getByServiceIdentifier(recipient.uuid()); + + if (maybeAccount.isPresent()) { + accountsByServiceIdentifier.put(recipient.uuid(), maybeAccount.get()); + } else { + if (!isStory) { + throw new NotFoundException(); + } + } + } + } // Stories will be checked by the client; we bypass access checks here for stories. if (!isStory) { - checkAccessKeys(accessKeys, recipients.values()); + checkAccessKeys(accessKeys, accountsByServiceIdentifier.values()); } - // We might filter out all the recipients of a story (if none exist). + final Map>> accountToDeviceIdAndRegistrationIdMap = + buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, accountsByServiceIdentifier); + + // We might filter out all the recipients of a story (if none have enabled stories). // In this case there is no error so we should just return 200 now. - if (isStory) { - if (recipients.isEmpty()) { - return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build(); - } - for (MessageRecipient recipient : recipients.values()) { - rateLimiters.getStoriesLimiter().validate(recipient.account().getUuid()); - } + if (isStory && accountToDeviceIdAndRegistrationIdMap.isEmpty()) { + return Response.ok(new SendMultiRecipientMessageResponse(new LinkedList<>())).build(); } Collection accountMismatchedDevices = new ArrayList<>(); Collection accountStaleDevices = new ArrayList<>(); - recipients.values().forEach(recipient -> { - final Account account = recipient.account(); - try { - DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.perDeviceData().keySet(), Collections.emptySet()); + for (Map.Entry entry : accountsByServiceIdentifier.entrySet()) { + final ServiceIdentifier serviceIdentifier = entry.getKey(); + final Account account = entry.getValue(); + + if (isStory) { + rateLimiters.getStoriesLimiter().validate(account.getUuid()); + } + + Set deviceIds = accountToDeviceIdAndRegistrationIdMap + .getOrDefault(account, Collections.emptySet()) + .stream() + .map(Pair::first) + .collect(Collectors.toSet()); + + try { + DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet()); + + // Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number + // identity + DestinationDeviceValidator.validateRegistrationIds( + account, + accountToDeviceIdAndRegistrationIdMap.get(account).stream(), + false); + } catch (MismatchedDevicesException e) { + accountMismatchedDevices.add(new AccountMismatchedDevices(serviceIdentifier, + new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); + } catch (StaleDevicesException e) { + accountStaleDevices.add(new AccountStaleDevices(serviceIdentifier, new StaleDevices(e.getStaleDevices()))); + } + } - DestinationDeviceValidator.validateRegistrationIds( - account, - recipient.perDeviceData().values(), - Recipient::deviceId, - Recipient::registrationId, - recipient.serviceIdentifier().identityType() == IdentityType.PNI); - } catch (MismatchedDevicesException e) { - accountMismatchedDevices.add( - new AccountMismatchedDevices( - recipient.serviceIdentifier(), - new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); - } catch (StaleDevicesException e) { - accountStaleDevices.add( - new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices()))); - } - }); if (!accountMismatchedDevices.isEmpty()) { return Response .status(409) @@ -492,28 +468,25 @@ public class MessageController { Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED))); CompletableFuture.allOf( - recipients.values().stream() - .flatMap(recipientData -> - recipientData.perDeviceData().values().stream().map( - recipient -> CompletableFuture.runAsync( - () -> { - final Account destinationAccount = recipientData.account(); - // we asserted this must exist in validateCompleteDeviceList - final Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow(); - try { - sentMessageCounter.increment(); - sendCommonPayloadMessage( - destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, online, - isStory, isUrgent, recipient, multiRecipientMessage.commonPayload()); - } catch (NoSuchUserException e) { - // this should never happen, because we already asserted the device is present and enabled - Metrics.counter( - UNEXPECTED_MISSING_USER_COUNTER_NAME, - Tags.of("isPrimary", String.valueOf(destinationDevice.isPrimary()))).increment(); - uuids404.add(recipientData.serviceIdentifier()); - } - }, - multiRecipientMessageExecutor))) + Arrays.stream(multiRecipientMessage.recipients()) + // If we're sending a story, some recipients might not map to existing accounts + .filter(recipient -> accountsByServiceIdentifier.containsKey(recipient.uuid())) + .map( + recipient -> CompletableFuture.runAsync( + () -> { + Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid()); + + // we asserted this must exist in validateCompleteDeviceList + Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow(); + sentMessageCounter.increment(); + try { + sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent, + recipient, multiRecipientMessage.commonPayload()); + } catch (NoSuchUserException e) { + uuids404.add(recipient.uuid()); + } + }, + multiRecipientMessageExecutor)) .toArray(CompletableFuture[]::new)) .get(); } catch (InterruptedException e) { @@ -529,31 +502,43 @@ public class MessageController { return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build(); } - private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection destinations) { + private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection destinationAccounts) { // We should not have null access keys when checking access; bail out early. if (accessKeys == null) { throw new WebApplicationException(Status.UNAUTHORIZED); } - destinations.stream() - .map(MessageRecipient::account) - .filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess)) - .map(account -> account.getUnidentifiedAccessKey().orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED))) - .reduce( - (bytes, bytes2) -> { - if (bytes.length != bytes2.length) { - throw new WebApplicationException(Status.UNAUTHORIZED); - } - for (int i = 0; i < bytes.length; i++) { - bytes[i] ^= bytes2[i]; - } - return bytes; - }) - .ifPresent( - combinedUnidentifiedAccessKeys -> { - if (!MessageDigest.isEqual(combinedUnidentifiedAccessKeys, accessKeys.getAccessKeys())) { - throw new WebApplicationException(Status.UNAUTHORIZED); - } - }); + AtomicBoolean throwUnauthorized = new AtomicBoolean(false); + byte[] empty = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]; + final Optional UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + byte[] combinedUnknownAccessKeys = destinationAccounts.stream() + .map(account -> { + if (account.isUnrestrictedUnidentifiedAccess()) { + return UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY; + } else { + return account.getUnidentifiedAccessKey(); + } + }) + .map(accessKey -> { + if (accessKey.isEmpty()) { + throwUnauthorized.set(true); + return empty; + } + return accessKey.get(); + }) + .reduce(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH], (bytes, bytes2) -> { + if (bytes.length != bytes2.length) { + throwUnauthorized.set(true); + return bytes; + } + for (int i = 0; i < bytes.length; i++) { + bytes[i] ^= bytes2[i]; + } + return bytes; + }); + if (throwUnauthorized.get() + || !MessageDigest.isEqual(combinedUnknownAccessKeys, accessKeys.getAccessKeys())) { + throw new WebApplicationException(Status.UNAUTHORIZED); + } } @Timed @@ -731,7 +716,6 @@ public class MessageController { private void sendCommonPayloadMessage(Account destinationAccount, Device destinationDevice, - ServiceIdentifier serviceIdentifier, long timestamp, boolean online, boolean story, @@ -755,7 +739,7 @@ public class MessageController { .setContent(ByteString.copyFrom(payload)) .setStory(story) .setUrgent(urgent) - .setDestinationUuid(serviceIdentifier.toServiceIdentifierString()); + .setDestinationUuid(new AciServiceIdentifier(destinationAccount.getUuid()).toServiceIdentifierString()); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); } catch (NotPushRegisteredException e) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 057b40a84..b6c4ae0ee 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.controllers; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; -import static org.hamcrest.collection.IsEmptyCollection.empty; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -30,7 +29,6 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; -import static org.whispersystems.textsecuregcm.util.MockUtils.exactly; import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.collect.ImmutableSet; @@ -44,23 +42,21 @@ import java.io.ByteArrayInputStream; import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedList; import java.util.List; -import java.util.Map; import java.util.Optional; +import java.util.Random; import java.util.Set; import java.util.UUID; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; -import java.util.function.Function; -import java.util.stream.Collectors; import java.util.stream.Stream; import javax.ws.rs.client.Entity; import javax.ws.rs.client.Invocation; @@ -77,11 +73,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.ArgumentsSources; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; -import org.junitpioneer.jupiter.cartesian.ArgumentSets; -import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.mockito.ArgumentCaptor; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; @@ -99,6 +92,8 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; +import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage; +import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; @@ -144,7 +139,6 @@ class MessageControllerTest { private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111"); private static final byte SINGLE_DEVICE_ID1 = 1; private static final int SINGLE_DEVICE_REG_ID1 = 111; - private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111; private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222"); @@ -155,11 +149,6 @@ class MessageControllerTest { private static final int MULTI_DEVICE_REG_ID1 = 222; private static final int MULTI_DEVICE_REG_ID2 = 333; private static final int MULTI_DEVICE_REG_ID3 = 444; - private static final int MULTI_DEVICE_PNI_REG_ID1 = 2222; - private static final int MULTI_DEVICE_PNI_REG_ID2 = 3333; - private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444; - - private static final UUID NONEXISTENT_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333"); private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes(); @@ -203,13 +192,13 @@ class MessageControllerTest { final List singleDeviceList = List.of( - generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()) + generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()) ); final List multiDeviceList = List.of( - generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), - generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), - generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) + generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), + generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), + generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) ); Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); @@ -222,8 +211,6 @@ class MessageControllerTest { when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount)); - when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty()); - when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty()); final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration = mock(DynamicInboundMessageByteLimitConfiguration.class); @@ -935,21 +922,25 @@ class MessageControllerTest { ); } - private record Recipient(ServiceIdentifier uuid, - byte deviceId, - int registrationId, - byte[] perRecipientKeyMaterial) { + private static void writePayloadDeviceId(ByteBuffer bb, byte deviceId) { + long x = deviceId; + // write the device-id in the 7-bit varint format we use, least significant bytes first. + do { + long b = x & 0x7f; + x = x >>> 7; + if (x != 0) b |= 0x80; + bb.put((byte)b); + } while (x != 0); } - private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, - final boolean useExplicitIdentifier) { + private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, final boolean useExplicitIdentifier) { if (useExplicitIdentifier) { bb.put(r.uuid().toFixedWidthByteArray()); } else { bb.put(UUIDUtil.toBytes(r.uuid().uuid())); } - bb.put(r.deviceId()); // device id (1 byte) + writePayloadDeviceId(bb, r.deviceId()); // device id (1-9 bytes) bb.putShort((short) r.registrationId()); // registration id (2 bytes) bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes) } @@ -962,8 +953,8 @@ class MessageControllerTest { // first write the header bb.put(explicitIdentifiers ? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER - : MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte - bb.put((byte)recipients.size()); // count varint + : MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte + bb.put((byte)recipients.size()); // count varint Iterator it = recipients.iterator(); while (it.hasNext()) { @@ -977,24 +968,23 @@ class MessageControllerTest { return new ByteArrayInputStream(buffer, 0, bb.position()); } - // see testMultiRecipientMessageNoPni and testMultiRecipientMessagePni below for actual invocations - private void testMultiRecipientMessage( - Map> destinations, - boolean authorize, - boolean isStory, - boolean urgent, - boolean explicitIdentifier, - int expectedStatus, - int expectedMessagesSent) throws Exception { - final List recipients = new ArrayList<>(); - destinations.forEach( - (serviceIdentifier, deviceToRegistrationId) -> - deviceToRegistrationId.forEach( - (deviceId, registrationId) -> - recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48])))); + @ParameterizedTest + @MethodSource + void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent, boolean explicitIdentifier) throws Exception { + + final List recipients; + if (recipientUUID == MULTI_DEVICE_UUID) { + recipients = List.of( + new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), + new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]) + ); + } else { + recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); + } // initialize our binary payload and create an input stream byte[] buffer = new byte[2048]; + //InputStream stream = initializeMultiPayload(recipientUUID, buffer); InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier); // set up the entity to use in our PUT request @@ -1013,160 +1003,124 @@ class MessageControllerTest { // add access header if needed if (authorize) { - final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count(); - String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]); + String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES); bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes); } // make the PUT request Response response = bldr.put(entity); - assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus))); - verify(messageSender, - exactly(expectedMessagesSent)) - .sendMessage( - any(), - any(), - argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()), - anyBoolean()); - if (expectedStatus == 200) { - SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); - assertThat(smrmr.uuids404(), is(empty())); + if (authorize) { + ArgumentCaptor envelopeArgumentCaptor = ArgumentCaptor.forClass(Envelope.class); + verify(messageSender, atLeastOnce()).sendMessage(any(), any(), envelopeArgumentCaptor.capture(), anyBoolean()); + assertEquals(urgent, envelopeArgumentCaptor.getValue().getUrgent()); + } + + // We have a 2x2x2 grid of possible situations based on: + // - recipient enabled stories? + // - sender is authorized? + // - message is a story? + // + // (urgent is not included in the grid because it has no effect + // on any of the other settings.) + + if (recipientUUID == MULTI_DEVICE_UUID) { + // This is the case where the recipient has enabled stories. + if(isStory) { + // We are sending a story, so we ignore access checks and expect this + // to go out to both the recipient's devices. + checkGoodMultiRecipientResponse(response, 2); + } else { + // We are not sending a story, so we need to do access checks. + if (authorize) { + // When authorized we send a message to the recipient's devices. + checkGoodMultiRecipientResponse(response, 2); + } else { + // When forbidden, we return a 401 error. + checkBadMultiRecipientResponse(response, 401); + } + } + } else { + // This is the case where the recipient has not enabled stories. + if (isStory) { + // We are sending a story, so we ignore access checks. + // this recipient has one device. + checkGoodMultiRecipientResponse(response, 1); + } else { + // We are not sending a story so check access. + if (authorize) { + // If allowed, send a message to the recipient's one device. + checkGoodMultiRecipientResponse(response, 1); + } else { + // If forbidden, return a 401 error. + checkBadMultiRecipientResponse(response, 401); + } + } } } - @SafeVarargs - private static Map submap(Map map, K... keys) { - return Arrays.stream(keys).collect(Collectors.toMap(Function.identity(), map::get)); + // Arguments here are: recipient-UUID, is-authorized?, is-story? + private static Stream testMultiRecipientMessage() { + return Stream.of( + Arguments.of(MULTI_DEVICE_UUID, false, true, true, false), + Arguments.of(MULTI_DEVICE_UUID, false, false, true, false), + Arguments.of(SINGLE_DEVICE_UUID, false, true, true, false), + Arguments.of(SINGLE_DEVICE_UUID, false, false, true, false), + Arguments.of(MULTI_DEVICE_UUID, true, true, true, false), + Arguments.of(MULTI_DEVICE_UUID, true, false, true, false), + Arguments.of(SINGLE_DEVICE_UUID, true, true, true, false), + Arguments.of(SINGLE_DEVICE_UUID, true, false, true, false), + Arguments.of(MULTI_DEVICE_UUID, false, true, false, false), + Arguments.of(MULTI_DEVICE_UUID, false, false, false, false), + Arguments.of(SINGLE_DEVICE_UUID, false, true, false, false), + Arguments.of(SINGLE_DEVICE_UUID, false, false, false, false), + Arguments.of(MULTI_DEVICE_UUID, true, true, false, false), + Arguments.of(MULTI_DEVICE_UUID, true, false, false, false), + Arguments.of(SINGLE_DEVICE_UUID, true, true, false, false), + Arguments.of(SINGLE_DEVICE_UUID, true, false, false, false), + Arguments.of(MULTI_DEVICE_UUID, false, true, true, true), + Arguments.of(MULTI_DEVICE_UUID, false, false, true, true), + Arguments.of(SINGLE_DEVICE_UUID, false, true, true, true), + Arguments.of(SINGLE_DEVICE_UUID, false, false, true, true), + Arguments.of(MULTI_DEVICE_UUID, true, true, true, true), + Arguments.of(MULTI_DEVICE_UUID, true, false, true, true), + Arguments.of(SINGLE_DEVICE_UUID, true, true, true, true), + Arguments.of(SINGLE_DEVICE_UUID, true, false, true, true), + Arguments.of(MULTI_DEVICE_UUID, false, true, false, true), + Arguments.of(MULTI_DEVICE_UUID, false, false, false, true), + Arguments.of(SINGLE_DEVICE_UUID, false, true, false, true), + Arguments.of(SINGLE_DEVICE_UUID, false, false, false, true), + Arguments.of(MULTI_DEVICE_UUID, true, true, false, true), + Arguments.of(MULTI_DEVICE_UUID, true, false, false, true), + Arguments.of(SINGLE_DEVICE_UUID, true, true, false, true), + Arguments.of(SINGLE_DEVICE_UUID, true, false, false, true) + ); } - private static Map> multiRecipientTargetMap() { - return - Map.of( - new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), - new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1), - new AciServiceIdentifier(MULTI_DEVICE_UUID), - Map.of( - MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, - MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2), - new PniServiceIdentifier(MULTI_DEVICE_PNI), - Map.of( - MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1, - MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2), - new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), - new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1) - ); - } + @Test + void testMultiRecipientMessageToAccountsSomeOfWhichDoNotExist() throws Exception { + UUID badUUID = UUID.fromString("33333333-3333-3333-3333-333333333333"); + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(badUUID))).thenReturn(Optional.empty()); - private record MultiRecipientMessageTestCase( - Map> destinations, - boolean authenticated, - boolean story, - int expectedStatus, - int expectedSentMessages) { - } + final List recipients = List.of( + new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, + new byte[48]), + new Recipient(new AciServiceIdentifier(badUUID), (byte) 1, 1, new byte[48])); - @CartesianTest - @CartesianTest.MethodFactory("testMultiRecipientMessageNoPni") - void testMultiRecipientMessageNoPni(MultiRecipientMessageTestCase testCase, boolean urgent , boolean explicitIdentifier) throws Exception { - testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages()); - } + Response response = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("online", true) + .queryParam("ts", 1700000000000L) + .queryParam("story", true) + .queryParam("urgent", false) + .request() + .header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot") + .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) + .put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], true), + MultiRecipientMessageProvider.MEDIA_TYPE)); - private static ArgumentSets testMultiRecipientMessageNoPni() { - final Map> targets = multiRecipientTargetMap(); - final Map> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID)); - final Map> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID)); - final Map> bothAccountsAci = - submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID)); - final Map> realAndFakeAci = - submap( - targets, - new AciServiceIdentifier(SINGLE_DEVICE_UUID), - new AciServiceIdentifier(MULTI_DEVICE_UUID), - new AciServiceIdentifier(NONEXISTENT_UUID)); - - final boolean auth = true; - final boolean unauth = false; - final boolean story = true; - final boolean notStory = false; - - return ArgumentSets - .argumentsForFirstParameter( - new MultiRecipientMessageTestCase(singleDeviceAci, unauth, story, 200, 1), - new MultiRecipientMessageTestCase(multiDeviceAci, unauth, story, 200, 2), - new MultiRecipientMessageTestCase(bothAccountsAci, unauth, story, 200, 3), - new MultiRecipientMessageTestCase(realAndFakeAci, unauth, story, 200, 3), - - new MultiRecipientMessageTestCase(singleDeviceAci, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(multiDeviceAci, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(bothAccountsAci, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(realAndFakeAci, unauth, notStory, 404, 0), - - new MultiRecipientMessageTestCase(singleDeviceAci, auth, story, 200, 1), - new MultiRecipientMessageTestCase(multiDeviceAci, auth, story, 200, 2), - new MultiRecipientMessageTestCase(bothAccountsAci, auth, story, 200, 3), - new MultiRecipientMessageTestCase(realAndFakeAci, auth, story, 200, 3), - - new MultiRecipientMessageTestCase(singleDeviceAci, auth, notStory, 200, 1), - new MultiRecipientMessageTestCase(multiDeviceAci, auth, notStory, 200, 2), - new MultiRecipientMessageTestCase(bothAccountsAci, auth, notStory, 200, 3), - new MultiRecipientMessageTestCase(realAndFakeAci, auth, notStory, 404, 0)) - .argumentsForNextParameter(false, true) // urgent - .argumentsForNextParameter(false, true); // explicitIdentifiers - } - - @CartesianTest - @CartesianTest.MethodFactory("testMultiRecipientMessagePni") - void testMultiRecipientMessagePni(MultiRecipientMessageTestCase testCase, boolean urgent) throws Exception { - testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, true, testCase.expectedStatus(), testCase.expectedSentMessages()); - } - - private static ArgumentSets testMultiRecipientMessagePni() { - final Map> targets = multiRecipientTargetMap(); - final Map> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI)); - final Map> singleDeviceAciAndPni = submap( - targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI)); - final Map> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI)); - final Map> bothAccountsMixed = - submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI)); - final Map> realAndFakeMixed = - submap( - targets, - new PniServiceIdentifier(SINGLE_DEVICE_PNI), - new AciServiceIdentifier(MULTI_DEVICE_UUID), - new PniServiceIdentifier(NONEXISTENT_UUID)); - - final boolean auth = true; - final boolean unauth = false; - final boolean story = true; - final boolean notStory = false; - - return ArgumentSets - .argumentsForFirstParameter( - new MultiRecipientMessageTestCase(singleDevicePni, unauth, story, 200, 1), - new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2), - new MultiRecipientMessageTestCase(multiDevicePni, unauth, story, 200, 2), - new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, story, 200, 3), - new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, story, 200, 3), - - new MultiRecipientMessageTestCase(singleDevicePni, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(multiDevicePni, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, notStory, 404, 0), - - new MultiRecipientMessageTestCase(singleDevicePni, auth, story, 200, 1), - new MultiRecipientMessageTestCase(singleDeviceAciAndPni, auth, story, 200, 2), - new MultiRecipientMessageTestCase(multiDevicePni, auth, story, 200, 2), - new MultiRecipientMessageTestCase(bothAccountsMixed, auth, story, 200, 3), - new MultiRecipientMessageTestCase(realAndFakeMixed, auth, story, 200, 3), - - new MultiRecipientMessageTestCase(singleDevicePni, auth, notStory, 200, 1), - new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2), - new MultiRecipientMessageTestCase(multiDevicePni, auth, notStory, 200, 2), - new MultiRecipientMessageTestCase(bothAccountsMixed, auth, notStory, 200, 3), - new MultiRecipientMessageTestCase(realAndFakeMixed, auth, notStory, 404, 0)) - .argumentsForNextParameter(false, true); // urgent + checkGoodMultiRecipientResponse(response, 1); } @ParameterizedTest @@ -1174,7 +1128,7 @@ class MessageControllerTest { void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception { final List recipients = List.of( new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), + new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]), new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48])); Response response = resources @@ -1372,12 +1326,12 @@ class MessageControllerTest { @ParameterizedTest @MethodSource - void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier, final int regId1, final int regId2) + void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier) throws NotPushRegisteredException, InterruptedException { final List recipients = List.of( - new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, regId1, new byte[48]), - new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, regId2, new byte[48])); + new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), + new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])); // initialize our binary payload and create an input stream byte[] buffer = new byte[2048]; @@ -1410,8 +1364,8 @@ class MessageControllerTest { private static Stream sendMultiRecipientMessage404() { return Stream.of( - Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2), - Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2)); + Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)), + Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI))); } private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { @@ -1419,6 +1373,14 @@ class MessageControllerTest { verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); } + private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception { + assertThat("Unexpected response", response.getStatus(), is(equalTo(200))); + ArgumentCaptor>> captor = ArgumentCaptor.forClass(List.class); + verify(messageSender, times(expectedCount)).sendMessage(any(), any(), any(), anyBoolean()); + SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); + assert (smrmr.uuids404().isEmpty()); + } + private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false); @@ -1451,4 +1413,64 @@ class MessageControllerTest { return builder.build(); } + private static Recipient genRecipient(Random rng) { + UUID u1 = UUID.randomUUID(); // non-null + byte d1 = (byte) (rng.nextInt(127) + 1); // 1 to 127 + int dr1 = rng.nextInt() & 0xffff; // 0 to 65535 + byte[] perKeyBytes = new byte[48]; // size=48, non-null + rng.nextBytes(perKeyBytes); + return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes); + } + + private static void roundTripVarint(byte expected, byte[] bytes) throws Exception { + ByteBuffer bb = ByteBuffer.wrap(bytes); + writePayloadDeviceId(bb, expected); + InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position()); + long got = MultiRecipientMessageProvider.readVarint(stream); + assertEquals(expected, got, String.format("encoded as: %s", Arrays.toString(bytes))); + } + + @Test + void testVarintPayload() throws Exception { + Random rng = new Random(); + byte[] bytes = new byte[12]; + + // some static test cases + for (byte i = 1; i <= 10; i++) { + roundTripVarint(i, bytes); + } + roundTripVarint(Byte.MAX_VALUE, bytes); + + for (int i = 0; i < 1000; i++) { + // we need to ensure positive device IDs + byte start = (byte) rng.nextInt(128); + if (start == 0L) { + start = 1; + } + + // run the test for this case + roundTripVarint(start, bytes); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testMultiPayloadRoundtrip(final boolean useExplicitIdentifiers) throws Exception { + Random rng = new java.util.Random(); + List expected = new LinkedList<>(); + for(int i = 0; i < 100; i++) { + expected.add(genRecipient(rng)); + } + + byte[] buffer = new byte[100 + expected.size() * 100]; + InputStream entityStream = initializeMultiPayload(expected, buffer, useExplicitIdentifiers); + MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider(); + // the provider ignores the headers, java reflection, etc. so we don't use those here. + MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream); + List got = Arrays.asList(res.recipients()); + + assertEquals(expected, got); + } + + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java index a16a30c93..bca1ad942 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java @@ -10,8 +10,6 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.internal.exceptions.Reporter.noMoreInteractionsWanted; -import static org.mockito.internal.exceptions.Reporter.tooFewActualInvocations; -import static org.mockito.internal.exceptions.Reporter.tooManyActualInvocations; import static org.mockito.internal.exceptions.Reporter.wantedButNotInvoked; import static org.mockito.internal.invocation.InvocationMarker.markVerified; import static org.mockito.internal.invocation.InvocationsFinder.findFirstUnverified; @@ -28,7 +26,6 @@ import org.mockito.Mockito; import org.mockito.invocation.Invocation; import org.mockito.invocation.MatchableInvocation; import org.mockito.verification.VerificationMode; -import org.mockito.internal.verification.Times; import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.limits.RateLimiter; @@ -174,17 +171,10 @@ public final class MockUtils { * this method */ public static VerificationMode exactly() { - return exactly(1); - } - - /** - * a combination of {@link #exactly()} and {@link org.mockito.Mockito#times(int)}, verifies that - * there are exactly N invocations of this method, and all of them match the given specification - */ - public static VerificationMode exactly(int wantedCount) { return data -> { MatchableInvocation target = data.getTarget(); final List allInvocations = data.getAllInvocations(); + List chunk = findInvocations(allInvocations, target); List otherInvocations = allInvocations.stream() .filter(target::hasSameMethod) .filter(Predicate.not(target::matches)) @@ -194,7 +184,10 @@ public final class MockUtils { Invocation unverified = findFirstUnverified(otherInvocations); throw noMoreInteractionsWanted(unverified, (List) allInvocations); } - Mockito.times(wantedCount).verify(data); + if (chunk.isEmpty()) { + throw wantedButNotInvoked(target); + } + markVerified(chunk.get(0), target); }; }