diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 906c015fe..ab18c7673 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -16,6 +16,12 @@ import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.responses.ApiResponse; + import java.security.MessageDigest; import java.time.Duration; import java.util.ArrayList; @@ -37,7 +43,11 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.function.Predicate; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -112,6 +122,8 @@ import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.websocket.Stories; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @@ -119,6 +131,12 @@ import reactor.core.scheduler.Scheduler; @io.swagger.v3.oas.annotations.tags.Tag(name = "Messages") public class MessageController { + private record MultiRecipientDeliveryData( + ServiceIdentifier serviceIdentifier, + Account account, + Map perDeviceData) { + } + private static final Logger logger = LoggerFactory.getLogger(MessageController.class); private final RateLimiters rateLimiters; @@ -141,6 +159,7 @@ public class MessageController { private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage"); private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType"); + private static final String UNEXPECTED_MISSING_USER_COUNTER_NAME = name(MessageController.class, "unexpectedMissingDestinationForMultiRecipientMessage"); private static final String EPHEMERAL_TAG_NAME = "ephemeral"; private static final String SENDER_TYPE_TAG_NAME = "senderType"; @@ -347,26 +366,25 @@ public class MessageController { /** - * Build mapping of accounts to devices/registration IDs. + * Build mapping of service IDs to resolved accounts and device/registration IDs */ - private Map>> buildDeviceIdAndRegistrationIdMap( - MultiRecipientMessage multiRecipientMessage, - Map accountsByServiceIdentifier) { - - return Arrays.stream(multiRecipientMessage.recipients()) - // for normal messages, all recipients UUIDs are in the map, - // but story messages might specify inactive UUIDs, which we - // have previously filtered - .filter(r -> accountsByServiceIdentifier.containsKey(r.uuid())) - .collect(Collectors.toMap( - recipient -> accountsByServiceIdentifier.get(recipient.uuid()), - recipient -> new HashSet<>( - Collections.singletonList(new Pair<>(recipient.deviceId(), recipient.registrationId()))), - (a, b) -> { - a.addAll(b); - return a; - } - )); + private Map buildRecipientMap( + MultiRecipientMessage multiRecipientMessage, boolean isStory) { + return Flux.fromArray(multiRecipientMessage.recipients()) + .groupBy(Recipient::uuid, multiRecipientMessage.recipients().length) + .flatMap( + gf -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(gf.key())) + .switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new)) + .flatMap( + account -> + gf.collectMap(Recipient::deviceId) + .map(perRecipientData -> + new MultiRecipientDeliveryData( + gf.key(), + account, + perRecipientData)))) + .collectMap(MultiRecipientDeliveryData::serviceIdentifier) + .block(); } @Timed @@ -375,79 +393,87 @@ public class MessageController { @Consumes(MultiRecipientMessageProvider.MEDIA_TYPE) @Produces(MediaType.APPLICATION_JSON) @FilterSpam + @Operation( + summary = "Send multi-recipient sealed-sender message", + description = """ + Deliver a common-payload message to multiple recipients. + An unidentifed-access key for all recipients must be provided, unless the message is a story. + """) + @ApiResponse(responseCode="200", description="Message was successfully sent to all recipients", useReturnTypeSchema=true) + @ApiResponse(responseCode="400", description="The envelope specified delivery to the same recipient device multiple times") + @ApiResponse(responseCode="401", description="The message is not a story and the unauthorized access key is incorrect") + @ApiResponse( + responseCode="404", + description="The message is not a story and some of the recipient service IDs do not correspond to registered Signal users") + @ApiResponse( + responseCode = "409", description = "Incorrect set of devices supplied for some recipients", + content = @Content(schema = @Schema(implementation = AccountMismatchedDevices[].class))) + @ApiResponse( + responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices", + content = @Content(schema = @Schema(implementation = AccountStaleDevices[].class))) + public Response sendMultiRecipientMessage( + @Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message") @HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys, + @HeaderParam(HttpHeaders.USER_AGENT) String userAgent, + + @Parameter(description="If true, deliver the message only to recipients that are online when it is sent") @QueryParam("online") boolean online, + + @Parameter(description="The sender's timestamp for the envelope") @QueryParam("ts") long timestamp, + + @Parameter(description="If true, this message should cause push notifications to be sent to recipients") @QueryParam("urgent") @DefaultValue("true") final boolean isUrgent, + + @Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted") @QueryParam("story") boolean isStory, + @Parameter(description="The sealed-sender multi-recipient message payload") @NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException { - final Map accountsByServiceIdentifier = new HashMap<>(); - - for (final Recipient recipient : multiRecipientMessage.recipients()) { - if (!accountsByServiceIdentifier.containsKey(recipient.uuid())) { - final Optional maybeAccount = accountsManager.getByServiceIdentifier(recipient.uuid()); - - if (maybeAccount.isPresent()) { - accountsByServiceIdentifier.put(recipient.uuid(), maybeAccount.get()); - } else { - if (!isStory) { - throw new NotFoundException(); - } - } - } - } + final Map recipients = buildRecipientMap(multiRecipientMessage, isStory); // Stories will be checked by the client; we bypass access checks here for stories. if (!isStory) { - checkAccessKeys(accessKeys, accountsByServiceIdentifier.values()); + checkAccessKeys(accessKeys, recipients.values()); } - final Map>> accountToDeviceIdAndRegistrationIdMap = - buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, accountsByServiceIdentifier); - - // We might filter out all the recipients of a story (if none have enabled stories). + // We might filter out all the recipients of a story (if none exist). // In this case there is no error so we should just return 200 now. - if (isStory && accountToDeviceIdAndRegistrationIdMap.isEmpty()) { - return Response.ok(new SendMultiRecipientMessageResponse(new LinkedList<>())).build(); + if (isStory) { + if (recipients.isEmpty()) { + return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build(); + } + for (MultiRecipientDeliveryData recipient : recipients.values()) { + rateLimiters.getStoriesLimiter().validate(recipient.account().getUuid()); + } } Collection accountMismatchedDevices = new ArrayList<>(); Collection accountStaleDevices = new ArrayList<>(); + recipients.values().forEach(recipient -> { + final Account account = recipient.account(); - for (Map.Entry entry : accountsByServiceIdentifier.entrySet()) { - final ServiceIdentifier serviceIdentifier = entry.getKey(); - final Account account = entry.getValue(); - - if (isStory) { - rateLimiters.getStoriesLimiter().validate(account.getUuid()); - } - - Set deviceIds = accountToDeviceIdAndRegistrationIdMap - .getOrDefault(account, Collections.emptySet()) - .stream() - .map(Pair::first) - .collect(Collectors.toSet()); - - try { - DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet()); - - // Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number - // identity - DestinationDeviceValidator.validateRegistrationIds( - account, - accountToDeviceIdAndRegistrationIdMap.get(account).stream(), - false); - } catch (MismatchedDevicesException e) { - accountMismatchedDevices.add(new AccountMismatchedDevices(serviceIdentifier, - new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); - } catch (StaleDevicesException e) { - accountStaleDevices.add(new AccountStaleDevices(serviceIdentifier, new StaleDevices(e.getStaleDevices()))); - } - } + try { + DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.perDeviceData().keySet(), Collections.emptySet()); + DestinationDeviceValidator.validateRegistrationIds( + account, + recipient.perDeviceData().values(), + Recipient::deviceId, + Recipient::registrationId, + recipient.serviceIdentifier().identityType() == IdentityType.PNI); + } catch (MismatchedDevicesException e) { + accountMismatchedDevices.add( + new AccountMismatchedDevices( + recipient.serviceIdentifier(), + new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); + } catch (StaleDevicesException e) { + accountStaleDevices.add( + new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices()))); + } + }); if (!accountMismatchedDevices.isEmpty()) { return Response .status(409) @@ -472,25 +498,28 @@ public class MessageController { Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED))); CompletableFuture.allOf( - Arrays.stream(multiRecipientMessage.recipients()) - // If we're sending a story, some recipients might not map to existing accounts - .filter(recipient -> accountsByServiceIdentifier.containsKey(recipient.uuid())) - .map( - recipient -> CompletableFuture.runAsync( - () -> { - Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid()); - - // we asserted this must exist in validateCompleteDeviceList - Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow(); - sentMessageCounter.increment(); - try { - sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent, - recipient, multiRecipientMessage.commonPayload()); - } catch (NoSuchUserException e) { - uuids404.add(recipient.uuid()); - } - }, - multiRecipientMessageExecutor)) + recipients.values().stream() + .flatMap(recipientData -> + recipientData.perDeviceData().values().stream().map( + recipient -> CompletableFuture.runAsync( + () -> { + final Account destinationAccount = recipientData.account(); + // we asserted this must exist in validateCompleteDeviceList + final Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow(); + try { + sentMessageCounter.increment(); + sendCommonPayloadMessage( + destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, online, + isStory, isUrgent, recipient, multiRecipientMessage.commonPayload()); + } catch (NoSuchUserException e) { + // this should never happen, because we already asserted the device is present and enabled + Metrics.counter( + UNEXPECTED_MISSING_USER_COUNTER_NAME, + Tags.of("isPrimary", String.valueOf(destinationDevice.isPrimary()))).increment(); + uuids404.add(recipientData.serviceIdentifier()); + } + }, + multiRecipientMessageExecutor))) .toArray(CompletableFuture[]::new)) .get(); } catch (InterruptedException e) { @@ -506,41 +535,27 @@ public class MessageController { return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build(); } - private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection destinationAccounts) { + private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection destinations) { // We should not have null access keys when checking access; bail out early. if (accessKeys == null) { throw new WebApplicationException(Status.UNAUTHORIZED); } - AtomicBoolean throwUnauthorized = new AtomicBoolean(false); - byte[] empty = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]; - final Optional UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - byte[] combinedUnknownAccessKeys = destinationAccounts.stream() - .map(account -> { - if (account.isUnrestrictedUnidentifiedAccess()) { - return UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY; - } else { - return account.getUnidentifiedAccessKey(); - } - }) - .map(accessKey -> { - if (accessKey.isEmpty()) { - throwUnauthorized.set(true); - return empty; - } - return accessKey.get(); - }) - .reduce(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH], (bytes, bytes2) -> { - if (bytes.length != bytes2.length) { - throwUnauthorized.set(true); - return bytes; - } - for (int i = 0; i < bytes.length; i++) { - bytes[i] ^= bytes2[i]; - } - return bytes; - }); - if (throwUnauthorized.get() - || !MessageDigest.isEqual(combinedUnknownAccessKeys, accessKeys.getAccessKeys())) { + + final int keyLength = UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH; + final byte[] combinedUnidentifiedAccessKeys = destinations.stream() + .map(MultiRecipientDeliveryData::account) + .filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess)) + .map(account -> + account.getUnidentifiedAccessKey() + .filter(b -> b.length == keyLength) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED))) + .reduce(new byte[keyLength], + (a, b) -> { + final byte[] xor = new byte[keyLength]; + IntStream.range(0, keyLength).forEach(i -> xor[i] = (byte) (a[i] ^ b[i])); + return xor; + }); + if (!MessageDigest.isEqual(combinedUnidentifiedAccessKeys, accessKeys.getAccessKeys())) { throw new WebApplicationException(Status.UNAUTHORIZED); } } @@ -720,6 +735,7 @@ public class MessageController { private void sendCommonPayloadMessage(Account destinationAccount, Device destinationDevice, + ServiceIdentifier serviceIdentifier, long timestamp, boolean online, boolean story, @@ -743,7 +759,7 @@ public class MessageController { .setContent(ByteString.copyFrom(payload)) .setStory(story) .setUrgent(urgent) - .setDestinationUuid(new AciServiceIdentifier(destinationAccount.getUuid()).toServiceIdentifierString()); + .setDestinationUuid(serviceIdentifier.toServiceIdentifierString()); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); } catch (NotPushRegisteredException e) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 3a3fd963c..9826f93aa 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.controllers; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.collection.IsEmptyCollection.empty; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -29,6 +30,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; +import static org.whispersystems.textsecuregcm.util.MockUtils.exactly; import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.collect.ImmutableSet; @@ -42,21 +44,24 @@ import java.io.ByteArrayInputStream; import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Optional; -import java.util.Random; import java.util.Set; import java.util.UUID; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import javax.ws.rs.client.Entity; import javax.ws.rs.client.Invocation; @@ -73,8 +78,11 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsSources; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.junitpioneer.jupiter.cartesian.ArgumentSets; +import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.mockito.ArgumentCaptor; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; @@ -92,8 +100,6 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; -import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage; -import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; @@ -139,6 +145,7 @@ class MessageControllerTest { private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111"); private static final byte SINGLE_DEVICE_ID1 = 1; private static final int SINGLE_DEVICE_REG_ID1 = 111; + private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111; private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222"); @@ -149,6 +156,11 @@ class MessageControllerTest { private static final int MULTI_DEVICE_REG_ID1 = 222; private static final int MULTI_DEVICE_REG_ID2 = 333; private static final int MULTI_DEVICE_REG_ID3 = 444; + private static final int MULTI_DEVICE_PNI_REG_ID1 = 2222; + private static final int MULTI_DEVICE_PNI_REG_ID2 = 3333; + private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444; + + private static final UUID NONEXISTENT_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333"); private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes(); @@ -192,13 +204,13 @@ class MessageControllerTest { final List singleDeviceList = List.of( - generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()) + generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()) ); final List multiDeviceList = List.of( - generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), - generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), - generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) + generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), + generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), + generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) ); Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); @@ -211,6 +223,8 @@ class MessageControllerTest { when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount)); + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty()); + when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty()); final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration = mock(DynamicInboundMessageByteLimitConfiguration.class); @@ -942,25 +956,21 @@ class MessageControllerTest { ); } - private static void writePayloadDeviceId(ByteBuffer bb, byte deviceId) { - long x = deviceId; - // write the device-id in the 7-bit varint format we use, least significant bytes first. - do { - long b = x & 0x7f; - x = x >>> 7; - if (x != 0) b |= 0x80; - bb.put((byte)b); - } while (x != 0); + private record Recipient(ServiceIdentifier uuid, + byte deviceId, + int registrationId, + byte[] perRecipientKeyMaterial) { } - private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, final boolean useExplicitIdentifier) { + private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, + final boolean useExplicitIdentifier) { if (useExplicitIdentifier) { bb.put(r.uuid().toFixedWidthByteArray()); } else { bb.put(UUIDUtil.toBytes(r.uuid().uuid())); } - writePayloadDeviceId(bb, r.deviceId()); // device id (1-9 bytes) + bb.put(r.deviceId()); // device id (1 byte) bb.putShort((short) r.registrationId()); // registration id (2 bytes) bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes) } @@ -973,8 +983,15 @@ class MessageControllerTest { // first write the header bb.put(explicitIdentifiers ? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER - : MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte - bb.put((byte)recipients.size()); // count varint + : MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte + + // count varint + int nRecip = recipients.size(); + while (nRecip > 127) { + bb.put((byte) (nRecip & 0x7F | 0x80)); + nRecip = nRecip >> 7; + } + bb.put((byte)(nRecip & 0x7F)); Iterator it = recipients.iterator(); while (it.hasNext()) { @@ -988,23 +1005,65 @@ class MessageControllerTest { return new ByteArrayInputStream(buffer, 0, bb.position()); } - @ParameterizedTest - @MethodSource - void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent, boolean explicitIdentifier) throws Exception { + @Test + void testManyRecipientMessage() throws Exception { + final int nRecipients = 999; + final int devicesPerRecipient = 5; + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + final List recipients = new ArrayList<>(); - final List recipients; - if (recipientUUID == MULTI_DEVICE_UUID) { - recipients = List.of( - new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]) - ); - } else { - recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); + for (int i = 0; i < nRecipients; i++) { + final List devices = + IntStream.range(1, devicesPerRecipient + 1) + .mapToObj( + d -> generateTestDevice( + (byte) d, 100 + d, 10 * d, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), + System.currentTimeMillis())) + .collect(Collectors.toList()); + final UUID aci = new UUID(0L, (long) i); + final UUID pni = new UUID(1L, (long) i); + final String e164 = String.format("+1408555%04d", i); + final Account account = AccountsHelper.generateTestAccount(e164, aci, pni, devices, UNIDENTIFIED_ACCESS_BYTES); + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci))).thenReturn(Optional.of(account)); + when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(pni))).thenReturn(Optional.of(account)); + devices.forEach(d -> recipients.add(new Recipient(new AciServiceIdentifier(aci), d.getId(), d.getRegistrationId(), new byte[48]))); } + byte[] buffer = new byte[1048576]; + InputStream stream = initializeMultiPayload(recipients, buffer, true); + Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); + final Response response = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("online", true) + .queryParam("story", true) + .queryParam("urgent", false) + .request() + .header(HttpHeaders.USER_AGENT, "FIXME") + .put(entity); + + assertThat(response.readEntity(String.class), response.getStatus(), is(equalTo(200))); + verify(messageSender, times(nRecipients * devicesPerRecipient)).sendMessage(any(), any(), any(), eq(true)); + } + + // see testMultiRecipientMessageNoPni and testMultiRecipientMessagePni below for actual invocations + private void testMultiRecipientMessage( + Map> destinations, + boolean authorize, + boolean isStory, + boolean urgent, + boolean explicitIdentifier, + int expectedStatus, + int expectedMessagesSent) throws Exception { + final List recipients = new ArrayList<>(); + destinations.forEach( + (serviceIdentifier, deviceToRegistrationId) -> + deviceToRegistrationId.forEach( + (deviceId, registrationId) -> + recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48])))); + // initialize our binary payload and create an input stream byte[] buffer = new byte[2048]; - //InputStream stream = initializeMultiPayload(recipientUUID, buffer); InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier); // set up the entity to use in our PUT request @@ -1023,124 +1082,160 @@ class MessageControllerTest { // add access header if needed if (authorize) { - String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES); + final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count(); + String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]); bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes); } // make the PUT request Response response = bldr.put(entity); - if (authorize) { - ArgumentCaptor envelopeArgumentCaptor = ArgumentCaptor.forClass(Envelope.class); - verify(messageSender, atLeastOnce()).sendMessage(any(), any(), envelopeArgumentCaptor.capture(), anyBoolean()); - assertEquals(urgent, envelopeArgumentCaptor.getValue().getUrgent()); - } - - // We have a 2x2x2 grid of possible situations based on: - // - recipient enabled stories? - // - sender is authorized? - // - message is a story? - // - // (urgent is not included in the grid because it has no effect - // on any of the other settings.) - - if (recipientUUID == MULTI_DEVICE_UUID) { - // This is the case where the recipient has enabled stories. - if(isStory) { - // We are sending a story, so we ignore access checks and expect this - // to go out to both the recipient's devices. - checkGoodMultiRecipientResponse(response, 2); - } else { - // We are not sending a story, so we need to do access checks. - if (authorize) { - // When authorized we send a message to the recipient's devices. - checkGoodMultiRecipientResponse(response, 2); - } else { - // When forbidden, we return a 401 error. - checkBadMultiRecipientResponse(response, 401); - } - } - } else { - // This is the case where the recipient has not enabled stories. - if (isStory) { - // We are sending a story, so we ignore access checks. - // this recipient has one device. - checkGoodMultiRecipientResponse(response, 1); - } else { - // We are not sending a story so check access. - if (authorize) { - // If allowed, send a message to the recipient's one device. - checkGoodMultiRecipientResponse(response, 1); - } else { - // If forbidden, return a 401 error. - checkBadMultiRecipientResponse(response, 401); - } - } + assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus))); + verify(messageSender, + exactly(expectedMessagesSent)) + .sendMessage( + any(), + any(), + argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()), + eq(true)); + if (expectedStatus == 200) { + SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); + assertThat(smrmr.uuids404(), is(empty())); } } - // Arguments here are: recipient-UUID, is-authorized?, is-story? - private static Stream testMultiRecipientMessage() { - return Stream.of( - Arguments.of(MULTI_DEVICE_UUID, false, true, true, false), - Arguments.of(MULTI_DEVICE_UUID, false, false, true, false), - Arguments.of(SINGLE_DEVICE_UUID, false, true, true, false), - Arguments.of(SINGLE_DEVICE_UUID, false, false, true, false), - Arguments.of(MULTI_DEVICE_UUID, true, true, true, false), - Arguments.of(MULTI_DEVICE_UUID, true, false, true, false), - Arguments.of(SINGLE_DEVICE_UUID, true, true, true, false), - Arguments.of(SINGLE_DEVICE_UUID, true, false, true, false), - Arguments.of(MULTI_DEVICE_UUID, false, true, false, false), - Arguments.of(MULTI_DEVICE_UUID, false, false, false, false), - Arguments.of(SINGLE_DEVICE_UUID, false, true, false, false), - Arguments.of(SINGLE_DEVICE_UUID, false, false, false, false), - Arguments.of(MULTI_DEVICE_UUID, true, true, false, false), - Arguments.of(MULTI_DEVICE_UUID, true, false, false, false), - Arguments.of(SINGLE_DEVICE_UUID, true, true, false, false), - Arguments.of(SINGLE_DEVICE_UUID, true, false, false, false), - Arguments.of(MULTI_DEVICE_UUID, false, true, true, true), - Arguments.of(MULTI_DEVICE_UUID, false, false, true, true), - Arguments.of(SINGLE_DEVICE_UUID, false, true, true, true), - Arguments.of(SINGLE_DEVICE_UUID, false, false, true, true), - Arguments.of(MULTI_DEVICE_UUID, true, true, true, true), - Arguments.of(MULTI_DEVICE_UUID, true, false, true, true), - Arguments.of(SINGLE_DEVICE_UUID, true, true, true, true), - Arguments.of(SINGLE_DEVICE_UUID, true, false, true, true), - Arguments.of(MULTI_DEVICE_UUID, false, true, false, true), - Arguments.of(MULTI_DEVICE_UUID, false, false, false, true), - Arguments.of(SINGLE_DEVICE_UUID, false, true, false, true), - Arguments.of(SINGLE_DEVICE_UUID, false, false, false, true), - Arguments.of(MULTI_DEVICE_UUID, true, true, false, true), - Arguments.of(MULTI_DEVICE_UUID, true, false, false, true), - Arguments.of(SINGLE_DEVICE_UUID, true, true, false, true), - Arguments.of(SINGLE_DEVICE_UUID, true, false, false, true) - ); + @SafeVarargs + private static Map submap(Map map, K... keys) { + return Arrays.stream(keys).collect(Collectors.toMap(Function.identity(), map::get)); } - @Test - void testMultiRecipientMessageToAccountsSomeOfWhichDoNotExist() throws Exception { - UUID badUUID = UUID.fromString("33333333-3333-3333-3333-333333333333"); - when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(badUUID))).thenReturn(Optional.empty()); + private static Map> multiRecipientTargetMap() { + return + Map.of( + new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), + new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1), + new AciServiceIdentifier(MULTI_DEVICE_UUID), + Map.of( + MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, + MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2), + new PniServiceIdentifier(MULTI_DEVICE_PNI), + Map.of( + MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1, + MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2), + new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), + new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1) + ); + } - final List recipients = List.of( - new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, - new byte[48]), - new Recipient(new AciServiceIdentifier(badUUID), (byte) 1, 1, new byte[48])); + private record MultiRecipientMessageTestCase( + Map> destinations, + boolean authenticated, + boolean story, + int expectedStatus, + int expectedSentMessages) { + } - Response response = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", true) - .queryParam("ts", 1700000000000L) - .queryParam("story", true) - .queryParam("urgent", false) - .request() - .header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot") - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) - .put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], true), - MultiRecipientMessageProvider.MEDIA_TYPE)); + @CartesianTest + @CartesianTest.MethodFactory("testMultiRecipientMessageNoPni") + void testMultiRecipientMessageNoPni(MultiRecipientMessageTestCase testCase, boolean urgent , boolean explicitIdentifier) throws Exception { + testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages()); + } - checkGoodMultiRecipientResponse(response, 1); + private static ArgumentSets testMultiRecipientMessageNoPni() { + final Map> targets = multiRecipientTargetMap(); + final Map> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID)); + final Map> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID)); + final Map> bothAccountsAci = + submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID)); + final Map> realAndFakeAci = + submap( + targets, + new AciServiceIdentifier(SINGLE_DEVICE_UUID), + new AciServiceIdentifier(MULTI_DEVICE_UUID), + new AciServiceIdentifier(NONEXISTENT_UUID)); + + final boolean auth = true; + final boolean unauth = false; + final boolean story = true; + final boolean notStory = false; + + return ArgumentSets + .argumentsForFirstParameter( + new MultiRecipientMessageTestCase(singleDeviceAci, unauth, story, 200, 1), + new MultiRecipientMessageTestCase(multiDeviceAci, unauth, story, 200, 2), + new MultiRecipientMessageTestCase(bothAccountsAci, unauth, story, 200, 3), + new MultiRecipientMessageTestCase(realAndFakeAci, unauth, story, 200, 3), + + new MultiRecipientMessageTestCase(singleDeviceAci, unauth, notStory, 401, 0), + new MultiRecipientMessageTestCase(multiDeviceAci, unauth, notStory, 401, 0), + new MultiRecipientMessageTestCase(bothAccountsAci, unauth, notStory, 401, 0), + new MultiRecipientMessageTestCase(realAndFakeAci, unauth, notStory, 404, 0), + + new MultiRecipientMessageTestCase(singleDeviceAci, auth, story, 200, 1), + new MultiRecipientMessageTestCase(multiDeviceAci, auth, story, 200, 2), + new MultiRecipientMessageTestCase(bothAccountsAci, auth, story, 200, 3), + new MultiRecipientMessageTestCase(realAndFakeAci, auth, story, 200, 3), + + new MultiRecipientMessageTestCase(singleDeviceAci, auth, notStory, 200, 1), + new MultiRecipientMessageTestCase(multiDeviceAci, auth, notStory, 200, 2), + new MultiRecipientMessageTestCase(bothAccountsAci, auth, notStory, 200, 3), + new MultiRecipientMessageTestCase(realAndFakeAci, auth, notStory, 404, 0)) + .argumentsForNextParameter(false, true) // urgent + .argumentsForNextParameter(false, true); // explicitIdentifiers + } + + @CartesianTest + @CartesianTest.MethodFactory("testMultiRecipientMessagePni") + void testMultiRecipientMessagePni(MultiRecipientMessageTestCase testCase, boolean urgent) throws Exception { + testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, true, testCase.expectedStatus(), testCase.expectedSentMessages()); + } + + private static ArgumentSets testMultiRecipientMessagePni() { + final Map> targets = multiRecipientTargetMap(); + final Map> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI)); + final Map> singleDeviceAciAndPni = submap( + targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI)); + final Map> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI)); + final Map> bothAccountsMixed = + submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI)); + final Map> realAndFakeMixed = + submap( + targets, + new PniServiceIdentifier(SINGLE_DEVICE_PNI), + new AciServiceIdentifier(MULTI_DEVICE_UUID), + new PniServiceIdentifier(NONEXISTENT_UUID)); + + final boolean auth = true; + final boolean unauth = false; + final boolean story = true; + final boolean notStory = false; + + return ArgumentSets + .argumentsForFirstParameter( + new MultiRecipientMessageTestCase(singleDevicePni, unauth, story, 200, 1), + new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2), + new MultiRecipientMessageTestCase(multiDevicePni, unauth, story, 200, 2), + new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, story, 200, 3), + new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, story, 200, 3), + + new MultiRecipientMessageTestCase(singleDevicePni, unauth, notStory, 401, 0), + new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, notStory, 401, 0), + new MultiRecipientMessageTestCase(multiDevicePni, unauth, notStory, 401, 0), + new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, notStory, 401, 0), + new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, notStory, 404, 0), + + new MultiRecipientMessageTestCase(singleDevicePni, auth, story, 200, 1), + new MultiRecipientMessageTestCase(singleDeviceAciAndPni, auth, story, 200, 2), + new MultiRecipientMessageTestCase(multiDevicePni, auth, story, 200, 2), + new MultiRecipientMessageTestCase(bothAccountsMixed, auth, story, 200, 3), + new MultiRecipientMessageTestCase(realAndFakeMixed, auth, story, 200, 3), + + new MultiRecipientMessageTestCase(singleDevicePni, auth, notStory, 200, 1), + new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2), + new MultiRecipientMessageTestCase(multiDevicePni, auth, notStory, 200, 2), + new MultiRecipientMessageTestCase(bothAccountsMixed, auth, notStory, 200, 3), + new MultiRecipientMessageTestCase(realAndFakeMixed, auth, notStory, 404, 0)) + .argumentsForNextParameter(false, true); // urgent } @ParameterizedTest @@ -1148,7 +1243,7 @@ class MessageControllerTest { void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception { final List recipients = List.of( new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]), + new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48])); Response response = resources @@ -1346,12 +1441,12 @@ class MessageControllerTest { @ParameterizedTest @MethodSource - void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier) + void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier, final int regId1, final int regId2) throws NotPushRegisteredException, InterruptedException { final List recipients = List.of( - new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])); + new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, regId1, new byte[48]), + new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, regId2, new byte[48])); // initialize our binary payload and create an input stream byte[] buffer = new byte[2048]; @@ -1384,8 +1479,8 @@ class MessageControllerTest { private static Stream sendMultiRecipientMessage404() { return Stream.of( - Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)), - Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI))); + Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2), + Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2)); } private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { @@ -1393,14 +1488,6 @@ class MessageControllerTest { verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); } - private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception { - assertThat("Unexpected response", response.getStatus(), is(equalTo(200))); - ArgumentCaptor>> captor = ArgumentCaptor.forClass(List.class); - verify(messageSender, times(expectedCount)).sendMessage(any(), any(), any(), anyBoolean()); - SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); - assert (smrmr.uuids404().isEmpty()); - } - private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false); @@ -1433,64 +1520,4 @@ class MessageControllerTest { return builder.build(); } - private static Recipient genRecipient(Random rng) { - UUID u1 = UUID.randomUUID(); // non-null - byte d1 = (byte) (rng.nextInt(127) + 1); // 1 to 127 - int dr1 = rng.nextInt() & 0xffff; // 0 to 65535 - byte[] perKeyBytes = new byte[48]; // size=48, non-null - rng.nextBytes(perKeyBytes); - return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes); - } - - private static void roundTripVarint(byte expected, byte[] bytes) throws Exception { - ByteBuffer bb = ByteBuffer.wrap(bytes); - writePayloadDeviceId(bb, expected); - InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position()); - long got = MultiRecipientMessageProvider.readVarint(stream); - assertEquals(expected, got, String.format("encoded as: %s", Arrays.toString(bytes))); - } - - @Test - void testVarintPayload() throws Exception { - Random rng = new Random(); - byte[] bytes = new byte[12]; - - // some static test cases - for (byte i = 1; i <= 10; i++) { - roundTripVarint(i, bytes); - } - roundTripVarint(Byte.MAX_VALUE, bytes); - - for (int i = 0; i < 1000; i++) { - // we need to ensure positive device IDs - byte start = (byte) rng.nextInt(128); - if (start == 0L) { - start = 1; - } - - // run the test for this case - roundTripVarint(start, bytes); - } - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testMultiPayloadRoundtrip(final boolean useExplicitIdentifiers) throws Exception { - Random rng = new java.util.Random(); - List expected = new LinkedList<>(); - for(int i = 0; i < 100; i++) { - expected.add(genRecipient(rng)); - } - - byte[] buffer = new byte[100 + expected.size() * 100]; - InputStream entityStream = initializeMultiPayload(expected, buffer, useExplicitIdentifiers); - MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider(); - // the provider ignores the headers, java reflection, etc. so we don't use those here. - MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream); - List got = Arrays.asList(res.recipients()); - - assertEquals(expected, got); - } - - } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java index 0750de537..dbe61f577 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java @@ -10,10 +10,7 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.internal.exceptions.Reporter.noMoreInteractionsWanted; -import static org.mockito.internal.exceptions.Reporter.wantedButNotInvoked; -import static org.mockito.internal.invocation.InvocationMarker.markVerified; import static org.mockito.internal.invocation.InvocationsFinder.findFirstUnverified; -import static org.mockito.internal.invocation.InvocationsFinder.findInvocations; import java.time.Duration; import java.util.List; @@ -169,10 +166,17 @@ public final class MockUtils { * this method */ public static VerificationMode exactly() { + return exactly(1); + } + + /** + * a combination of {@link #exactly()} and {@link org.mockito.Mockito#times(int)}, verifies that + * there are exactly N invocations of this method, and all of them match the given specification + */ + public static VerificationMode exactly(int wantedCount) { return data -> { MatchableInvocation target = data.getTarget(); final List allInvocations = data.getAllInvocations(); - List chunk = findInvocations(allInvocations, target); List otherInvocations = allInvocations.stream() .filter(target::hasSameMethod) .filter(Predicate.not(target::matches)) @@ -182,10 +186,7 @@ public final class MockUtils { Invocation unverified = findFirstUnverified(otherInvocations); throw noMoreInteractionsWanted(unverified, (List) allInvocations); } - if (chunk.isEmpty()) { - throw wantedButNotInvoked(target); - } - markVerified(chunk.get(0), target); + Mockito.times(wantedCount).verify(data); }; }