multisend cleanup

This commit is contained in:
Jonathan Klabunde Tomer 2023-11-30 15:50:36 -08:00 committed by GitHub
parent 22e6584402
commit c03249b411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 344 additions and 343 deletions

View File

@ -16,6 +16,12 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.security.MessageDigest;
import java.time.Duration;
import java.util.ArrayList;
@ -37,6 +43,8 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
@ -87,6 +95,7 @@ import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageRespon
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -111,6 +120,8 @@ import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.Stories;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ -118,6 +129,12 @@ import reactor.core.scheduler.Scheduler;
@io.swagger.v3.oas.annotations.tags.Tag(name = "Messages")
public class MessageController {
private record MessageRecipient(
ServiceIdentifier serviceIdentifier,
Account account,
Map<Byte, Recipient> perDeviceData) {
}
private static final Logger logger = LoggerFactory.getLogger(MessageController.class);
private final RateLimiters rateLimiters;
@ -138,9 +155,9 @@ public class MessageController {
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize");
private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes");
private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage");
private static final String RATE_LIMITED_STORIES_COUNTER_NAME = name(MessageController.class, "rateLimitedStory");
private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType");
private static final String UNEXPECTED_MISSING_USER_COUNTER_NAME = name(MessageController.class, "unexpectedMissingDestinationForMultiRecipientMessage");
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
private static final String SENDER_TYPE_TAG_NAME = "senderType";
@ -343,26 +360,25 @@ public class MessageController {
/**
* Build mapping of accounts to devices/registration IDs.
* Build mapping of service IDs to resolved accounts and device/registration IDs
*/
private Map<Account, Set<Pair<Byte, Integer>>> buildDeviceIdAndRegistrationIdMap(
MultiRecipientMessage multiRecipientMessage,
Map<ServiceIdentifier, Account> accountsByServiceIdentifier) {
return Arrays.stream(multiRecipientMessage.recipients())
// for normal messages, all recipients UUIDs are in the map,
// but story messages might specify inactive UUIDs, which we
// have previously filtered
.filter(r -> accountsByServiceIdentifier.containsKey(r.uuid()))
.collect(Collectors.toMap(
recipient -> accountsByServiceIdentifier.get(recipient.uuid()),
recipient -> new HashSet<>(
Collections.singletonList(new Pair<>(recipient.deviceId(), recipient.registrationId()))),
(a, b) -> {
a.addAll(b);
return a;
}
));
private Map<ServiceIdentifier, MessageRecipient> buildRecipientMap(
MultiRecipientMessage multiRecipientMessage, boolean isStory) {
return Flux.fromArray(multiRecipientMessage.recipients())
.groupBy(Recipient::uuid)
.flatMap(
gf -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(gf.key()))
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
.flatMap(
account ->
gf.collectMap(Recipient::deviceId)
.map(perRecipientData ->
new MessageRecipient(
gf.key(),
account,
perRecipientData))))
.collectMap(MessageRecipient::serviceIdentifier)
.block();
}
@Timed
@ -371,79 +387,87 @@ public class MessageController {
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
@Produces(MediaType.APPLICATION_JSON)
@FilterSpam
@Operation(
summary = "Send multi-recipient sealed-sender message",
description = """
Deliver a common-payload message to multiple recipients.
An unidentifed-access key for all recipients must be provided, unless the message is a story.
""")
@ApiResponse(responseCode="200", description="Message was successfully sent to all recipients", useReturnTypeSchema=true)
@ApiResponse(responseCode="400", description="The envelope specified delivery to the same recipient device multiple times")
@ApiResponse(responseCode="401", description="The message is not a story and the unauthorized access key is incorrect")
@ApiResponse(
responseCode="404",
description="The message is not a story and some of the recipient service IDs do not correspond to registered Signal users")
@ApiResponse(
responseCode = "409", description = "Incorrect set of devices supplied for some recipients",
content = @Content(schema = @Schema(implementation = AccountMismatchedDevices[].class)))
@ApiResponse(
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
content = @Content(schema = @Schema(implementation = AccountStaleDevices[].class)))
public Response sendMultiRecipientMessage(
@Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message")
@HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@Parameter(description="If true, deliver the message only to recipients that are online when it is sent")
@QueryParam("online") boolean online,
@Parameter(description="The sender's timestamp for the envelope")
@QueryParam("ts") long timestamp,
@Parameter(description="If true, this message should cause push notifications to be sent to recipients")
@QueryParam("urgent") @DefaultValue("true") final boolean isUrgent,
@Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted")
@QueryParam("story") boolean isStory,
@Parameter(description="The sealed-sender multi-recipient message payload")
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {
final Map<ServiceIdentifier, Account> accountsByServiceIdentifier = new HashMap<>();
for (final Recipient recipient : multiRecipientMessage.recipients()) {
if (!accountsByServiceIdentifier.containsKey(recipient.uuid())) {
final Optional<Account> maybeAccount = accountsManager.getByServiceIdentifier(recipient.uuid());
if (maybeAccount.isPresent()) {
accountsByServiceIdentifier.put(recipient.uuid(), maybeAccount.get());
} else {
if (!isStory) {
throw new NotFoundException();
}
}
}
}
final Map<ServiceIdentifier, MessageRecipient> recipients = buildRecipientMap(multiRecipientMessage, isStory);
// Stories will be checked by the client; we bypass access checks here for stories.
if (!isStory) {
checkAccessKeys(accessKeys, accountsByServiceIdentifier.values());
checkAccessKeys(accessKeys, recipients.values());
}
final Map<Account, Set<Pair<Byte, Integer>>> accountToDeviceIdAndRegistrationIdMap =
buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, accountsByServiceIdentifier);
// We might filter out all the recipients of a story (if none have enabled stories).
// We might filter out all the recipients of a story (if none exist).
// In this case there is no error so we should just return 200 now.
if (isStory && accountToDeviceIdAndRegistrationIdMap.isEmpty()) {
return Response.ok(new SendMultiRecipientMessageResponse(new LinkedList<>())).build();
if (isStory) {
if (recipients.isEmpty()) {
return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build();
}
for (MessageRecipient recipient : recipients.values()) {
rateLimiters.getStoriesLimiter().validate(recipient.account().getUuid());
}
}
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
recipients.values().forEach(recipient -> {
final Account account = recipient.account();
for (Map.Entry<ServiceIdentifier, Account> entry : accountsByServiceIdentifier.entrySet()) {
final ServiceIdentifier serviceIdentifier = entry.getKey();
final Account account = entry.getValue();
if (isStory) {
rateLimiters.getStoriesLimiter().validate(account.getUuid());
}
Set<Byte> deviceIds = accountToDeviceIdAndRegistrationIdMap
.getOrDefault(account, Collections.emptySet())
.stream()
.map(Pair::first)
.collect(Collectors.toSet());
try {
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet());
// Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number
// identity
DestinationDeviceValidator.validateRegistrationIds(
account,
accountToDeviceIdAndRegistrationIdMap.get(account).stream(),
false);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(serviceIdentifier,
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(new AccountStaleDevices(serviceIdentifier, new StaleDevices(e.getStaleDevices())));
}
}
try {
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.perDeviceData().keySet(), Collections.emptySet());
DestinationDeviceValidator.validateRegistrationIds(
account,
recipient.perDeviceData().values(),
Recipient::deviceId,
Recipient::registrationId,
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
recipient.serviceIdentifier(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices())));
}
});
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
@ -468,25 +492,28 @@ public class MessageController {
Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED)));
CompletableFuture.allOf(
Arrays.stream(multiRecipientMessage.recipients())
// If we're sending a story, some recipients might not map to existing accounts
.filter(recipient -> accountsByServiceIdentifier.containsKey(recipient.uuid()))
.map(
recipient -> CompletableFuture.runAsync(
() -> {
Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid());
// we asserted this must exist in validateCompleteDeviceList
Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
sentMessageCounter.increment();
try {
sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent,
recipient, multiRecipientMessage.commonPayload());
} catch (NoSuchUserException e) {
uuids404.add(recipient.uuid());
}
},
multiRecipientMessageExecutor))
recipients.values().stream()
.flatMap(recipientData ->
recipientData.perDeviceData().values().stream().map(
recipient -> CompletableFuture.runAsync(
() -> {
final Account destinationAccount = recipientData.account();
// we asserted this must exist in validateCompleteDeviceList
final Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
try {
sentMessageCounter.increment();
sendCommonPayloadMessage(
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, online,
isStory, isUrgent, recipient, multiRecipientMessage.commonPayload());
} catch (NoSuchUserException e) {
// this should never happen, because we already asserted the device is present and enabled
Metrics.counter(
UNEXPECTED_MISSING_USER_COUNTER_NAME,
Tags.of("isPrimary", String.valueOf(destinationDevice.isPrimary()))).increment();
uuids404.add(recipientData.serviceIdentifier());
}
},
multiRecipientMessageExecutor)))
.toArray(CompletableFuture[]::new))
.get();
} catch (InterruptedException e) {
@ -502,43 +529,31 @@ public class MessageController {
return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build();
}
private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<Account> destinationAccounts) {
private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<MessageRecipient> destinations) {
// We should not have null access keys when checking access; bail out early.
if (accessKeys == null) {
throw new WebApplicationException(Status.UNAUTHORIZED);
}
AtomicBoolean throwUnauthorized = new AtomicBoolean(false);
byte[] empty = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH];
final Optional<byte[]> UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
byte[] combinedUnknownAccessKeys = destinationAccounts.stream()
.map(account -> {
if (account.isUnrestrictedUnidentifiedAccess()) {
return UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY;
} else {
return account.getUnidentifiedAccessKey();
}
})
.map(accessKey -> {
if (accessKey.isEmpty()) {
throwUnauthorized.set(true);
return empty;
}
return accessKey.get();
})
.reduce(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH], (bytes, bytes2) -> {
if (bytes.length != bytes2.length) {
throwUnauthorized.set(true);
return bytes;
}
for (int i = 0; i < bytes.length; i++) {
bytes[i] ^= bytes2[i];
}
return bytes;
});
if (throwUnauthorized.get()
|| !MessageDigest.isEqual(combinedUnknownAccessKeys, accessKeys.getAccessKeys())) {
throw new WebApplicationException(Status.UNAUTHORIZED);
}
destinations.stream()
.map(MessageRecipient::account)
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess))
.map(account -> account.getUnidentifiedAccessKey().orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)))
.reduce(
(bytes, bytes2) -> {
if (bytes.length != bytes2.length) {
throw new WebApplicationException(Status.UNAUTHORIZED);
}
for (int i = 0; i < bytes.length; i++) {
bytes[i] ^= bytes2[i];
}
return bytes;
})
.ifPresent(
combinedUnidentifiedAccessKeys -> {
if (!MessageDigest.isEqual(combinedUnidentifiedAccessKeys, accessKeys.getAccessKeys())) {
throw new WebApplicationException(Status.UNAUTHORIZED);
}
});
}
@Timed
@ -716,6 +731,7 @@ public class MessageController {
private void sendCommonPayloadMessage(Account destinationAccount,
Device destinationDevice,
ServiceIdentifier serviceIdentifier,
long timestamp,
boolean online,
boolean story,
@ -739,7 +755,7 @@ public class MessageController {
.setContent(ByteString.copyFrom(payload))
.setStory(story)
.setUrgent(urgent)
.setDestinationUuid(new AciServiceIdentifier(destinationAccount.getUuid()).toServiceIdentifierString());
.setDestinationUuid(serviceIdentifier.toServiceIdentifierString());
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
} catch (NotPushRegisteredException e) {

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.controllers;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.collection.IsEmptyCollection.empty;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@ -29,6 +30,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
import static org.whispersystems.textsecuregcm.util.MockUtils.exactly;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableSet;
@ -42,21 +44,23 @@ import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
@ -73,8 +77,11 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsSources;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
@ -92,8 +99,6 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
@ -139,6 +144,7 @@ class MessageControllerTest {
private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111");
private static final byte SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222");
@ -149,6 +155,11 @@ class MessageControllerTest {
private static final int MULTI_DEVICE_REG_ID1 = 222;
private static final int MULTI_DEVICE_REG_ID2 = 333;
private static final int MULTI_DEVICE_REG_ID3 = 444;
private static final int MULTI_DEVICE_PNI_REG_ID1 = 2222;
private static final int MULTI_DEVICE_PNI_REG_ID2 = 3333;
private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444;
private static final UUID NONEXISTENT_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
@ -192,13 +203,13 @@ class MessageControllerTest {
final List<Device> singleDeviceList = List.of(
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis())
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis())
);
final List<Device> multiDeviceList = List.of(
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
);
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
@ -211,6 +222,8 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration =
mock(DynamicInboundMessageByteLimitConfiguration.class);
@ -922,25 +935,21 @@ class MessageControllerTest {
);
}
private static void writePayloadDeviceId(ByteBuffer bb, byte deviceId) {
long x = deviceId;
// write the device-id in the 7-bit varint format we use, least significant bytes first.
do {
long b = x & 0x7f;
x = x >>> 7;
if (x != 0) b |= 0x80;
bb.put((byte)b);
} while (x != 0);
private record Recipient(ServiceIdentifier uuid,
byte deviceId,
int registrationId,
byte[] perRecipientKeyMaterial) {
}
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, final boolean useExplicitIdentifier) {
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r,
final boolean useExplicitIdentifier) {
if (useExplicitIdentifier) {
bb.put(r.uuid().toFixedWidthByteArray());
} else {
bb.put(UUIDUtil.toBytes(r.uuid().uuid()));
}
writePayloadDeviceId(bb, r.deviceId()); // device id (1-9 bytes)
bb.put(r.deviceId()); // device id (1 byte)
bb.putShort((short) r.registrationId()); // registration id (2 bytes)
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
}
@ -953,8 +962,8 @@ class MessageControllerTest {
// first write the header
bb.put(explicitIdentifiers
? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
bb.put((byte)recipients.size()); // count varint
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
bb.put((byte)recipients.size()); // count varint
Iterator<Recipient> it = recipients.iterator();
while (it.hasNext()) {
@ -968,23 +977,24 @@ class MessageControllerTest {
return new ByteArrayInputStream(buffer, 0, bb.position());
}
@ParameterizedTest
@MethodSource
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent, boolean explicitIdentifier) throws Exception {
final List<Recipient> recipients;
if (recipientUUID == MULTI_DEVICE_UUID) {
recipients = List.of(
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
);
} else {
recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
}
// see testMultiRecipientMessageNoPni and testMultiRecipientMessagePni below for actual invocations
private void testMultiRecipientMessage(
Map<ServiceIdentifier, Map<Byte, Integer>> destinations,
boolean authorize,
boolean isStory,
boolean urgent,
boolean explicitIdentifier,
int expectedStatus,
int expectedMessagesSent) throws Exception {
final List<Recipient> recipients = new ArrayList<>();
destinations.forEach(
(serviceIdentifier, deviceToRegistrationId) ->
deviceToRegistrationId.forEach(
(deviceId, registrationId) ->
recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
//InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier);
// set up the entity to use in our PUT request
@ -1003,124 +1013,160 @@ class MessageControllerTest {
// add access header if needed
if (authorize) {
String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES);
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count();
String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes);
}
// make the PUT request
Response response = bldr.put(entity);
if (authorize) {
ArgumentCaptor<Envelope> envelopeArgumentCaptor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, atLeastOnce()).sendMessage(any(), any(), envelopeArgumentCaptor.capture(), anyBoolean());
assertEquals(urgent, envelopeArgumentCaptor.getValue().getUrgent());
}
// We have a 2x2x2 grid of possible situations based on:
// - recipient enabled stories?
// - sender is authorized?
// - message is a story?
//
// (urgent is not included in the grid because it has no effect
// on any of the other settings.)
if (recipientUUID == MULTI_DEVICE_UUID) {
// This is the case where the recipient has enabled stories.
if(isStory) {
// We are sending a story, so we ignore access checks and expect this
// to go out to both the recipient's devices.
checkGoodMultiRecipientResponse(response, 2);
} else {
// We are not sending a story, so we need to do access checks.
if (authorize) {
// When authorized we send a message to the recipient's devices.
checkGoodMultiRecipientResponse(response, 2);
} else {
// When forbidden, we return a 401 error.
checkBadMultiRecipientResponse(response, 401);
}
}
} else {
// This is the case where the recipient has not enabled stories.
if (isStory) {
// We are sending a story, so we ignore access checks.
// this recipient has one device.
checkGoodMultiRecipientResponse(response, 1);
} else {
// We are not sending a story so check access.
if (authorize) {
// If allowed, send a message to the recipient's one device.
checkGoodMultiRecipientResponse(response, 1);
} else {
// If forbidden, return a 401 error.
checkBadMultiRecipientResponse(response, 401);
}
}
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
verify(messageSender,
exactly(expectedMessagesSent))
.sendMessage(
any(),
any(),
argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()),
anyBoolean());
if (expectedStatus == 200) {
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assertThat(smrmr.uuids404(), is(empty()));
}
}
// Arguments here are: recipient-UUID, is-authorized?, is-story?
private static Stream<Arguments> testMultiRecipientMessage() {
return Stream.of(
Arguments.of(MULTI_DEVICE_UUID, false, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, true, false, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, false, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, true)
);
@SafeVarargs
private static <K, V> Map<K, V> submap(Map<K, V> map, K... keys) {
return Arrays.stream(keys).collect(Collectors.toMap(Function.identity(), map::get));
}
@Test
void testMultiRecipientMessageToAccountsSomeOfWhichDoNotExist() throws Exception {
UUID badUUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(badUUID))).thenReturn(Optional.empty());
private static Map<ServiceIdentifier, Map<Byte, Integer>> multiRecipientTargetMap() {
return
Map.of(
new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2),
new PniServiceIdentifier(MULTI_DEVICE_PNI),
Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2),
new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1)
);
}
final List<Recipient> recipients = List.of(
new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1,
new byte[48]),
new Recipient(new AciServiceIdentifier(badUUID), (byte) 1, 1, new byte[48]));
private record MultiRecipientMessageTestCase(
Map<ServiceIdentifier, Map<Byte, Integer>> destinations,
boolean authenticated,
boolean story,
int expectedStatus,
int expectedSentMessages) {
}
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1700000000000L)
.queryParam("story", true)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], true),
MultiRecipientMessageProvider.MEDIA_TYPE));
@CartesianTest
@CartesianTest.MethodFactory("testMultiRecipientMessageNoPni")
void testMultiRecipientMessageNoPni(MultiRecipientMessageTestCase testCase, boolean urgent , boolean explicitIdentifier) throws Exception {
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages());
}
checkGoodMultiRecipientResponse(response, 1);
private static ArgumentSets testMultiRecipientMessageNoPni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsAci =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeAci =
submap(
targets,
new AciServiceIdentifier(SINGLE_DEVICE_UUID),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
new AciServiceIdentifier(NONEXISTENT_UUID));
final boolean auth = true;
final boolean unauth = false;
final boolean story = true;
final boolean notStory = false;
return ArgumentSets
.argumentsForFirstParameter(
new MultiRecipientMessageTestCase(singleDeviceAci, unauth, story, 200, 1),
new MultiRecipientMessageTestCase(multiDeviceAci, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsAci, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeAci, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDeviceAci, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(multiDeviceAci, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(bothAccountsAci, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(realAndFakeAci, unauth, notStory, 404, 0),
new MultiRecipientMessageTestCase(singleDeviceAci, auth, story, 200, 1),
new MultiRecipientMessageTestCase(multiDeviceAci, auth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsAci, auth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeAci, auth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDeviceAci, auth, notStory, 200, 1),
new MultiRecipientMessageTestCase(multiDeviceAci, auth, notStory, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsAci, auth, notStory, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeAci, auth, notStory, 404, 0))
.argumentsForNextParameter(false, true) // urgent
.argumentsForNextParameter(false, true); // explicitIdentifiers
}
@CartesianTest
@CartesianTest.MethodFactory("testMultiRecipientMessagePni")
void testMultiRecipientMessagePni(MultiRecipientMessageTestCase testCase, boolean urgent) throws Exception {
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, true, testCase.expectedStatus(), testCase.expectedSentMessages());
}
private static ArgumentSets testMultiRecipientMessagePni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAciAndPni = submap(
targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsMixed =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeMixed =
submap(
targets,
new PniServiceIdentifier(SINGLE_DEVICE_PNI),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
new PniServiceIdentifier(NONEXISTENT_UUID));
final boolean auth = true;
final boolean unauth = false;
final boolean story = true;
final boolean notStory = false;
return ArgumentSets
.argumentsForFirstParameter(
new MultiRecipientMessageTestCase(singleDevicePni, unauth, story, 200, 1),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(multiDevicePni, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDevicePni, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(multiDevicePni, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, notStory, 404, 0),
new MultiRecipientMessageTestCase(singleDevicePni, auth, story, 200, 1),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, auth, story, 200, 2),
new MultiRecipientMessageTestCase(multiDevicePni, auth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsMixed, auth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeMixed, auth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDevicePni, auth, notStory, 200, 1),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(multiDevicePni, auth, notStory, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsMixed, auth, notStory, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeMixed, auth, notStory, 404, 0))
.argumentsForNextParameter(false, true); // urgent
}
@ParameterizedTest
@ -1128,7 +1174,7 @@ class MessageControllerTest {
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
final List<Recipient> recipients = List.of(
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
Response response = resources
@ -1326,12 +1372,12 @@ class MessageControllerTest {
@ParameterizedTest
@MethodSource
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier)
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier, final int regId1, final int regId2)
throws NotPushRegisteredException, InterruptedException {
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, regId1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, regId2, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
@ -1364,8 +1410,8 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessage404() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
}
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
@ -1373,14 +1419,6 @@ class MessageControllerTest {
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
}
private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(200)));
ArgumentCaptor<List<Callable<Void>>> captor = ArgumentCaptor.forClass(List.class);
verify(messageSender, times(expectedCount)).sendMessage(any(), any(), any(), anyBoolean());
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assert (smrmr.uuids404().isEmpty());
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false);
@ -1413,64 +1451,4 @@ class MessageControllerTest {
return builder.build();
}
private static Recipient genRecipient(Random rng) {
UUID u1 = UUID.randomUUID(); // non-null
byte d1 = (byte) (rng.nextInt(127) + 1); // 1 to 127
int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
byte[] perKeyBytes = new byte[48]; // size=48, non-null
rng.nextBytes(perKeyBytes);
return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes);
}
private static void roundTripVarint(byte expected, byte[] bytes) throws Exception {
ByteBuffer bb = ByteBuffer.wrap(bytes);
writePayloadDeviceId(bb, expected);
InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position());
long got = MultiRecipientMessageProvider.readVarint(stream);
assertEquals(expected, got, String.format("encoded as: %s", Arrays.toString(bytes)));
}
@Test
void testVarintPayload() throws Exception {
Random rng = new Random();
byte[] bytes = new byte[12];
// some static test cases
for (byte i = 1; i <= 10; i++) {
roundTripVarint(i, bytes);
}
roundTripVarint(Byte.MAX_VALUE, bytes);
for (int i = 0; i < 1000; i++) {
// we need to ensure positive device IDs
byte start = (byte) rng.nextInt(128);
if (start == 0L) {
start = 1;
}
// run the test for this case
roundTripVarint(start, bytes);
}
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiPayloadRoundtrip(final boolean useExplicitIdentifiers) throws Exception {
Random rng = new java.util.Random();
List<Recipient> expected = new LinkedList<>();
for(int i = 0; i < 100; i++) {
expected.add(genRecipient(rng));
}
byte[] buffer = new byte[100 + expected.size() * 100];
InputStream entityStream = initializeMultiPayload(expected, buffer, useExplicitIdentifiers);
MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider();
// the provider ignores the headers, java reflection, etc. so we don't use those here.
MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream);
List<Recipient> got = Arrays.asList(res.recipients());
assertEquals(expected, got);
}
}

View File

@ -10,6 +10,8 @@ import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.internal.exceptions.Reporter.noMoreInteractionsWanted;
import static org.mockito.internal.exceptions.Reporter.tooFewActualInvocations;
import static org.mockito.internal.exceptions.Reporter.tooManyActualInvocations;
import static org.mockito.internal.exceptions.Reporter.wantedButNotInvoked;
import static org.mockito.internal.invocation.InvocationMarker.markVerified;
import static org.mockito.internal.invocation.InvocationsFinder.findFirstUnverified;
@ -26,6 +28,7 @@ import org.mockito.Mockito;
import org.mockito.invocation.Invocation;
import org.mockito.invocation.MatchableInvocation;
import org.mockito.verification.VerificationMode;
import org.mockito.internal.verification.Times;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
@ -171,10 +174,17 @@ public final class MockUtils {
* this method
*/
public static VerificationMode exactly() {
return exactly(1);
}
/**
* a combination of {@link #exactly()} and {@link org.mockito.Mockito#times(int)}, verifies that
* there are exactly N invocations of this method, and all of them match the given specification
*/
public static VerificationMode exactly(int wantedCount) {
return data -> {
MatchableInvocation target = data.getTarget();
final List<Invocation> allInvocations = data.getAllInvocations();
List<Invocation> chunk = findInvocations(allInvocations, target);
List<Invocation> otherInvocations = allInvocations.stream()
.filter(target::hasSameMethod)
.filter(Predicate.not(target::matches))
@ -184,10 +194,7 @@ public final class MockUtils {
Invocation unverified = findFirstUnverified(otherInvocations);
throw noMoreInteractionsWanted(unverified, (List) allInvocations);
}
if (chunk.isEmpty()) {
throw wantedButNotInvoked(target);
}
markVerified(chunk.get(0), target);
Mockito.times(wantedCount).verify(data);
};
}