Filter unknown UUIDs for /multi_recipient&story=true.

This commit is contained in:
erik-signal 2022-10-13 15:33:51 -04:00 committed by GitHub
parent 7a2683a06b
commit 3e0baf82a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 122 additions and 48 deletions

View File

@ -287,9 +287,6 @@ public class MessageController {
/**
* Build mapping of accounts to devices/registration IDs.
* <p>
* Messages that are stories will only be sent to the subset of recipients who have indicated they want to receive
* stories.
*
* @param multiRecipientMessage
* @param uuidToAccountMap
@ -300,17 +297,20 @@ public class MessageController {
Map<UUID, Account> uuidToAccountMap
) {
Stream<Recipient> recipients = Arrays.stream(multiRecipientMessage.getRecipients());
return recipients.collect(Collectors.toMap(
recipient -> uuidToAccountMap.get(recipient.getUuid()),
recipient -> new HashSet<>(
Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
(a, b) -> {
a.addAll(b);
return a;
}
));
return Arrays.stream(multiRecipientMessage.getRecipients())
// for normal messages, all recipients UUIDs are in the map,
// but story messages might specify inactive UUIDs, which we
// have previously filtered
.filter(r -> uuidToAccountMap.containsKey(r.getUuid()))
.collect(Collectors.toMap(
recipient -> uuidToAccountMap.get(recipient.getUuid()),
recipient -> new HashSet<>(
Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
(a, b) -> {
a.addAll(b);
return a;
}
));
}
@Timed
@ -328,14 +328,26 @@ public class MessageController {
@QueryParam("urgent") @DefaultValue("true") final boolean isUrgent,
@QueryParam("story") boolean isStory,
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) {
Map<UUID, Account> uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients())
.map(Recipient::getUuid)
.distinct()
.collect(Collectors.toUnmodifiableMap(
Function.identity(),
uuid -> accountsManager
.getByAccountIdentifier(uuid)
.orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND))));
// we skip "missing" accounts when story=true.
// otherwise, we return a 404 status code.
final Function<UUID, Stream<Account>> accountFinder = uuid -> {
Optional<Account> res = accountsManager.getByAccountIdentifier(uuid);
if (!isStory && res.isEmpty()) {
throw new WebApplicationException(Status.NOT_FOUND);
}
return res.stream();
};
// build a map from UUID to accounts
Map<UUID, Account> uuidToAccountMap =
Arrays.stream(multiRecipientMessage.getRecipients())
.map(Recipient::getUuid)
.distinct()
.flatMap(accountFinder)
.collect(Collectors.toUnmodifiableMap(
Account::getUuid,
Function.identity()));
// Stories will be checked by the client; we bypass access checks here for stories.
if (!isStory) {

View File

@ -39,6 +39,7 @@ import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Base64;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@ -67,6 +68,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
@ -95,14 +97,14 @@ import org.whispersystems.websocket.Stories;
class MessageControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID();
private static final UUID SINGLE_DEVICE_UUID = UUID.fromString("11111111-1111-1111-1111-111111111111");
private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111");
private static final int SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID();
private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID();
private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222");
private static final UUID MULTI_DEVICE_PNI = UUID.fromString("22222222-0000-0000-0000-222222222222");
private static final int MULTI_DEVICE_ID1 = 1;
private static final int MULTI_DEVICE_ID2 = 2;
private static final int MULTI_DEVICE_ID3 = 3;
@ -113,7 +115,7 @@ class MessageControllerTest {
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
private static final UUID INTERNATIONAL_UUID = UUID.randomUUID();
private static final UUID INTERNATIONAL_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
private Account internationalAccount;
@ -717,10 +719,10 @@ class MessageControllerTest {
);
}
private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, int deviceId, int regId) throws Exception {
private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, long deviceId, int regId) throws Exception {
bb.putLong(msb); // uuid (first 8 bytes)
bb.putLong(lsb); // uuid (last 8 bytes)
int x = deviceId;
long x = deviceId;
// write the device-id in the 7-bit varint format we use, least significant bytes first.
do {
bb.put((byte)(x & 0x7f));
@ -730,30 +732,22 @@ class MessageControllerTest {
bb.put(new byte[48]); // key material (48 bytes)
}
private InputStream initializeMultiPayload(UUID recipientUUID, byte[] buffer) throws Exception {
private InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer) throws Exception {
// initialize a binary payload according to our wire format
ByteBuffer bb = ByteBuffer.wrap(buffer);
bb.order(ByteOrder.BIG_ENDIAN);
// determine how many recipient/device pairs we will be writing
int count;
if (recipientUUID == MULTI_DEVICE_UUID) { count = 2; }
else if (recipientUUID == SINGLE_DEVICE_UUID) { count = 1; }
else { throw new Exception("unknown UUID: " + recipientUUID); }
// first write the header header
// first write the header
bb.put(MultiRecipientMessageProvider.VERSION); // version byte
bb.put((byte)count); // count varint, # of active devices for this user
bb.put((byte)recipients.size()); // count varint
long msb = recipientUUID.getMostSignificantBits();
long lsb = recipientUUID.getLeastSignificantBits();
// write the recipient data for each recipient/device pair
if (recipientUUID == MULTI_DEVICE_UUID) {
writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1);
writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2);
} else {
writeMultiPayloadRecipient(bb, msb, lsb, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1);
Iterator<Recipient> it = recipients.iterator();
while (it.hasNext()) {
Recipient r = it.next();
UUID uuid = r.getUuid();
long msb = uuid.getMostSignificantBits();
long lsb = uuid.getLeastSignificantBits();
writeMultiPayloadRecipient(bb, msb, lsb, r.getDeviceId(), r.getRegistrationId());
}
// now write the actual message body (empty for now)
@ -767,9 +761,20 @@ class MessageControllerTest {
@MethodSource
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent) throws Exception {
final List<Recipient> recipients;
if (recipientUUID == MULTI_DEVICE_UUID) {
recipients = List.of(
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
);
} else {
recipients = List.of(new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
}
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipientUUID, buffer);
//InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
@ -895,6 +900,63 @@ class MessageControllerTest {
assertThat("200 masks unknown recipient", response.getStatus(), is(equalTo(200)));
}
@ParameterizedTest
@MethodSource
void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known) throws Exception {
final Recipient r1;
if (known) {
r1 = new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
} else {
r1 = new Recipient(UUID.randomUUID(), 999, 999, new byte[48]);
}
Recipient r2 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
Recipient r3 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
List<Recipient> recipients = List.of(r1, r2, r3);
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipients, buffer);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// This looks weird, but there is a method to the madness.
// new bytes[16] is equivalent to UNIDENTIFIED_ACCESS_BYTES ^ UNIDENTIFIED_ACCESS_BYTES
// (i.e. we need to XOR all the access keys together)
String accessBytes = Base64.getEncoder().encodeToString(new byte[16]);
// start building the request
Invocation.Builder bldr = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1663798405641L)
.queryParam("story", story)
.request()
.header("User-Agent", "Test User Agent")
.header(OptionalAccess.UNIDENTIFIED, accessBytes);
// make the PUT request
Response response = bldr.put(entity);
if (story || known) {
// it's a story so we unconditionally get 200 ok
assertEquals(200, response.getStatus());
} else {
// unknown recipient means 404 not found
assertEquals(404, response.getStatus());
}
}
private static Stream<Arguments> testSendMultiRecipientMessageToUnknownAccounts() {
return Stream.of(
Arguments.of(true, true),
Arguments.of(true, false),
Arguments.of(false, true),
Arguments.of(false, false));
}
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());