multisend cleanup
This commit is contained in:
parent
1fb88271e5
commit
4efda89358
|
@ -16,6 +16,12 @@ import io.micrometer.core.instrument.Counter;
|
|||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.micrometer.core.instrument.Tag;
|
||||
import io.micrometer.core.instrument.Tags;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.media.Content;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
||||
|
||||
import java.security.MessageDigest;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
|
@ -37,7 +43,11 @@ import java.util.concurrent.CompletableFuture;
|
|||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nonnull;
|
||||
import javax.annotation.Nullable;
|
||||
|
@ -112,6 +122,8 @@ import org.whispersystems.textsecuregcm.util.Pair;
|
|||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
|
||||
import org.whispersystems.websocket.Stories;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
|
||||
|
@ -119,6 +131,12 @@ import reactor.core.scheduler.Scheduler;
|
|||
@io.swagger.v3.oas.annotations.tags.Tag(name = "Messages")
|
||||
public class MessageController {
|
||||
|
||||
private record MultiRecipientDeliveryData(
|
||||
ServiceIdentifier serviceIdentifier,
|
||||
Account account,
|
||||
Map<Byte, Recipient> perDeviceData) {
|
||||
}
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MessageController.class);
|
||||
|
||||
private final RateLimiters rateLimiters;
|
||||
|
@ -141,6 +159,7 @@ public class MessageController {
|
|||
private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage");
|
||||
|
||||
private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType");
|
||||
private static final String UNEXPECTED_MISSING_USER_COUNTER_NAME = name(MessageController.class, "unexpectedMissingDestinationForMultiRecipientMessage");
|
||||
|
||||
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
|
||||
private static final String SENDER_TYPE_TAG_NAME = "senderType";
|
||||
|
@ -347,26 +366,25 @@ public class MessageController {
|
|||
|
||||
|
||||
/**
|
||||
* Build mapping of accounts to devices/registration IDs.
|
||||
* Build mapping of service IDs to resolved accounts and device/registration IDs
|
||||
*/
|
||||
private Map<Account, Set<Pair<Byte, Integer>>> buildDeviceIdAndRegistrationIdMap(
|
||||
MultiRecipientMessage multiRecipientMessage,
|
||||
Map<ServiceIdentifier, Account> accountsByServiceIdentifier) {
|
||||
|
||||
return Arrays.stream(multiRecipientMessage.recipients())
|
||||
// for normal messages, all recipients UUIDs are in the map,
|
||||
// but story messages might specify inactive UUIDs, which we
|
||||
// have previously filtered
|
||||
.filter(r -> accountsByServiceIdentifier.containsKey(r.uuid()))
|
||||
.collect(Collectors.toMap(
|
||||
recipient -> accountsByServiceIdentifier.get(recipient.uuid()),
|
||||
recipient -> new HashSet<>(
|
||||
Collections.singletonList(new Pair<>(recipient.deviceId(), recipient.registrationId()))),
|
||||
(a, b) -> {
|
||||
a.addAll(b);
|
||||
return a;
|
||||
}
|
||||
));
|
||||
private Map<ServiceIdentifier, MultiRecipientDeliveryData> buildRecipientMap(
|
||||
MultiRecipientMessage multiRecipientMessage, boolean isStory) {
|
||||
return Flux.fromArray(multiRecipientMessage.recipients())
|
||||
.groupBy(Recipient::uuid, multiRecipientMessage.recipients().length)
|
||||
.flatMap(
|
||||
gf -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(gf.key()))
|
||||
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
|
||||
.flatMap(
|
||||
account ->
|
||||
gf.collectMap(Recipient::deviceId)
|
||||
.map(perRecipientData ->
|
||||
new MultiRecipientDeliveryData(
|
||||
gf.key(),
|
||||
account,
|
||||
perRecipientData))))
|
||||
.collectMap(MultiRecipientDeliveryData::serviceIdentifier)
|
||||
.block();
|
||||
}
|
||||
|
||||
@Timed
|
||||
|
@ -375,79 +393,87 @@ public class MessageController {
|
|||
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
|
||||
@Produces(MediaType.APPLICATION_JSON)
|
||||
@FilterSpam
|
||||
@Operation(
|
||||
summary = "Send multi-recipient sealed-sender message",
|
||||
description = """
|
||||
Deliver a common-payload message to multiple recipients.
|
||||
An unidentifed-access key for all recipients must be provided, unless the message is a story.
|
||||
""")
|
||||
@ApiResponse(responseCode="200", description="Message was successfully sent to all recipients", useReturnTypeSchema=true)
|
||||
@ApiResponse(responseCode="400", description="The envelope specified delivery to the same recipient device multiple times")
|
||||
@ApiResponse(responseCode="401", description="The message is not a story and the unauthorized access key is incorrect")
|
||||
@ApiResponse(
|
||||
responseCode="404",
|
||||
description="The message is not a story and some of the recipient service IDs do not correspond to registered Signal users")
|
||||
@ApiResponse(
|
||||
responseCode = "409", description = "Incorrect set of devices supplied for some recipients",
|
||||
content = @Content(schema = @Schema(implementation = AccountMismatchedDevices[].class)))
|
||||
@ApiResponse(
|
||||
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
|
||||
content = @Content(schema = @Schema(implementation = AccountStaleDevices[].class)))
|
||||
|
||||
public Response sendMultiRecipientMessage(
|
||||
@Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message")
|
||||
@HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys,
|
||||
|
||||
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
|
||||
|
||||
@Parameter(description="If true, deliver the message only to recipients that are online when it is sent")
|
||||
@QueryParam("online") boolean online,
|
||||
|
||||
@Parameter(description="The sender's timestamp for the envelope")
|
||||
@QueryParam("ts") long timestamp,
|
||||
|
||||
@Parameter(description="If true, this message should cause push notifications to be sent to recipients")
|
||||
@QueryParam("urgent") @DefaultValue("true") final boolean isUrgent,
|
||||
|
||||
@Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted")
|
||||
@QueryParam("story") boolean isStory,
|
||||
@Parameter(description="The sealed-sender multi-recipient message payload")
|
||||
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {
|
||||
|
||||
final Map<ServiceIdentifier, Account> accountsByServiceIdentifier = new HashMap<>();
|
||||
|
||||
for (final Recipient recipient : multiRecipientMessage.recipients()) {
|
||||
if (!accountsByServiceIdentifier.containsKey(recipient.uuid())) {
|
||||
final Optional<Account> maybeAccount = accountsManager.getByServiceIdentifier(recipient.uuid());
|
||||
|
||||
if (maybeAccount.isPresent()) {
|
||||
accountsByServiceIdentifier.put(recipient.uuid(), maybeAccount.get());
|
||||
} else {
|
||||
if (!isStory) {
|
||||
throw new NotFoundException();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
final Map<ServiceIdentifier, MultiRecipientDeliveryData> recipients = buildRecipientMap(multiRecipientMessage, isStory);
|
||||
|
||||
// Stories will be checked by the client; we bypass access checks here for stories.
|
||||
if (!isStory) {
|
||||
checkAccessKeys(accessKeys, accountsByServiceIdentifier.values());
|
||||
checkAccessKeys(accessKeys, recipients.values());
|
||||
}
|
||||
|
||||
final Map<Account, Set<Pair<Byte, Integer>>> accountToDeviceIdAndRegistrationIdMap =
|
||||
buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, accountsByServiceIdentifier);
|
||||
|
||||
// We might filter out all the recipients of a story (if none have enabled stories).
|
||||
// We might filter out all the recipients of a story (if none exist).
|
||||
// In this case there is no error so we should just return 200 now.
|
||||
if (isStory && accountToDeviceIdAndRegistrationIdMap.isEmpty()) {
|
||||
return Response.ok(new SendMultiRecipientMessageResponse(new LinkedList<>())).build();
|
||||
if (isStory) {
|
||||
if (recipients.isEmpty()) {
|
||||
return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build();
|
||||
}
|
||||
for (MultiRecipientDeliveryData recipient : recipients.values()) {
|
||||
rateLimiters.getStoriesLimiter().validate(recipient.account().getUuid());
|
||||
}
|
||||
}
|
||||
|
||||
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
|
||||
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
|
||||
recipients.values().forEach(recipient -> {
|
||||
final Account account = recipient.account();
|
||||
|
||||
for (Map.Entry<ServiceIdentifier, Account> entry : accountsByServiceIdentifier.entrySet()) {
|
||||
final ServiceIdentifier serviceIdentifier = entry.getKey();
|
||||
final Account account = entry.getValue();
|
||||
|
||||
if (isStory) {
|
||||
rateLimiters.getStoriesLimiter().validate(account.getUuid());
|
||||
}
|
||||
|
||||
Set<Byte> deviceIds = accountToDeviceIdAndRegistrationIdMap
|
||||
.getOrDefault(account, Collections.emptySet())
|
||||
.stream()
|
||||
.map(Pair::first)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
try {
|
||||
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet());
|
||||
|
||||
// Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number
|
||||
// identity
|
||||
DestinationDeviceValidator.validateRegistrationIds(
|
||||
account,
|
||||
accountToDeviceIdAndRegistrationIdMap.get(account).stream(),
|
||||
false);
|
||||
} catch (MismatchedDevicesException e) {
|
||||
accountMismatchedDevices.add(new AccountMismatchedDevices(serviceIdentifier,
|
||||
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
|
||||
} catch (StaleDevicesException e) {
|
||||
accountStaleDevices.add(new AccountStaleDevices(serviceIdentifier, new StaleDevices(e.getStaleDevices())));
|
||||
}
|
||||
}
|
||||
try {
|
||||
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.perDeviceData().keySet(), Collections.emptySet());
|
||||
|
||||
DestinationDeviceValidator.validateRegistrationIds(
|
||||
account,
|
||||
recipient.perDeviceData().values(),
|
||||
Recipient::deviceId,
|
||||
Recipient::registrationId,
|
||||
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
|
||||
} catch (MismatchedDevicesException e) {
|
||||
accountMismatchedDevices.add(
|
||||
new AccountMismatchedDevices(
|
||||
recipient.serviceIdentifier(),
|
||||
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
|
||||
} catch (StaleDevicesException e) {
|
||||
accountStaleDevices.add(
|
||||
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices())));
|
||||
}
|
||||
});
|
||||
if (!accountMismatchedDevices.isEmpty()) {
|
||||
return Response
|
||||
.status(409)
|
||||
|
@ -472,25 +498,28 @@ public class MessageController {
|
|||
Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED)));
|
||||
|
||||
CompletableFuture.allOf(
|
||||
Arrays.stream(multiRecipientMessage.recipients())
|
||||
// If we're sending a story, some recipients might not map to existing accounts
|
||||
.filter(recipient -> accountsByServiceIdentifier.containsKey(recipient.uuid()))
|
||||
.map(
|
||||
recipient -> CompletableFuture.runAsync(
|
||||
() -> {
|
||||
Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid());
|
||||
|
||||
// we asserted this must exist in validateCompleteDeviceList
|
||||
Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
|
||||
sentMessageCounter.increment();
|
||||
try {
|
||||
sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent,
|
||||
recipient, multiRecipientMessage.commonPayload());
|
||||
} catch (NoSuchUserException e) {
|
||||
uuids404.add(recipient.uuid());
|
||||
}
|
||||
},
|
||||
multiRecipientMessageExecutor))
|
||||
recipients.values().stream()
|
||||
.flatMap(recipientData ->
|
||||
recipientData.perDeviceData().values().stream().map(
|
||||
recipient -> CompletableFuture.runAsync(
|
||||
() -> {
|
||||
final Account destinationAccount = recipientData.account();
|
||||
// we asserted this must exist in validateCompleteDeviceList
|
||||
final Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
|
||||
try {
|
||||
sentMessageCounter.increment();
|
||||
sendCommonPayloadMessage(
|
||||
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, online,
|
||||
isStory, isUrgent, recipient, multiRecipientMessage.commonPayload());
|
||||
} catch (NoSuchUserException e) {
|
||||
// this should never happen, because we already asserted the device is present and enabled
|
||||
Metrics.counter(
|
||||
UNEXPECTED_MISSING_USER_COUNTER_NAME,
|
||||
Tags.of("isPrimary", String.valueOf(destinationDevice.isPrimary()))).increment();
|
||||
uuids404.add(recipientData.serviceIdentifier());
|
||||
}
|
||||
},
|
||||
multiRecipientMessageExecutor)))
|
||||
.toArray(CompletableFuture[]::new))
|
||||
.get();
|
||||
} catch (InterruptedException e) {
|
||||
|
@ -506,41 +535,27 @@ public class MessageController {
|
|||
return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build();
|
||||
}
|
||||
|
||||
private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<Account> destinationAccounts) {
|
||||
private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<MultiRecipientDeliveryData> destinations) {
|
||||
// We should not have null access keys when checking access; bail out early.
|
||||
if (accessKeys == null) {
|
||||
throw new WebApplicationException(Status.UNAUTHORIZED);
|
||||
}
|
||||
AtomicBoolean throwUnauthorized = new AtomicBoolean(false);
|
||||
byte[] empty = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH];
|
||||
final Optional<byte[]> UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
|
||||
byte[] combinedUnknownAccessKeys = destinationAccounts.stream()
|
||||
.map(account -> {
|
||||
if (account.isUnrestrictedUnidentifiedAccess()) {
|
||||
return UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY;
|
||||
} else {
|
||||
return account.getUnidentifiedAccessKey();
|
||||
}
|
||||
})
|
||||
.map(accessKey -> {
|
||||
if (accessKey.isEmpty()) {
|
||||
throwUnauthorized.set(true);
|
||||
return empty;
|
||||
}
|
||||
return accessKey.get();
|
||||
})
|
||||
.reduce(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH], (bytes, bytes2) -> {
|
||||
if (bytes.length != bytes2.length) {
|
||||
throwUnauthorized.set(true);
|
||||
return bytes;
|
||||
}
|
||||
for (int i = 0; i < bytes.length; i++) {
|
||||
bytes[i] ^= bytes2[i];
|
||||
}
|
||||
return bytes;
|
||||
});
|
||||
if (throwUnauthorized.get()
|
||||
|| !MessageDigest.isEqual(combinedUnknownAccessKeys, accessKeys.getAccessKeys())) {
|
||||
|
||||
final int keyLength = UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH;
|
||||
final byte[] combinedUnidentifiedAccessKeys = destinations.stream()
|
||||
.map(MultiRecipientDeliveryData::account)
|
||||
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess))
|
||||
.map(account ->
|
||||
account.getUnidentifiedAccessKey()
|
||||
.filter(b -> b.length == keyLength)
|
||||
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)))
|
||||
.reduce(new byte[keyLength],
|
||||
(a, b) -> {
|
||||
final byte[] xor = new byte[keyLength];
|
||||
IntStream.range(0, keyLength).forEach(i -> xor[i] = (byte) (a[i] ^ b[i]));
|
||||
return xor;
|
||||
});
|
||||
if (!MessageDigest.isEqual(combinedUnidentifiedAccessKeys, accessKeys.getAccessKeys())) {
|
||||
throw new WebApplicationException(Status.UNAUTHORIZED);
|
||||
}
|
||||
}
|
||||
|
@ -720,6 +735,7 @@ public class MessageController {
|
|||
|
||||
private void sendCommonPayloadMessage(Account destinationAccount,
|
||||
Device destinationDevice,
|
||||
ServiceIdentifier serviceIdentifier,
|
||||
long timestamp,
|
||||
boolean online,
|
||||
boolean story,
|
||||
|
@ -743,7 +759,7 @@ public class MessageController {
|
|||
.setContent(ByteString.copyFrom(payload))
|
||||
.setStory(story)
|
||||
.setUrgent(urgent)
|
||||
.setDestinationUuid(new AciServiceIdentifier(destinationAccount.getUuid()).toServiceIdentifierString());
|
||||
.setDestinationUuid(serviceIdentifier.toServiceIdentifierString());
|
||||
|
||||
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
|
||||
} catch (NotPushRegisteredException e) {
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.controllers;
|
|||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.CoreMatchers.not;
|
||||
import static org.hamcrest.collection.IsEmptyCollection.empty;
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
|
@ -29,6 +30,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
|
|||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
|
||||
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
|
||||
import static org.whispersystems.textsecuregcm.util.MockUtils.exactly;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
@ -42,21 +44,24 @@ import java.io.ByteArrayInputStream;
|
|||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
import java.util.stream.Stream;
|
||||
import javax.ws.rs.client.Entity;
|
||||
import javax.ws.rs.client.Invocation;
|
||||
|
@ -73,8 +78,11 @@ import org.junit.jupiter.api.Test;
|
|||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.ArgumentsSources;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
|
||||
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
@ -92,8 +100,6 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
|||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
||||
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
|
||||
|
@ -139,6 +145,7 @@ class MessageControllerTest {
|
|||
private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111");
|
||||
private static final byte SINGLE_DEVICE_ID1 = 1;
|
||||
private static final int SINGLE_DEVICE_REG_ID1 = 111;
|
||||
private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111;
|
||||
|
||||
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
|
||||
private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222");
|
||||
|
@ -149,6 +156,11 @@ class MessageControllerTest {
|
|||
private static final int MULTI_DEVICE_REG_ID1 = 222;
|
||||
private static final int MULTI_DEVICE_REG_ID2 = 333;
|
||||
private static final int MULTI_DEVICE_REG_ID3 = 444;
|
||||
private static final int MULTI_DEVICE_PNI_REG_ID1 = 2222;
|
||||
private static final int MULTI_DEVICE_PNI_REG_ID2 = 3333;
|
||||
private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444;
|
||||
|
||||
private static final UUID NONEXISTENT_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
|
||||
|
||||
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
|
||||
|
||||
|
@ -192,13 +204,13 @@ class MessageControllerTest {
|
|||
|
||||
|
||||
final List<Device> singleDeviceList = List.of(
|
||||
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis())
|
||||
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis())
|
||||
);
|
||||
|
||||
final List<Device> multiDeviceList = List.of(
|
||||
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
|
||||
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
|
||||
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
|
||||
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
|
||||
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
|
||||
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
|
||||
);
|
||||
|
||||
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
|
||||
|
@ -211,6 +223,8 @@ class MessageControllerTest {
|
|||
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
|
||||
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount));
|
||||
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount));
|
||||
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
|
||||
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
|
||||
|
||||
final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration =
|
||||
mock(DynamicInboundMessageByteLimitConfiguration.class);
|
||||
|
@ -942,25 +956,21 @@ class MessageControllerTest {
|
|||
);
|
||||
}
|
||||
|
||||
private static void writePayloadDeviceId(ByteBuffer bb, byte deviceId) {
|
||||
long x = deviceId;
|
||||
// write the device-id in the 7-bit varint format we use, least significant bytes first.
|
||||
do {
|
||||
long b = x & 0x7f;
|
||||
x = x >>> 7;
|
||||
if (x != 0) b |= 0x80;
|
||||
bb.put((byte)b);
|
||||
} while (x != 0);
|
||||
private record Recipient(ServiceIdentifier uuid,
|
||||
byte deviceId,
|
||||
int registrationId,
|
||||
byte[] perRecipientKeyMaterial) {
|
||||
}
|
||||
|
||||
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, final boolean useExplicitIdentifier) {
|
||||
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r,
|
||||
final boolean useExplicitIdentifier) {
|
||||
if (useExplicitIdentifier) {
|
||||
bb.put(r.uuid().toFixedWidthByteArray());
|
||||
} else {
|
||||
bb.put(UUIDUtil.toBytes(r.uuid().uuid()));
|
||||
}
|
||||
|
||||
writePayloadDeviceId(bb, r.deviceId()); // device id (1-9 bytes)
|
||||
bb.put(r.deviceId()); // device id (1 byte)
|
||||
bb.putShort((short) r.registrationId()); // registration id (2 bytes)
|
||||
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
|
||||
}
|
||||
|
@ -973,8 +983,15 @@ class MessageControllerTest {
|
|||
// first write the header
|
||||
bb.put(explicitIdentifiers
|
||||
? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER
|
||||
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
|
||||
bb.put((byte)recipients.size()); // count varint
|
||||
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
|
||||
|
||||
// count varint
|
||||
int nRecip = recipients.size();
|
||||
while (nRecip > 127) {
|
||||
bb.put((byte) (nRecip & 0x7F | 0x80));
|
||||
nRecip = nRecip >> 7;
|
||||
}
|
||||
bb.put((byte)(nRecip & 0x7F));
|
||||
|
||||
Iterator<Recipient> it = recipients.iterator();
|
||||
while (it.hasNext()) {
|
||||
|
@ -988,23 +1005,65 @@ class MessageControllerTest {
|
|||
return new ByteArrayInputStream(buffer, 0, bb.position());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent, boolean explicitIdentifier) throws Exception {
|
||||
@Test
|
||||
void testManyRecipientMessage() throws Exception {
|
||||
final int nRecipients = 999;
|
||||
final int devicesPerRecipient = 5;
|
||||
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
|
||||
final List<Recipient> recipients = new ArrayList<>();
|
||||
|
||||
final List<Recipient> recipients;
|
||||
if (recipientUUID == MULTI_DEVICE_UUID) {
|
||||
recipients = List.of(
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
|
||||
);
|
||||
} else {
|
||||
recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
|
||||
for (int i = 0; i < nRecipients; i++) {
|
||||
final List<Device> devices =
|
||||
IntStream.range(1, devicesPerRecipient + 1)
|
||||
.mapToObj(
|
||||
d -> generateTestDevice(
|
||||
(byte) d, 100 + d, 10 * d, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(),
|
||||
System.currentTimeMillis()))
|
||||
.collect(Collectors.toList());
|
||||
final UUID aci = new UUID(0L, (long) i);
|
||||
final UUID pni = new UUID(1L, (long) i);
|
||||
final String e164 = String.format("+1408555%04d", i);
|
||||
final Account account = AccountsHelper.generateTestAccount(e164, aci, pni, devices, UNIDENTIFIED_ACCESS_BYTES);
|
||||
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci))).thenReturn(Optional.of(account));
|
||||
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(pni))).thenReturn(Optional.of(account));
|
||||
devices.forEach(d -> recipients.add(new Recipient(new AciServiceIdentifier(aci), d.getId(), d.getRegistrationId(), new byte[48])));
|
||||
}
|
||||
|
||||
byte[] buffer = new byte[1048576];
|
||||
InputStream stream = initializeMultiPayload(recipients, buffer, true);
|
||||
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
|
||||
final Response response = resources
|
||||
.getJerseyTest()
|
||||
.target("/v1/messages/multi_recipient")
|
||||
.queryParam("online", true)
|
||||
.queryParam("story", true)
|
||||
.queryParam("urgent", false)
|
||||
.request()
|
||||
.header(HttpHeaders.USER_AGENT, "FIXME")
|
||||
.put(entity);
|
||||
|
||||
assertThat(response.readEntity(String.class), response.getStatus(), is(equalTo(200)));
|
||||
verify(messageSender, times(nRecipients * devicesPerRecipient)).sendMessage(any(), any(), any(), eq(true));
|
||||
}
|
||||
|
||||
// see testMultiRecipientMessageNoPni and testMultiRecipientMessagePni below for actual invocations
|
||||
private void testMultiRecipientMessage(
|
||||
Map<ServiceIdentifier, Map<Byte, Integer>> destinations,
|
||||
boolean authorize,
|
||||
boolean isStory,
|
||||
boolean urgent,
|
||||
boolean explicitIdentifier,
|
||||
int expectedStatus,
|
||||
int expectedMessagesSent) throws Exception {
|
||||
final List<Recipient> recipients = new ArrayList<>();
|
||||
destinations.forEach(
|
||||
(serviceIdentifier, deviceToRegistrationId) ->
|
||||
deviceToRegistrationId.forEach(
|
||||
(deviceId, registrationId) ->
|
||||
recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
|
||||
|
||||
// initialize our binary payload and create an input stream
|
||||
byte[] buffer = new byte[2048];
|
||||
//InputStream stream = initializeMultiPayload(recipientUUID, buffer);
|
||||
InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier);
|
||||
|
||||
// set up the entity to use in our PUT request
|
||||
|
@ -1023,124 +1082,160 @@ class MessageControllerTest {
|
|||
|
||||
// add access header if needed
|
||||
if (authorize) {
|
||||
String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES);
|
||||
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count();
|
||||
String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
|
||||
bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes);
|
||||
}
|
||||
|
||||
// make the PUT request
|
||||
Response response = bldr.put(entity);
|
||||
|
||||
if (authorize) {
|
||||
ArgumentCaptor<Envelope> envelopeArgumentCaptor = ArgumentCaptor.forClass(Envelope.class);
|
||||
verify(messageSender, atLeastOnce()).sendMessage(any(), any(), envelopeArgumentCaptor.capture(), anyBoolean());
|
||||
assertEquals(urgent, envelopeArgumentCaptor.getValue().getUrgent());
|
||||
}
|
||||
|
||||
// We have a 2x2x2 grid of possible situations based on:
|
||||
// - recipient enabled stories?
|
||||
// - sender is authorized?
|
||||
// - message is a story?
|
||||
//
|
||||
// (urgent is not included in the grid because it has no effect
|
||||
// on any of the other settings.)
|
||||
|
||||
if (recipientUUID == MULTI_DEVICE_UUID) {
|
||||
// This is the case where the recipient has enabled stories.
|
||||
if(isStory) {
|
||||
// We are sending a story, so we ignore access checks and expect this
|
||||
// to go out to both the recipient's devices.
|
||||
checkGoodMultiRecipientResponse(response, 2);
|
||||
} else {
|
||||
// We are not sending a story, so we need to do access checks.
|
||||
if (authorize) {
|
||||
// When authorized we send a message to the recipient's devices.
|
||||
checkGoodMultiRecipientResponse(response, 2);
|
||||
} else {
|
||||
// When forbidden, we return a 401 error.
|
||||
checkBadMultiRecipientResponse(response, 401);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// This is the case where the recipient has not enabled stories.
|
||||
if (isStory) {
|
||||
// We are sending a story, so we ignore access checks.
|
||||
// this recipient has one device.
|
||||
checkGoodMultiRecipientResponse(response, 1);
|
||||
} else {
|
||||
// We are not sending a story so check access.
|
||||
if (authorize) {
|
||||
// If allowed, send a message to the recipient's one device.
|
||||
checkGoodMultiRecipientResponse(response, 1);
|
||||
} else {
|
||||
// If forbidden, return a 401 error.
|
||||
checkBadMultiRecipientResponse(response, 401);
|
||||
}
|
||||
}
|
||||
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
|
||||
verify(messageSender,
|
||||
exactly(expectedMessagesSent))
|
||||
.sendMessage(
|
||||
any(),
|
||||
any(),
|
||||
argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()),
|
||||
eq(true));
|
||||
if (expectedStatus == 200) {
|
||||
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
|
||||
assertThat(smrmr.uuids404(), is(empty()));
|
||||
}
|
||||
}
|
||||
|
||||
// Arguments here are: recipient-UUID, is-authorized?, is-story?
|
||||
private static Stream<Arguments> testMultiRecipientMessage() {
|
||||
return Stream.of(
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, true, true, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, false, true, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, true, true, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, false, true, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, true, false, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, false, false, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, true, false, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, false, false, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, true, true, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, false, true, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, true, true, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, false, true, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, true, false, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, false, false, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, true, false, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, false, false, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, true)
|
||||
);
|
||||
@SafeVarargs
|
||||
private static <K, V> Map<K, V> submap(Map<K, V> map, K... keys) {
|
||||
return Arrays.stream(keys).collect(Collectors.toMap(Function.identity(), map::get));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiRecipientMessageToAccountsSomeOfWhichDoNotExist() throws Exception {
|
||||
UUID badUUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
|
||||
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(badUUID))).thenReturn(Optional.empty());
|
||||
private static Map<ServiceIdentifier, Map<Byte, Integer>> multiRecipientTargetMap() {
|
||||
return
|
||||
Map.of(
|
||||
new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
|
||||
new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1),
|
||||
new AciServiceIdentifier(MULTI_DEVICE_UUID),
|
||||
Map.of(
|
||||
MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1,
|
||||
MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2),
|
||||
new PniServiceIdentifier(MULTI_DEVICE_PNI),
|
||||
Map.of(
|
||||
MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1,
|
||||
MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2),
|
||||
new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
|
||||
new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1)
|
||||
);
|
||||
}
|
||||
|
||||
final List<Recipient> recipients = List.of(
|
||||
new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1,
|
||||
new byte[48]),
|
||||
new Recipient(new AciServiceIdentifier(badUUID), (byte) 1, 1, new byte[48]));
|
||||
private record MultiRecipientMessageTestCase(
|
||||
Map<ServiceIdentifier, Map<Byte, Integer>> destinations,
|
||||
boolean authenticated,
|
||||
boolean story,
|
||||
int expectedStatus,
|
||||
int expectedSentMessages) {
|
||||
}
|
||||
|
||||
Response response = resources
|
||||
.getJerseyTest()
|
||||
.target("/v1/messages/multi_recipient")
|
||||
.queryParam("online", true)
|
||||
.queryParam("ts", 1700000000000L)
|
||||
.queryParam("story", true)
|
||||
.queryParam("urgent", false)
|
||||
.request()
|
||||
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
|
||||
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
|
||||
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], true),
|
||||
MultiRecipientMessageProvider.MEDIA_TYPE));
|
||||
@CartesianTest
|
||||
@CartesianTest.MethodFactory("testMultiRecipientMessageNoPni")
|
||||
void testMultiRecipientMessageNoPni(MultiRecipientMessageTestCase testCase, boolean urgent , boolean explicitIdentifier) throws Exception {
|
||||
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages());
|
||||
}
|
||||
|
||||
checkGoodMultiRecipientResponse(response, 1);
|
||||
private static ArgumentSets testMultiRecipientMessageNoPni() {
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsAci =
|
||||
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeAci =
|
||||
submap(
|
||||
targets,
|
||||
new AciServiceIdentifier(SINGLE_DEVICE_UUID),
|
||||
new AciServiceIdentifier(MULTI_DEVICE_UUID),
|
||||
new AciServiceIdentifier(NONEXISTENT_UUID));
|
||||
|
||||
final boolean auth = true;
|
||||
final boolean unauth = false;
|
||||
final boolean story = true;
|
||||
final boolean notStory = false;
|
||||
|
||||
return ArgumentSets
|
||||
.argumentsForFirstParameter(
|
||||
new MultiRecipientMessageTestCase(singleDeviceAci, unauth, story, 200, 1),
|
||||
new MultiRecipientMessageTestCase(multiDeviceAci, unauth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsAci, unauth, story, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeAci, unauth, story, 200, 3),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDeviceAci, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(multiDeviceAci, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(bothAccountsAci, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(realAndFakeAci, unauth, notStory, 404, 0),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDeviceAci, auth, story, 200, 1),
|
||||
new MultiRecipientMessageTestCase(multiDeviceAci, auth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsAci, auth, story, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeAci, auth, story, 200, 3),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDeviceAci, auth, notStory, 200, 1),
|
||||
new MultiRecipientMessageTestCase(multiDeviceAci, auth, notStory, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsAci, auth, notStory, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeAci, auth, notStory, 404, 0))
|
||||
.argumentsForNextParameter(false, true) // urgent
|
||||
.argumentsForNextParameter(false, true); // explicitIdentifiers
|
||||
}
|
||||
|
||||
@CartesianTest
|
||||
@CartesianTest.MethodFactory("testMultiRecipientMessagePni")
|
||||
void testMultiRecipientMessagePni(MultiRecipientMessageTestCase testCase, boolean urgent) throws Exception {
|
||||
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, true, testCase.expectedStatus(), testCase.expectedSentMessages());
|
||||
}
|
||||
|
||||
private static ArgumentSets testMultiRecipientMessagePni() {
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAciAndPni = submap(
|
||||
targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsMixed =
|
||||
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeMixed =
|
||||
submap(
|
||||
targets,
|
||||
new PniServiceIdentifier(SINGLE_DEVICE_PNI),
|
||||
new AciServiceIdentifier(MULTI_DEVICE_UUID),
|
||||
new PniServiceIdentifier(NONEXISTENT_UUID));
|
||||
|
||||
final boolean auth = true;
|
||||
final boolean unauth = false;
|
||||
final boolean story = true;
|
||||
final boolean notStory = false;
|
||||
|
||||
return ArgumentSets
|
||||
.argumentsForFirstParameter(
|
||||
new MultiRecipientMessageTestCase(singleDevicePni, unauth, story, 200, 1),
|
||||
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(multiDevicePni, unauth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, story, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, story, 200, 3),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDevicePni, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(multiDevicePni, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, notStory, 404, 0),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDevicePni, auth, story, 200, 1),
|
||||
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, auth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(multiDevicePni, auth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsMixed, auth, story, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeMixed, auth, story, 200, 3),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDevicePni, auth, notStory, 200, 1),
|
||||
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(multiDevicePni, auth, notStory, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsMixed, auth, notStory, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeMixed, auth, notStory, 404, 0))
|
||||
.argumentsForNextParameter(false, true); // urgent
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -1148,7 +1243,7 @@ class MessageControllerTest {
|
|||
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
|
||||
final List<Recipient> recipients = List.of(
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]),
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
|
||||
|
||||
Response response = resources
|
||||
|
@ -1346,12 +1441,12 @@ class MessageControllerTest {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier)
|
||||
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier, final int regId1, final int regId2)
|
||||
throws NotPushRegisteredException, InterruptedException {
|
||||
|
||||
final List<Recipient> recipients = List.of(
|
||||
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
|
||||
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
|
||||
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, regId1, new byte[48]),
|
||||
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, regId2, new byte[48]));
|
||||
|
||||
// initialize our binary payload and create an input stream
|
||||
byte[] buffer = new byte[2048];
|
||||
|
@ -1384,8 +1479,8 @@ class MessageControllerTest {
|
|||
|
||||
private static Stream<Arguments> sendMultiRecipientMessage404() {
|
||||
return Stream.of(
|
||||
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
|
||||
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
|
||||
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2),
|
||||
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
|
||||
}
|
||||
|
||||
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
|
||||
|
@ -1393,14 +1488,6 @@ class MessageControllerTest {
|
|||
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
|
||||
}
|
||||
|
||||
private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception {
|
||||
assertThat("Unexpected response", response.getStatus(), is(equalTo(200)));
|
||||
ArgumentCaptor<List<Callable<Void>>> captor = ArgumentCaptor.forClass(List.class);
|
||||
verify(messageSender, times(expectedCount)).sendMessage(any(), any(), any(), anyBoolean());
|
||||
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
|
||||
assert (smrmr.uuids404().isEmpty());
|
||||
}
|
||||
|
||||
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
|
||||
byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
|
||||
return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false);
|
||||
|
@ -1433,64 +1520,4 @@ class MessageControllerTest {
|
|||
return builder.build();
|
||||
}
|
||||
|
||||
private static Recipient genRecipient(Random rng) {
|
||||
UUID u1 = UUID.randomUUID(); // non-null
|
||||
byte d1 = (byte) (rng.nextInt(127) + 1); // 1 to 127
|
||||
int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
|
||||
byte[] perKeyBytes = new byte[48]; // size=48, non-null
|
||||
rng.nextBytes(perKeyBytes);
|
||||
return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes);
|
||||
}
|
||||
|
||||
private static void roundTripVarint(byte expected, byte[] bytes) throws Exception {
|
||||
ByteBuffer bb = ByteBuffer.wrap(bytes);
|
||||
writePayloadDeviceId(bb, expected);
|
||||
InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position());
|
||||
long got = MultiRecipientMessageProvider.readVarint(stream);
|
||||
assertEquals(expected, got, String.format("encoded as: %s", Arrays.toString(bytes)));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testVarintPayload() throws Exception {
|
||||
Random rng = new Random();
|
||||
byte[] bytes = new byte[12];
|
||||
|
||||
// some static test cases
|
||||
for (byte i = 1; i <= 10; i++) {
|
||||
roundTripVarint(i, bytes);
|
||||
}
|
||||
roundTripVarint(Byte.MAX_VALUE, bytes);
|
||||
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
// we need to ensure positive device IDs
|
||||
byte start = (byte) rng.nextInt(128);
|
||||
if (start == 0L) {
|
||||
start = 1;
|
||||
}
|
||||
|
||||
// run the test for this case
|
||||
roundTripVarint(start, bytes);
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = {true, false})
|
||||
void testMultiPayloadRoundtrip(final boolean useExplicitIdentifiers) throws Exception {
|
||||
Random rng = new java.util.Random();
|
||||
List<Recipient> expected = new LinkedList<>();
|
||||
for(int i = 0; i < 100; i++) {
|
||||
expected.add(genRecipient(rng));
|
||||
}
|
||||
|
||||
byte[] buffer = new byte[100 + expected.size() * 100];
|
||||
InputStream entityStream = initializeMultiPayload(expected, buffer, useExplicitIdentifiers);
|
||||
MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider();
|
||||
// the provider ignores the headers, java reflection, etc. so we don't use those here.
|
||||
MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream);
|
||||
List<Recipient> got = Arrays.asList(res.recipients());
|
||||
|
||||
assertEquals(expected, got);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -10,10 +10,7 @@ import static org.mockito.Mockito.doNothing;
|
|||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.internal.exceptions.Reporter.noMoreInteractionsWanted;
|
||||
import static org.mockito.internal.exceptions.Reporter.wantedButNotInvoked;
|
||||
import static org.mockito.internal.invocation.InvocationMarker.markVerified;
|
||||
import static org.mockito.internal.invocation.InvocationsFinder.findFirstUnverified;
|
||||
import static org.mockito.internal.invocation.InvocationsFinder.findInvocations;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
@ -169,10 +166,17 @@ public final class MockUtils {
|
|||
* this method
|
||||
*/
|
||||
public static VerificationMode exactly() {
|
||||
return exactly(1);
|
||||
}
|
||||
|
||||
/**
|
||||
* a combination of {@link #exactly()} and {@link org.mockito.Mockito#times(int)}, verifies that
|
||||
* there are exactly N invocations of this method, and all of them match the given specification
|
||||
*/
|
||||
public static VerificationMode exactly(int wantedCount) {
|
||||
return data -> {
|
||||
MatchableInvocation target = data.getTarget();
|
||||
final List<Invocation> allInvocations = data.getAllInvocations();
|
||||
List<Invocation> chunk = findInvocations(allInvocations, target);
|
||||
List<Invocation> otherInvocations = allInvocations.stream()
|
||||
.filter(target::hasSameMethod)
|
||||
.filter(Predicate.not(target::matches))
|
||||
|
@ -182,10 +186,7 @@ public final class MockUtils {
|
|||
Invocation unverified = findFirstUnverified(otherInvocations);
|
||||
throw noMoreInteractionsWanted(unverified, (List) allInvocations);
|
||||
}
|
||||
if (chunk.isEmpty()) {
|
||||
throw wantedButNotInvoked(target);
|
||||
}
|
||||
markVerified(chunk.get(0), target);
|
||||
Mockito.times(wantedCount).verify(data);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue