Revert "multisend cleanup"

This reverts commit c03249b411.
This commit is contained in:
Jonathan Klabunde Tomer 2023-12-01 14:39:31 -08:00
parent c03249b411
commit 20392a567b
3 changed files with 343 additions and 344 deletions

View File

@ -16,12 +16,6 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
@ -43,8 +37,6 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
@ -95,7 +87,6 @@ import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageRespon
import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -120,8 +111,6 @@ import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.Stories; import org.whispersystems.websocket.Stories;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ -129,12 +118,6 @@ import reactor.core.scheduler.Scheduler;
@io.swagger.v3.oas.annotations.tags.Tag(name = "Messages") @io.swagger.v3.oas.annotations.tags.Tag(name = "Messages")
public class MessageController { public class MessageController {
private record MessageRecipient(
ServiceIdentifier serviceIdentifier,
Account account,
Map<Byte, Recipient> perDeviceData) {
}
private static final Logger logger = LoggerFactory.getLogger(MessageController.class); private static final Logger logger = LoggerFactory.getLogger(MessageController.class);
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
@ -155,9 +138,9 @@ public class MessageController {
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize"); private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize");
private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes"); private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes");
private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage"); private static final String RATE_LIMITED_MESSAGE_COUNTER_NAME = name(MessageController.class, "rateLimitedMessage");
private static final String RATE_LIMITED_STORIES_COUNTER_NAME = name(MessageController.class, "rateLimitedStory");
private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType"); private static final String REJECT_INVALID_ENVELOPE_TYPE = name(MessageController.class, "rejectInvalidEnvelopeType");
private static final String UNEXPECTED_MISSING_USER_COUNTER_NAME = name(MessageController.class, "unexpectedMissingDestinationForMultiRecipientMessage");
private static final String EPHEMERAL_TAG_NAME = "ephemeral"; private static final String EPHEMERAL_TAG_NAME = "ephemeral";
private static final String SENDER_TYPE_TAG_NAME = "senderType"; private static final String SENDER_TYPE_TAG_NAME = "senderType";
@ -360,25 +343,26 @@ public class MessageController {
/** /**
* Build mapping of service IDs to resolved accounts and device/registration IDs * Build mapping of accounts to devices/registration IDs.
*/ */
private Map<ServiceIdentifier, MessageRecipient> buildRecipientMap( private Map<Account, Set<Pair<Byte, Integer>>> buildDeviceIdAndRegistrationIdMap(
MultiRecipientMessage multiRecipientMessage, boolean isStory) { MultiRecipientMessage multiRecipientMessage,
return Flux.fromArray(multiRecipientMessage.recipients()) Map<ServiceIdentifier, Account> accountsByServiceIdentifier) {
.groupBy(Recipient::uuid)
.flatMap( return Arrays.stream(multiRecipientMessage.recipients())
gf -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(gf.key())) // for normal messages, all recipients UUIDs are in the map,
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new)) // but story messages might specify inactive UUIDs, which we
.flatMap( // have previously filtered
account -> .filter(r -> accountsByServiceIdentifier.containsKey(r.uuid()))
gf.collectMap(Recipient::deviceId) .collect(Collectors.toMap(
.map(perRecipientData -> recipient -> accountsByServiceIdentifier.get(recipient.uuid()),
new MessageRecipient( recipient -> new HashSet<>(
gf.key(), Collections.singletonList(new Pair<>(recipient.deviceId(), recipient.registrationId()))),
account, (a, b) -> {
perRecipientData)))) a.addAll(b);
.collectMap(MessageRecipient::serviceIdentifier) return a;
.block(); }
));
} }
@Timed @Timed
@ -387,87 +371,79 @@ public class MessageController {
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE) @Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@FilterSpam @FilterSpam
@Operation(
summary = "Send multi-recipient sealed-sender message",
description = """
Deliver a common-payload message to multiple recipients.
An unidentifed-access key for all recipients must be provided, unless the message is a story.
""")
@ApiResponse(responseCode="200", description="Message was successfully sent to all recipients", useReturnTypeSchema=true)
@ApiResponse(responseCode="400", description="The envelope specified delivery to the same recipient device multiple times")
@ApiResponse(responseCode="401", description="The message is not a story and the unauthorized access key is incorrect")
@ApiResponse(
responseCode="404",
description="The message is not a story and some of the recipient service IDs do not correspond to registered Signal users")
@ApiResponse(
responseCode = "409", description = "Incorrect set of devices supplied for some recipients",
content = @Content(schema = @Schema(implementation = AccountMismatchedDevices[].class)))
@ApiResponse(
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
content = @Content(schema = @Schema(implementation = AccountStaleDevices[].class)))
public Response sendMultiRecipientMessage( public Response sendMultiRecipientMessage(
@Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message")
@HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys, @HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@Parameter(description="If true, deliver the message only to recipients that are online when it is sent")
@QueryParam("online") boolean online, @QueryParam("online") boolean online,
@Parameter(description="The sender's timestamp for the envelope")
@QueryParam("ts") long timestamp, @QueryParam("ts") long timestamp,
@Parameter(description="If true, this message should cause push notifications to be sent to recipients")
@QueryParam("urgent") @DefaultValue("true") final boolean isUrgent, @QueryParam("urgent") @DefaultValue("true") final boolean isUrgent,
@Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted")
@QueryParam("story") boolean isStory, @QueryParam("story") boolean isStory,
@Parameter(description="The sealed-sender multi-recipient message payload")
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException { @NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {
final Map<ServiceIdentifier, MessageRecipient> recipients = buildRecipientMap(multiRecipientMessage, isStory); final Map<ServiceIdentifier, Account> accountsByServiceIdentifier = new HashMap<>();
for (final Recipient recipient : multiRecipientMessage.recipients()) {
if (!accountsByServiceIdentifier.containsKey(recipient.uuid())) {
final Optional<Account> maybeAccount = accountsManager.getByServiceIdentifier(recipient.uuid());
if (maybeAccount.isPresent()) {
accountsByServiceIdentifier.put(recipient.uuid(), maybeAccount.get());
} else {
if (!isStory) {
throw new NotFoundException();
}
}
}
}
// Stories will be checked by the client; we bypass access checks here for stories. // Stories will be checked by the client; we bypass access checks here for stories.
if (!isStory) { if (!isStory) {
checkAccessKeys(accessKeys, recipients.values()); checkAccessKeys(accessKeys, accountsByServiceIdentifier.values());
} }
// We might filter out all the recipients of a story (if none exist). final Map<Account, Set<Pair<Byte, Integer>>> accountToDeviceIdAndRegistrationIdMap =
buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, accountsByServiceIdentifier);
// We might filter out all the recipients of a story (if none have enabled stories).
// In this case there is no error so we should just return 200 now. // In this case there is no error so we should just return 200 now.
if (isStory) { if (isStory && accountToDeviceIdAndRegistrationIdMap.isEmpty()) {
if (recipients.isEmpty()) { return Response.ok(new SendMultiRecipientMessageResponse(new LinkedList<>())).build();
return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build();
}
for (MessageRecipient recipient : recipients.values()) {
rateLimiters.getStoriesLimiter().validate(recipient.account().getUuid());
}
} }
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>(); Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>(); Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
recipients.values().forEach(recipient -> {
final Account account = recipient.account();
try { for (Map.Entry<ServiceIdentifier, Account> entry : accountsByServiceIdentifier.entrySet()) {
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.perDeviceData().keySet(), Collections.emptySet()); final ServiceIdentifier serviceIdentifier = entry.getKey();
final Account account = entry.getValue();
if (isStory) {
rateLimiters.getStoriesLimiter().validate(account.getUuid());
}
Set<Byte> deviceIds = accountToDeviceIdAndRegistrationIdMap
.getOrDefault(account, Collections.emptySet())
.stream()
.map(Pair::first)
.collect(Collectors.toSet());
try {
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet());
// Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number
// identity
DestinationDeviceValidator.validateRegistrationIds(
account,
accountToDeviceIdAndRegistrationIdMap.get(account).stream(),
false);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(serviceIdentifier,
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(new AccountStaleDevices(serviceIdentifier, new StaleDevices(e.getStaleDevices())));
}
}
DestinationDeviceValidator.validateRegistrationIds(
account,
recipient.perDeviceData().values(),
Recipient::deviceId,
Recipient::registrationId,
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
recipient.serviceIdentifier(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices())));
}
});
if (!accountMismatchedDevices.isEmpty()) { if (!accountMismatchedDevices.isEmpty()) {
return Response return Response
.status(409) .status(409)
@ -492,28 +468,25 @@ public class MessageController {
Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED))); Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED)));
CompletableFuture.allOf( CompletableFuture.allOf(
recipients.values().stream() Arrays.stream(multiRecipientMessage.recipients())
.flatMap(recipientData -> // If we're sending a story, some recipients might not map to existing accounts
recipientData.perDeviceData().values().stream().map( .filter(recipient -> accountsByServiceIdentifier.containsKey(recipient.uuid()))
recipient -> CompletableFuture.runAsync( .map(
() -> { recipient -> CompletableFuture.runAsync(
final Account destinationAccount = recipientData.account(); () -> {
// we asserted this must exist in validateCompleteDeviceList Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid());
final Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
try { // we asserted this must exist in validateCompleteDeviceList
sentMessageCounter.increment(); Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
sendCommonPayloadMessage( sentMessageCounter.increment();
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, online, try {
isStory, isUrgent, recipient, multiRecipientMessage.commonPayload()); sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent,
} catch (NoSuchUserException e) { recipient, multiRecipientMessage.commonPayload());
// this should never happen, because we already asserted the device is present and enabled } catch (NoSuchUserException e) {
Metrics.counter( uuids404.add(recipient.uuid());
UNEXPECTED_MISSING_USER_COUNTER_NAME, }
Tags.of("isPrimary", String.valueOf(destinationDevice.isPrimary()))).increment(); },
uuids404.add(recipientData.serviceIdentifier()); multiRecipientMessageExecutor))
}
},
multiRecipientMessageExecutor)))
.toArray(CompletableFuture[]::new)) .toArray(CompletableFuture[]::new))
.get(); .get();
} catch (InterruptedException e) { } catch (InterruptedException e) {
@ -529,31 +502,43 @@ public class MessageController {
return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build(); return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build();
} }
private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<MessageRecipient> destinations) { private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<Account> destinationAccounts) {
// We should not have null access keys when checking access; bail out early. // We should not have null access keys when checking access; bail out early.
if (accessKeys == null) { if (accessKeys == null) {
throw new WebApplicationException(Status.UNAUTHORIZED); throw new WebApplicationException(Status.UNAUTHORIZED);
} }
destinations.stream() AtomicBoolean throwUnauthorized = new AtomicBoolean(false);
.map(MessageRecipient::account) byte[] empty = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH];
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess)) final Optional<byte[]> UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
.map(account -> account.getUnidentifiedAccessKey().orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED))) byte[] combinedUnknownAccessKeys = destinationAccounts.stream()
.reduce( .map(account -> {
(bytes, bytes2) -> { if (account.isUnrestrictedUnidentifiedAccess()) {
if (bytes.length != bytes2.length) { return UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY;
throw new WebApplicationException(Status.UNAUTHORIZED); } else {
} return account.getUnidentifiedAccessKey();
for (int i = 0; i < bytes.length; i++) { }
bytes[i] ^= bytes2[i]; })
} .map(accessKey -> {
return bytes; if (accessKey.isEmpty()) {
}) throwUnauthorized.set(true);
.ifPresent( return empty;
combinedUnidentifiedAccessKeys -> { }
if (!MessageDigest.isEqual(combinedUnidentifiedAccessKeys, accessKeys.getAccessKeys())) { return accessKey.get();
throw new WebApplicationException(Status.UNAUTHORIZED); })
} .reduce(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH], (bytes, bytes2) -> {
}); if (bytes.length != bytes2.length) {
throwUnauthorized.set(true);
return bytes;
}
for (int i = 0; i < bytes.length; i++) {
bytes[i] ^= bytes2[i];
}
return bytes;
});
if (throwUnauthorized.get()
|| !MessageDigest.isEqual(combinedUnknownAccessKeys, accessKeys.getAccessKeys())) {
throw new WebApplicationException(Status.UNAUTHORIZED);
}
} }
@Timed @Timed
@ -731,7 +716,6 @@ public class MessageController {
private void sendCommonPayloadMessage(Account destinationAccount, private void sendCommonPayloadMessage(Account destinationAccount,
Device destinationDevice, Device destinationDevice,
ServiceIdentifier serviceIdentifier,
long timestamp, long timestamp,
boolean online, boolean online,
boolean story, boolean story,
@ -755,7 +739,7 @@ public class MessageController {
.setContent(ByteString.copyFrom(payload)) .setContent(ByteString.copyFrom(payload))
.setStory(story) .setStory(story)
.setUrgent(urgent) .setUrgent(urgent)
.setDestinationUuid(serviceIdentifier.toServiceIdentifierString()); .setDestinationUuid(new AciServiceIdentifier(destinationAccount.getUuid()).toServiceIdentifierString());
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
} catch (NotPushRegisteredException e) { } catch (NotPushRegisteredException e) {

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.controllers;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.collection.IsEmptyCollection.empty;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
@ -30,7 +29,6 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
import static org.whispersystems.textsecuregcm.util.MockUtils.exactly;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
@ -44,23 +42,21 @@ import java.io.ByteArrayInputStream;
import java.io.InputStream; import java.io.InputStream;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Random;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation; import javax.ws.rs.client.Invocation;
@ -77,11 +73,8 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsSources;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
@ -99,6 +92,8 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
@ -144,7 +139,6 @@ class MessageControllerTest {
private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111"); private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111");
private static final byte SINGLE_DEVICE_ID1 = 1; private static final byte SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111; private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222"); private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222");
@ -155,11 +149,6 @@ class MessageControllerTest {
private static final int MULTI_DEVICE_REG_ID1 = 222; private static final int MULTI_DEVICE_REG_ID1 = 222;
private static final int MULTI_DEVICE_REG_ID2 = 333; private static final int MULTI_DEVICE_REG_ID2 = 333;
private static final int MULTI_DEVICE_REG_ID3 = 444; private static final int MULTI_DEVICE_REG_ID3 = 444;
private static final int MULTI_DEVICE_PNI_REG_ID1 = 2222;
private static final int MULTI_DEVICE_PNI_REG_ID2 = 3333;
private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444;
private static final UUID NONEXISTENT_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes(); private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
@ -203,13 +192,13 @@ class MessageControllerTest {
final List<Device> singleDeviceList = List.of( final List<Device> singleDeviceList = List.of(
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()) generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis())
); );
final List<Device> multiDeviceList = List.of( final List<Device> multiDeviceList = List.of(
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
); );
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
@ -222,8 +211,6 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration = final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration =
mock(DynamicInboundMessageByteLimitConfiguration.class); mock(DynamicInboundMessageByteLimitConfiguration.class);
@ -935,21 +922,25 @@ class MessageControllerTest {
); );
} }
private record Recipient(ServiceIdentifier uuid, private static void writePayloadDeviceId(ByteBuffer bb, byte deviceId) {
byte deviceId, long x = deviceId;
int registrationId, // write the device-id in the 7-bit varint format we use, least significant bytes first.
byte[] perRecipientKeyMaterial) { do {
long b = x & 0x7f;
x = x >>> 7;
if (x != 0) b |= 0x80;
bb.put((byte)b);
} while (x != 0);
} }
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, final boolean useExplicitIdentifier) {
final boolean useExplicitIdentifier) {
if (useExplicitIdentifier) { if (useExplicitIdentifier) {
bb.put(r.uuid().toFixedWidthByteArray()); bb.put(r.uuid().toFixedWidthByteArray());
} else { } else {
bb.put(UUIDUtil.toBytes(r.uuid().uuid())); bb.put(UUIDUtil.toBytes(r.uuid().uuid()));
} }
bb.put(r.deviceId()); // device id (1 byte) writePayloadDeviceId(bb, r.deviceId()); // device id (1-9 bytes)
bb.putShort((short) r.registrationId()); // registration id (2 bytes) bb.putShort((short) r.registrationId()); // registration id (2 bytes)
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes) bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
} }
@ -962,8 +953,8 @@ class MessageControllerTest {
// first write the header // first write the header
bb.put(explicitIdentifiers bb.put(explicitIdentifiers
? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER ? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte : MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
bb.put((byte)recipients.size()); // count varint bb.put((byte)recipients.size()); // count varint
Iterator<Recipient> it = recipients.iterator(); Iterator<Recipient> it = recipients.iterator();
while (it.hasNext()) { while (it.hasNext()) {
@ -977,24 +968,23 @@ class MessageControllerTest {
return new ByteArrayInputStream(buffer, 0, bb.position()); return new ByteArrayInputStream(buffer, 0, bb.position());
} }
// see testMultiRecipientMessageNoPni and testMultiRecipientMessagePni below for actual invocations @ParameterizedTest
private void testMultiRecipientMessage( @MethodSource
Map<ServiceIdentifier, Map<Byte, Integer>> destinations, void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent, boolean explicitIdentifier) throws Exception {
boolean authorize,
boolean isStory, final List<Recipient> recipients;
boolean urgent, if (recipientUUID == MULTI_DEVICE_UUID) {
boolean explicitIdentifier, recipients = List.of(
int expectedStatus, new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
int expectedMessagesSent) throws Exception { new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
final List<Recipient> recipients = new ArrayList<>(); );
destinations.forEach( } else {
(serviceIdentifier, deviceToRegistrationId) -> recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
deviceToRegistrationId.forEach( }
(deviceId, registrationId) ->
recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
// initialize our binary payload and create an input stream // initialize our binary payload and create an input stream
byte[] buffer = new byte[2048]; byte[] buffer = new byte[2048];
//InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier); InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier);
// set up the entity to use in our PUT request // set up the entity to use in our PUT request
@ -1013,160 +1003,124 @@ class MessageControllerTest {
// add access header if needed // add access header if needed
if (authorize) { if (authorize) {
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count(); String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES);
String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes); bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes);
} }
// make the PUT request // make the PUT request
Response response = bldr.put(entity); Response response = bldr.put(entity);
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus))); if (authorize) {
verify(messageSender, ArgumentCaptor<Envelope> envelopeArgumentCaptor = ArgumentCaptor.forClass(Envelope.class);
exactly(expectedMessagesSent)) verify(messageSender, atLeastOnce()).sendMessage(any(), any(), envelopeArgumentCaptor.capture(), anyBoolean());
.sendMessage( assertEquals(urgent, envelopeArgumentCaptor.getValue().getUrgent());
any(), }
any(),
argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()), // We have a 2x2x2 grid of possible situations based on:
anyBoolean()); // - recipient enabled stories?
if (expectedStatus == 200) { // - sender is authorized?
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); // - message is a story?
assertThat(smrmr.uuids404(), is(empty())); //
// (urgent is not included in the grid because it has no effect
// on any of the other settings.)
if (recipientUUID == MULTI_DEVICE_UUID) {
// This is the case where the recipient has enabled stories.
if(isStory) {
// We are sending a story, so we ignore access checks and expect this
// to go out to both the recipient's devices.
checkGoodMultiRecipientResponse(response, 2);
} else {
// We are not sending a story, so we need to do access checks.
if (authorize) {
// When authorized we send a message to the recipient's devices.
checkGoodMultiRecipientResponse(response, 2);
} else {
// When forbidden, we return a 401 error.
checkBadMultiRecipientResponse(response, 401);
}
}
} else {
// This is the case where the recipient has not enabled stories.
if (isStory) {
// We are sending a story, so we ignore access checks.
// this recipient has one device.
checkGoodMultiRecipientResponse(response, 1);
} else {
// We are not sending a story so check access.
if (authorize) {
// If allowed, send a message to the recipient's one device.
checkGoodMultiRecipientResponse(response, 1);
} else {
// If forbidden, return a 401 error.
checkBadMultiRecipientResponse(response, 401);
}
}
} }
} }
@SafeVarargs // Arguments here are: recipient-UUID, is-authorized?, is-story?
private static <K, V> Map<K, V> submap(Map<K, V> map, K... keys) { private static Stream<Arguments> testMultiRecipientMessage() {
return Arrays.stream(keys).collect(Collectors.toMap(Function.identity(), map::get)); return Stream.of(
Arguments.of(MULTI_DEVICE_UUID, false, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, true, false, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, false, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, true)
);
} }
private static Map<ServiceIdentifier, Map<Byte, Integer>> multiRecipientTargetMap() { @Test
return void testMultiRecipientMessageToAccountsSomeOfWhichDoNotExist() throws Exception {
Map.of( UUID badUUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(badUUID))).thenReturn(Optional.empty());
new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2),
new PniServiceIdentifier(MULTI_DEVICE_PNI),
Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2),
new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1)
);
}
private record MultiRecipientMessageTestCase( final List<Recipient> recipients = List.of(
Map<ServiceIdentifier, Map<Byte, Integer>> destinations, new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1,
boolean authenticated, new byte[48]),
boolean story, new Recipient(new AciServiceIdentifier(badUUID), (byte) 1, 1, new byte[48]));
int expectedStatus,
int expectedSentMessages) {
}
@CartesianTest Response response = resources
@CartesianTest.MethodFactory("testMultiRecipientMessageNoPni") .getJerseyTest()
void testMultiRecipientMessageNoPni(MultiRecipientMessageTestCase testCase, boolean urgent , boolean explicitIdentifier) throws Exception { .target("/v1/messages/multi_recipient")
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages()); .queryParam("online", true)
} .queryParam("ts", 1700000000000L)
.queryParam("story", true)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], true),
MultiRecipientMessageProvider.MEDIA_TYPE));
private static ArgumentSets testMultiRecipientMessageNoPni() { checkGoodMultiRecipientResponse(response, 1);
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsAci =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeAci =
submap(
targets,
new AciServiceIdentifier(SINGLE_DEVICE_UUID),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
new AciServiceIdentifier(NONEXISTENT_UUID));
final boolean auth = true;
final boolean unauth = false;
final boolean story = true;
final boolean notStory = false;
return ArgumentSets
.argumentsForFirstParameter(
new MultiRecipientMessageTestCase(singleDeviceAci, unauth, story, 200, 1),
new MultiRecipientMessageTestCase(multiDeviceAci, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsAci, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeAci, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDeviceAci, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(multiDeviceAci, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(bothAccountsAci, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(realAndFakeAci, unauth, notStory, 404, 0),
new MultiRecipientMessageTestCase(singleDeviceAci, auth, story, 200, 1),
new MultiRecipientMessageTestCase(multiDeviceAci, auth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsAci, auth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeAci, auth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDeviceAci, auth, notStory, 200, 1),
new MultiRecipientMessageTestCase(multiDeviceAci, auth, notStory, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsAci, auth, notStory, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeAci, auth, notStory, 404, 0))
.argumentsForNextParameter(false, true) // urgent
.argumentsForNextParameter(false, true); // explicitIdentifiers
}
@CartesianTest
@CartesianTest.MethodFactory("testMultiRecipientMessagePni")
void testMultiRecipientMessagePni(MultiRecipientMessageTestCase testCase, boolean urgent) throws Exception {
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, true, testCase.expectedStatus(), testCase.expectedSentMessages());
}
private static ArgumentSets testMultiRecipientMessagePni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAciAndPni = submap(
targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsMixed =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeMixed =
submap(
targets,
new PniServiceIdentifier(SINGLE_DEVICE_PNI),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
new PniServiceIdentifier(NONEXISTENT_UUID));
final boolean auth = true;
final boolean unauth = false;
final boolean story = true;
final boolean notStory = false;
return ArgumentSets
.argumentsForFirstParameter(
new MultiRecipientMessageTestCase(singleDevicePni, unauth, story, 200, 1),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(multiDevicePni, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDevicePni, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(multiDevicePni, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, notStory, 404, 0),
new MultiRecipientMessageTestCase(singleDevicePni, auth, story, 200, 1),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, auth, story, 200, 2),
new MultiRecipientMessageTestCase(multiDevicePni, auth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsMixed, auth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeMixed, auth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDevicePni, auth, notStory, 200, 1),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(multiDevicePni, auth, notStory, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsMixed, auth, notStory, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeMixed, auth, notStory, 404, 0))
.argumentsForNextParameter(false, true); // urgent
} }
@ParameterizedTest @ParameterizedTest
@ -1174,7 +1128,7 @@ class MessageControllerTest {
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception { void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
final List<Recipient> recipients = List.of( final List<Recipient> recipients = List.of(
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48])); new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
Response response = resources Response response = resources
@ -1372,12 +1326,12 @@ class MessageControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier, final int regId1, final int regId2) void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier)
throws NotPushRegisteredException, InterruptedException { throws NotPushRegisteredException, InterruptedException {
final List<Recipient> recipients = List.of( final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, regId1, new byte[48]), new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, regId2, new byte[48])); new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
// initialize our binary payload and create an input stream // initialize our binary payload and create an input stream
byte[] buffer = new byte[2048]; byte[] buffer = new byte[2048];
@ -1410,8 +1364,8 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessage404() { private static Stream<Arguments> sendMultiRecipientMessage404() {
return Stream.of( return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2), Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2)); Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
} }
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
@ -1419,6 +1373,14 @@ class MessageControllerTest {
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
} }
private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(200)));
ArgumentCaptor<List<Callable<Void>>> captor = ArgumentCaptor.forClass(List.class);
verify(messageSender, times(expectedCount)).sendMessage(any(), any(), any(), anyBoolean());
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assert (smrmr.uuids404().isEmpty());
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false); return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false);
@ -1451,4 +1413,64 @@ class MessageControllerTest {
return builder.build(); return builder.build();
} }
private static Recipient genRecipient(Random rng) {
UUID u1 = UUID.randomUUID(); // non-null
byte d1 = (byte) (rng.nextInt(127) + 1); // 1 to 127
int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
byte[] perKeyBytes = new byte[48]; // size=48, non-null
rng.nextBytes(perKeyBytes);
return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes);
}
private static void roundTripVarint(byte expected, byte[] bytes) throws Exception {
ByteBuffer bb = ByteBuffer.wrap(bytes);
writePayloadDeviceId(bb, expected);
InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position());
long got = MultiRecipientMessageProvider.readVarint(stream);
assertEquals(expected, got, String.format("encoded as: %s", Arrays.toString(bytes)));
}
@Test
void testVarintPayload() throws Exception {
Random rng = new Random();
byte[] bytes = new byte[12];
// some static test cases
for (byte i = 1; i <= 10; i++) {
roundTripVarint(i, bytes);
}
roundTripVarint(Byte.MAX_VALUE, bytes);
for (int i = 0; i < 1000; i++) {
// we need to ensure positive device IDs
byte start = (byte) rng.nextInt(128);
if (start == 0L) {
start = 1;
}
// run the test for this case
roundTripVarint(start, bytes);
}
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiPayloadRoundtrip(final boolean useExplicitIdentifiers) throws Exception {
Random rng = new java.util.Random();
List<Recipient> expected = new LinkedList<>();
for(int i = 0; i < 100; i++) {
expected.add(genRecipient(rng));
}
byte[] buffer = new byte[100 + expected.size() * 100];
InputStream entityStream = initializeMultiPayload(expected, buffer, useExplicitIdentifiers);
MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider();
// the provider ignores the headers, java reflection, etc. so we don't use those here.
MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream);
List<Recipient> got = Arrays.asList(res.recipients());
assertEquals(expected, got);
}
} }

View File

@ -10,8 +10,6 @@ import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.internal.exceptions.Reporter.noMoreInteractionsWanted; import static org.mockito.internal.exceptions.Reporter.noMoreInteractionsWanted;
import static org.mockito.internal.exceptions.Reporter.tooFewActualInvocations;
import static org.mockito.internal.exceptions.Reporter.tooManyActualInvocations;
import static org.mockito.internal.exceptions.Reporter.wantedButNotInvoked; import static org.mockito.internal.exceptions.Reporter.wantedButNotInvoked;
import static org.mockito.internal.invocation.InvocationMarker.markVerified; import static org.mockito.internal.invocation.InvocationMarker.markVerified;
import static org.mockito.internal.invocation.InvocationsFinder.findFirstUnverified; import static org.mockito.internal.invocation.InvocationsFinder.findFirstUnverified;
@ -28,7 +26,6 @@ import org.mockito.Mockito;
import org.mockito.invocation.Invocation; import org.mockito.invocation.Invocation;
import org.mockito.invocation.MatchableInvocation; import org.mockito.invocation.MatchableInvocation;
import org.mockito.verification.VerificationMode; import org.mockito.verification.VerificationMode;
import org.mockito.internal.verification.Times;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes; import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
@ -174,17 +171,10 @@ public final class MockUtils {
* this method * this method
*/ */
public static VerificationMode exactly() { public static VerificationMode exactly() {
return exactly(1);
}
/**
* a combination of {@link #exactly()} and {@link org.mockito.Mockito#times(int)}, verifies that
* there are exactly N invocations of this method, and all of them match the given specification
*/
public static VerificationMode exactly(int wantedCount) {
return data -> { return data -> {
MatchableInvocation target = data.getTarget(); MatchableInvocation target = data.getTarget();
final List<Invocation> allInvocations = data.getAllInvocations(); final List<Invocation> allInvocations = data.getAllInvocations();
List<Invocation> chunk = findInvocations(allInvocations, target);
List<Invocation> otherInvocations = allInvocations.stream() List<Invocation> otherInvocations = allInvocations.stream()
.filter(target::hasSameMethod) .filter(target::hasSameMethod)
.filter(Predicate.not(target::matches)) .filter(Predicate.not(target::matches))
@ -194,7 +184,10 @@ public final class MockUtils {
Invocation unverified = findFirstUnverified(otherInvocations); Invocation unverified = findFirstUnverified(otherInvocations);
throw noMoreInteractionsWanted(unverified, (List) allInvocations); throw noMoreInteractionsWanted(unverified, (List) allInvocations);
} }
Mockito.times(wantedCount).verify(data); if (chunk.isEmpty()) {
throw wantedButNotInvoked(target);
}
markVerified(chunk.get(0), target);
}; };
} }