From 3e0baf82a49fb530283848a194da5ca0d83ce8bc Mon Sep 17 00:00:00 2001
From: erik-signal <113138376+erik-signal@users.noreply.github.com>
Date: Thu, 13 Oct 2022 15:33:51 -0400
Subject: [PATCH] Filter unknown UUIDs for /multi_recipient&story=true.
---
.../controllers/MessageController.java | 56 +++++----
.../controllers/MessageControllerTest.java | 114 ++++++++++++++----
2 files changed, 122 insertions(+), 48 deletions(-)
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java
index e9416a41b..e9fc2fa72 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java
@@ -287,9 +287,6 @@ public class MessageController {
/**
* Build mapping of accounts to devices/registration IDs.
- *
- * Messages that are stories will only be sent to the subset of recipients who have indicated they want to receive
- * stories.
*
* @param multiRecipientMessage
* @param uuidToAccountMap
@@ -300,17 +297,20 @@ public class MessageController {
Map uuidToAccountMap
) {
- Stream recipients = Arrays.stream(multiRecipientMessage.getRecipients());
-
- return recipients.collect(Collectors.toMap(
- recipient -> uuidToAccountMap.get(recipient.getUuid()),
- recipient -> new HashSet<>(
- Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
- (a, b) -> {
- a.addAll(b);
- return a;
- }
- ));
+ return Arrays.stream(multiRecipientMessage.getRecipients())
+ // for normal messages, all recipients UUIDs are in the map,
+ // but story messages might specify inactive UUIDs, which we
+ // have previously filtered
+ .filter(r -> uuidToAccountMap.containsKey(r.getUuid()))
+ .collect(Collectors.toMap(
+ recipient -> uuidToAccountMap.get(recipient.getUuid()),
+ recipient -> new HashSet<>(
+ Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
+ (a, b) -> {
+ a.addAll(b);
+ return a;
+ }
+ ));
}
@Timed
@@ -328,14 +328,26 @@ public class MessageController {
@QueryParam("urgent") @DefaultValue("true") final boolean isUrgent,
@QueryParam("story") boolean isStory,
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) {
- Map uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients())
- .map(Recipient::getUuid)
- .distinct()
- .collect(Collectors.toUnmodifiableMap(
- Function.identity(),
- uuid -> accountsManager
- .getByAccountIdentifier(uuid)
- .orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND))));
+
+ // we skip "missing" accounts when story=true.
+ // otherwise, we return a 404 status code.
+ final Function> accountFinder = uuid -> {
+ Optional res = accountsManager.getByAccountIdentifier(uuid);
+ if (!isStory && res.isEmpty()) {
+ throw new WebApplicationException(Status.NOT_FOUND);
+ }
+ return res.stream();
+ };
+
+ // build a map from UUID to accounts
+ Map uuidToAccountMap =
+ Arrays.stream(multiRecipientMessage.getRecipients())
+ .map(Recipient::getUuid)
+ .distinct()
+ .flatMap(accountFinder)
+ .collect(Collectors.toUnmodifiableMap(
+ Account::getUuid,
+ Function.identity()));
// Stories will be checked by the client; we bypass access checks here for stories.
if (!isStory) {
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java
index 786aae29c..1fd895e15 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java
@@ -39,6 +39,7 @@ import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Base64;
+import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@@ -67,6 +68,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
+import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
@@ -95,14 +97,14 @@ import org.whispersystems.websocket.Stories;
class MessageControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
- private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
- private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID();
+ private static final UUID SINGLE_DEVICE_UUID = UUID.fromString("11111111-1111-1111-1111-111111111111");
+ private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111");
private static final int SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
- private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID();
- private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID();
+ private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222");
+ private static final UUID MULTI_DEVICE_PNI = UUID.fromString("22222222-0000-0000-0000-222222222222");
private static final int MULTI_DEVICE_ID1 = 1;
private static final int MULTI_DEVICE_ID2 = 2;
private static final int MULTI_DEVICE_ID3 = 3;
@@ -113,7 +115,7 @@ class MessageControllerTest {
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
- private static final UUID INTERNATIONAL_UUID = UUID.randomUUID();
+ private static final UUID INTERNATIONAL_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
private Account internationalAccount;
@@ -717,10 +719,10 @@ class MessageControllerTest {
);
}
- private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, int deviceId, int regId) throws Exception {
+ private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, long deviceId, int regId) throws Exception {
bb.putLong(msb); // uuid (first 8 bytes)
bb.putLong(lsb); // uuid (last 8 bytes)
- int x = deviceId;
+ long x = deviceId;
// write the device-id in the 7-bit varint format we use, least significant bytes first.
do {
bb.put((byte)(x & 0x7f));
@@ -730,30 +732,22 @@ class MessageControllerTest {
bb.put(new byte[48]); // key material (48 bytes)
}
- private InputStream initializeMultiPayload(UUID recipientUUID, byte[] buffer) throws Exception {
+ private InputStream initializeMultiPayload(List recipients, byte[] buffer) throws Exception {
// initialize a binary payload according to our wire format
ByteBuffer bb = ByteBuffer.wrap(buffer);
bb.order(ByteOrder.BIG_ENDIAN);
- // determine how many recipient/device pairs we will be writing
- int count;
- if (recipientUUID == MULTI_DEVICE_UUID) { count = 2; }
- else if (recipientUUID == SINGLE_DEVICE_UUID) { count = 1; }
- else { throw new Exception("unknown UUID: " + recipientUUID); }
-
- // first write the header header
+ // first write the header
bb.put(MultiRecipientMessageProvider.VERSION); // version byte
- bb.put((byte)count); // count varint, # of active devices for this user
+ bb.put((byte)recipients.size()); // count varint
- long msb = recipientUUID.getMostSignificantBits();
- long lsb = recipientUUID.getLeastSignificantBits();
-
- // write the recipient data for each recipient/device pair
- if (recipientUUID == MULTI_DEVICE_UUID) {
- writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1);
- writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2);
- } else {
- writeMultiPayloadRecipient(bb, msb, lsb, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1);
+ Iterator it = recipients.iterator();
+ while (it.hasNext()) {
+ Recipient r = it.next();
+ UUID uuid = r.getUuid();
+ long msb = uuid.getMostSignificantBits();
+ long lsb = uuid.getLeastSignificantBits();
+ writeMultiPayloadRecipient(bb, msb, lsb, r.getDeviceId(), r.getRegistrationId());
}
// now write the actual message body (empty for now)
@@ -767,9 +761,20 @@ class MessageControllerTest {
@MethodSource
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent) throws Exception {
+ final List recipients;
+ if (recipientUUID == MULTI_DEVICE_UUID) {
+ recipients = List.of(
+ new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
+ new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
+ );
+ } else {
+ recipients = List.of(new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
+ }
+
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
- InputStream stream = initializeMultiPayload(recipientUUID, buffer);
+ //InputStream stream = initializeMultiPayload(recipientUUID, buffer);
+ InputStream stream = initializeMultiPayload(recipients, buffer);
// set up the entity to use in our PUT request
Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
@@ -895,6 +900,63 @@ class MessageControllerTest {
assertThat("200 masks unknown recipient", response.getStatus(), is(equalTo(200)));
}
+ @ParameterizedTest
+ @MethodSource
+ void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known) throws Exception {
+
+ final Recipient r1;
+ if (known) {
+ r1 = new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
+ } else {
+ r1 = new Recipient(UUID.randomUUID(), 999, 999, new byte[48]);
+ }
+
+ Recipient r2 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
+ Recipient r3 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
+
+ List recipients = List.of(r1, r2, r3);
+
+ byte[] buffer = new byte[2048];
+ InputStream stream = initializeMultiPayload(recipients, buffer);
+ // set up the entity to use in our PUT request
+ Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
+
+ // This looks weird, but there is a method to the madness.
+ // new bytes[16] is equivalent to UNIDENTIFIED_ACCESS_BYTES ^ UNIDENTIFIED_ACCESS_BYTES
+ // (i.e. we need to XOR all the access keys together)
+ String accessBytes = Base64.getEncoder().encodeToString(new byte[16]);
+
+ // start building the request
+ Invocation.Builder bldr = resources
+ .getJerseyTest()
+ .target("/v1/messages/multi_recipient")
+ .queryParam("online", true)
+ .queryParam("ts", 1663798405641L)
+ .queryParam("story", story)
+ .request()
+ .header("User-Agent", "Test User Agent")
+ .header(OptionalAccess.UNIDENTIFIED, accessBytes);
+
+ // make the PUT request
+ Response response = bldr.put(entity);
+
+ if (story || known) {
+ // it's a story so we unconditionally get 200 ok
+ assertEquals(200, response.getStatus());
+ } else {
+ // unknown recipient means 404 not found
+ assertEquals(404, response.getStatus());
+ }
+ }
+
+ private static Stream testSendMultiRecipientMessageToUnknownAccounts() {
+ return Stream.of(
+ Arguments.of(true, true),
+ Arguments.of(true, false),
+ Arguments.of(false, true),
+ Arguments.of(false, false));
+ }
+
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());