diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index e9416a41b..e9fc2fa72 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -287,9 +287,6 @@ public class MessageController { /** * Build mapping of accounts to devices/registration IDs. - *

- * Messages that are stories will only be sent to the subset of recipients who have indicated they want to receive - * stories. * * @param multiRecipientMessage * @param uuidToAccountMap @@ -300,17 +297,20 @@ public class MessageController { Map uuidToAccountMap ) { - Stream recipients = Arrays.stream(multiRecipientMessage.getRecipients()); - - return recipients.collect(Collectors.toMap( - recipient -> uuidToAccountMap.get(recipient.getUuid()), - recipient -> new HashSet<>( - Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), - (a, b) -> { - a.addAll(b); - return a; - } - )); + return Arrays.stream(multiRecipientMessage.getRecipients()) + // for normal messages, all recipients UUIDs are in the map, + // but story messages might specify inactive UUIDs, which we + // have previously filtered + .filter(r -> uuidToAccountMap.containsKey(r.getUuid())) + .collect(Collectors.toMap( + recipient -> uuidToAccountMap.get(recipient.getUuid()), + recipient -> new HashSet<>( + Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), + (a, b) -> { + a.addAll(b); + return a; + } + )); } @Timed @@ -328,14 +328,26 @@ public class MessageController { @QueryParam("urgent") @DefaultValue("true") final boolean isUrgent, @QueryParam("story") boolean isStory, @NotNull @Valid MultiRecipientMessage multiRecipientMessage) { - Map uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients()) - .map(Recipient::getUuid) - .distinct() - .collect(Collectors.toUnmodifiableMap( - Function.identity(), - uuid -> accountsManager - .getByAccountIdentifier(uuid) - .orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND)))); + + // we skip "missing" accounts when story=true. + // otherwise, we return a 404 status code. + final Function> accountFinder = uuid -> { + Optional res = accountsManager.getByAccountIdentifier(uuid); + if (!isStory && res.isEmpty()) { + throw new WebApplicationException(Status.NOT_FOUND); + } + return res.stream(); + }; + + // build a map from UUID to accounts + Map uuidToAccountMap = + Arrays.stream(multiRecipientMessage.getRecipients()) + .map(Recipient::getUuid) + .distinct() + .flatMap(accountFinder) + .collect(Collectors.toUnmodifiableMap( + Account::getUuid, + Function.identity())); // Stories will be checked by the client; we bypass access checks here for stories. if (!isStory) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 786aae29c..1fd895e15 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -39,6 +39,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; import java.util.Base64; +import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -67,6 +68,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; +import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; @@ -95,14 +97,14 @@ import org.whispersystems.websocket.Stories; class MessageControllerTest { private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111"; - private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID(); - private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID(); + private static final UUID SINGLE_DEVICE_UUID = UUID.fromString("11111111-1111-1111-1111-111111111111"); + private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111"); private static final int SINGLE_DEVICE_ID1 = 1; private static final int SINGLE_DEVICE_REG_ID1 = 111; private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; - private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID(); - private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID(); + private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222"); + private static final UUID MULTI_DEVICE_PNI = UUID.fromString("22222222-0000-0000-0000-222222222222"); private static final int MULTI_DEVICE_ID1 = 1; private static final int MULTI_DEVICE_ID2 = 2; private static final int MULTI_DEVICE_ID3 = 3; @@ -113,7 +115,7 @@ class MessageControllerTest { private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes(); private static final String INTERNATIONAL_RECIPIENT = "+61123456789"; - private static final UUID INTERNATIONAL_UUID = UUID.randomUUID(); + private static final UUID INTERNATIONAL_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333"); private Account internationalAccount; @@ -717,10 +719,10 @@ class MessageControllerTest { ); } - private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, int deviceId, int regId) throws Exception { + private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, long deviceId, int regId) throws Exception { bb.putLong(msb); // uuid (first 8 bytes) bb.putLong(lsb); // uuid (last 8 bytes) - int x = deviceId; + long x = deviceId; // write the device-id in the 7-bit varint format we use, least significant bytes first. do { bb.put((byte)(x & 0x7f)); @@ -730,30 +732,22 @@ class MessageControllerTest { bb.put(new byte[48]); // key material (48 bytes) } - private InputStream initializeMultiPayload(UUID recipientUUID, byte[] buffer) throws Exception { + private InputStream initializeMultiPayload(List recipients, byte[] buffer) throws Exception { // initialize a binary payload according to our wire format ByteBuffer bb = ByteBuffer.wrap(buffer); bb.order(ByteOrder.BIG_ENDIAN); - // determine how many recipient/device pairs we will be writing - int count; - if (recipientUUID == MULTI_DEVICE_UUID) { count = 2; } - else if (recipientUUID == SINGLE_DEVICE_UUID) { count = 1; } - else { throw new Exception("unknown UUID: " + recipientUUID); } - - // first write the header header + // first write the header bb.put(MultiRecipientMessageProvider.VERSION); // version byte - bb.put((byte)count); // count varint, # of active devices for this user + bb.put((byte)recipients.size()); // count varint - long msb = recipientUUID.getMostSignificantBits(); - long lsb = recipientUUID.getLeastSignificantBits(); - - // write the recipient data for each recipient/device pair - if (recipientUUID == MULTI_DEVICE_UUID) { - writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1); - writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2); - } else { - writeMultiPayloadRecipient(bb, msb, lsb, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1); + Iterator it = recipients.iterator(); + while (it.hasNext()) { + Recipient r = it.next(); + UUID uuid = r.getUuid(); + long msb = uuid.getMostSignificantBits(); + long lsb = uuid.getLeastSignificantBits(); + writeMultiPayloadRecipient(bb, msb, lsb, r.getDeviceId(), r.getRegistrationId()); } // now write the actual message body (empty for now) @@ -767,9 +761,20 @@ class MessageControllerTest { @MethodSource void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent) throws Exception { + final List recipients; + if (recipientUUID == MULTI_DEVICE_UUID) { + recipients = List.of( + new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), + new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]) + ); + } else { + recipients = List.of(new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); + } + // initialize our binary payload and create an input stream byte[] buffer = new byte[2048]; - InputStream stream = initializeMultiPayload(recipientUUID, buffer); + //InputStream stream = initializeMultiPayload(recipientUUID, buffer); + InputStream stream = initializeMultiPayload(recipients, buffer); // set up the entity to use in our PUT request Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); @@ -895,6 +900,63 @@ class MessageControllerTest { assertThat("200 masks unknown recipient", response.getStatus(), is(equalTo(200))); } + @ParameterizedTest + @MethodSource + void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known) throws Exception { + + final Recipient r1; + if (known) { + r1 = new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]); + } else { + r1 = new Recipient(UUID.randomUUID(), 999, 999, new byte[48]); + } + + Recipient r2 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]); + Recipient r3 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]); + + List recipients = List.of(r1, r2, r3); + + byte[] buffer = new byte[2048]; + InputStream stream = initializeMultiPayload(recipients, buffer); + // set up the entity to use in our PUT request + Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); + + // This looks weird, but there is a method to the madness. + // new bytes[16] is equivalent to UNIDENTIFIED_ACCESS_BYTES ^ UNIDENTIFIED_ACCESS_BYTES + // (i.e. we need to XOR all the access keys together) + String accessBytes = Base64.getEncoder().encodeToString(new byte[16]); + + // start building the request + Invocation.Builder bldr = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("online", true) + .queryParam("ts", 1663798405641L) + .queryParam("story", story) + .request() + .header("User-Agent", "Test User Agent") + .header(OptionalAccess.UNIDENTIFIED, accessBytes); + + // make the PUT request + Response response = bldr.put(entity); + + if (story || known) { + // it's a story so we unconditionally get 200 ok + assertEquals(200, response.getStatus()); + } else { + // unknown recipient means 404 not found + assertEquals(404, response.getStatus()); + } + } + + private static Stream testSendMultiRecipientMessageToUnknownAccounts() { + return Stream.of( + Arguments.of(true, true), + Arguments.of(true, false), + Arguments.of(false, true), + Arguments.of(false, false)); + } + private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode))); verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());