Filter unknown UUIDs for /multi_recipient&story=true.
This commit is contained in:
parent
7a2683a06b
commit
3e0baf82a4
|
@ -287,9 +287,6 @@ public class MessageController {
|
|||
|
||||
/**
|
||||
* Build mapping of accounts to devices/registration IDs.
|
||||
* <p>
|
||||
* Messages that are stories will only be sent to the subset of recipients who have indicated they want to receive
|
||||
* stories.
|
||||
*
|
||||
* @param multiRecipientMessage
|
||||
* @param uuidToAccountMap
|
||||
|
@ -300,17 +297,20 @@ public class MessageController {
|
|||
Map<UUID, Account> uuidToAccountMap
|
||||
) {
|
||||
|
||||
Stream<Recipient> recipients = Arrays.stream(multiRecipientMessage.getRecipients());
|
||||
|
||||
return recipients.collect(Collectors.toMap(
|
||||
recipient -> uuidToAccountMap.get(recipient.getUuid()),
|
||||
recipient -> new HashSet<>(
|
||||
Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
|
||||
(a, b) -> {
|
||||
a.addAll(b);
|
||||
return a;
|
||||
}
|
||||
));
|
||||
return Arrays.stream(multiRecipientMessage.getRecipients())
|
||||
// for normal messages, all recipients UUIDs are in the map,
|
||||
// but story messages might specify inactive UUIDs, which we
|
||||
// have previously filtered
|
||||
.filter(r -> uuidToAccountMap.containsKey(r.getUuid()))
|
||||
.collect(Collectors.toMap(
|
||||
recipient -> uuidToAccountMap.get(recipient.getUuid()),
|
||||
recipient -> new HashSet<>(
|
||||
Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
|
||||
(a, b) -> {
|
||||
a.addAll(b);
|
||||
return a;
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
@Timed
|
||||
|
@ -328,14 +328,26 @@ public class MessageController {
|
|||
@QueryParam("urgent") @DefaultValue("true") final boolean isUrgent,
|
||||
@QueryParam("story") boolean isStory,
|
||||
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) {
|
||||
Map<UUID, Account> uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients())
|
||||
.map(Recipient::getUuid)
|
||||
.distinct()
|
||||
.collect(Collectors.toUnmodifiableMap(
|
||||
Function.identity(),
|
||||
uuid -> accountsManager
|
||||
.getByAccountIdentifier(uuid)
|
||||
.orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND))));
|
||||
|
||||
// we skip "missing" accounts when story=true.
|
||||
// otherwise, we return a 404 status code.
|
||||
final Function<UUID, Stream<Account>> accountFinder = uuid -> {
|
||||
Optional<Account> res = accountsManager.getByAccountIdentifier(uuid);
|
||||
if (!isStory && res.isEmpty()) {
|
||||
throw new WebApplicationException(Status.NOT_FOUND);
|
||||
}
|
||||
return res.stream();
|
||||
};
|
||||
|
||||
// build a map from UUID to accounts
|
||||
Map<UUID, Account> uuidToAccountMap =
|
||||
Arrays.stream(multiRecipientMessage.getRecipients())
|
||||
.map(Recipient::getUuid)
|
||||
.distinct()
|
||||
.flatMap(accountFinder)
|
||||
.collect(Collectors.toUnmodifiableMap(
|
||||
Account::getUuid,
|
||||
Function.identity()));
|
||||
|
||||
// Stories will be checked by the client; we bypass access checks here for stories.
|
||||
if (!isStory) {
|
||||
|
|
|
@ -39,6 +39,7 @@ import java.nio.ByteBuffer;
|
|||
import java.nio.ByteOrder;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
@ -67,6 +68,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
|||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
||||
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
|
||||
|
@ -95,14 +97,14 @@ import org.whispersystems.websocket.Stories;
|
|||
class MessageControllerTest {
|
||||
|
||||
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
|
||||
private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
|
||||
private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID();
|
||||
private static final UUID SINGLE_DEVICE_UUID = UUID.fromString("11111111-1111-1111-1111-111111111111");
|
||||
private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111");
|
||||
private static final int SINGLE_DEVICE_ID1 = 1;
|
||||
private static final int SINGLE_DEVICE_REG_ID1 = 111;
|
||||
|
||||
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
|
||||
private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID();
|
||||
private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID();
|
||||
private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222");
|
||||
private static final UUID MULTI_DEVICE_PNI = UUID.fromString("22222222-0000-0000-0000-222222222222");
|
||||
private static final int MULTI_DEVICE_ID1 = 1;
|
||||
private static final int MULTI_DEVICE_ID2 = 2;
|
||||
private static final int MULTI_DEVICE_ID3 = 3;
|
||||
|
@ -113,7 +115,7 @@ class MessageControllerTest {
|
|||
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
|
||||
|
||||
private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
|
||||
private static final UUID INTERNATIONAL_UUID = UUID.randomUUID();
|
||||
private static final UUID INTERNATIONAL_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
|
||||
|
||||
private Account internationalAccount;
|
||||
|
||||
|
@ -717,10 +719,10 @@ class MessageControllerTest {
|
|||
);
|
||||
}
|
||||
|
||||
private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, int deviceId, int regId) throws Exception {
|
||||
private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, long deviceId, int regId) throws Exception {
|
||||
bb.putLong(msb); // uuid (first 8 bytes)
|
||||
bb.putLong(lsb); // uuid (last 8 bytes)
|
||||
int x = deviceId;
|
||||
long x = deviceId;
|
||||
// write the device-id in the 7-bit varint format we use, least significant bytes first.
|
||||
do {
|
||||
bb.put((byte)(x & 0x7f));
|
||||
|
@ -730,30 +732,22 @@ class MessageControllerTest {
|
|||
bb.put(new byte[48]); // key material (48 bytes)
|
||||
}
|
||||
|
||||
private InputStream initializeMultiPayload(UUID recipientUUID, byte[] buffer) throws Exception {
|
||||
private InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer) throws Exception {
|
||||
// initialize a binary payload according to our wire format
|
||||
ByteBuffer bb = ByteBuffer.wrap(buffer);
|
||||
bb.order(ByteOrder.BIG_ENDIAN);
|
||||
|
||||
// determine how many recipient/device pairs we will be writing
|
||||
int count;
|
||||
if (recipientUUID == MULTI_DEVICE_UUID) { count = 2; }
|
||||
else if (recipientUUID == SINGLE_DEVICE_UUID) { count = 1; }
|
||||
else { throw new Exception("unknown UUID: " + recipientUUID); }
|
||||
|
||||
// first write the header header
|
||||
// first write the header
|
||||
bb.put(MultiRecipientMessageProvider.VERSION); // version byte
|
||||
bb.put((byte)count); // count varint, # of active devices for this user
|
||||
bb.put((byte)recipients.size()); // count varint
|
||||
|
||||
long msb = recipientUUID.getMostSignificantBits();
|
||||
long lsb = recipientUUID.getLeastSignificantBits();
|
||||
|
||||
// write the recipient data for each recipient/device pair
|
||||
if (recipientUUID == MULTI_DEVICE_UUID) {
|
||||
writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1);
|
||||
writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2);
|
||||
} else {
|
||||
writeMultiPayloadRecipient(bb, msb, lsb, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1);
|
||||
Iterator<Recipient> it = recipients.iterator();
|
||||
while (it.hasNext()) {
|
||||
Recipient r = it.next();
|
||||
UUID uuid = r.getUuid();
|
||||
long msb = uuid.getMostSignificantBits();
|
||||
long lsb = uuid.getLeastSignificantBits();
|
||||
writeMultiPayloadRecipient(bb, msb, lsb, r.getDeviceId(), r.getRegistrationId());
|
||||
}
|
||||
|
||||
// now write the actual message body (empty for now)
|
||||
|
@ -767,9 +761,20 @@ class MessageControllerTest {
|
|||
@MethodSource
|
||||
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent) throws Exception {
|
||||
|
||||
final List<Recipient> recipients;
|
||||
if (recipientUUID == MULTI_DEVICE_UUID) {
|
||||
recipients = List.of(
|
||||
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
|
||||
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
|
||||
);
|
||||
} else {
|
||||
recipients = List.of(new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
|
||||
}
|
||||
|
||||
// initialize our binary payload and create an input stream
|
||||
byte[] buffer = new byte[2048];
|
||||
InputStream stream = initializeMultiPayload(recipientUUID, buffer);
|
||||
//InputStream stream = initializeMultiPayload(recipientUUID, buffer);
|
||||
InputStream stream = initializeMultiPayload(recipients, buffer);
|
||||
|
||||
// set up the entity to use in our PUT request
|
||||
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
|
||||
|
@ -895,6 +900,63 @@ class MessageControllerTest {
|
|||
assertThat("200 masks unknown recipient", response.getStatus(), is(equalTo(200)));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known) throws Exception {
|
||||
|
||||
final Recipient r1;
|
||||
if (known) {
|
||||
r1 = new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
|
||||
} else {
|
||||
r1 = new Recipient(UUID.randomUUID(), 999, 999, new byte[48]);
|
||||
}
|
||||
|
||||
Recipient r2 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
|
||||
Recipient r3 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
|
||||
|
||||
List<Recipient> recipients = List.of(r1, r2, r3);
|
||||
|
||||
byte[] buffer = new byte[2048];
|
||||
InputStream stream = initializeMultiPayload(recipients, buffer);
|
||||
// set up the entity to use in our PUT request
|
||||
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
|
||||
|
||||
// This looks weird, but there is a method to the madness.
|
||||
// new bytes[16] is equivalent to UNIDENTIFIED_ACCESS_BYTES ^ UNIDENTIFIED_ACCESS_BYTES
|
||||
// (i.e. we need to XOR all the access keys together)
|
||||
String accessBytes = Base64.getEncoder().encodeToString(new byte[16]);
|
||||
|
||||
// start building the request
|
||||
Invocation.Builder bldr = resources
|
||||
.getJerseyTest()
|
||||
.target("/v1/messages/multi_recipient")
|
||||
.queryParam("online", true)
|
||||
.queryParam("ts", 1663798405641L)
|
||||
.queryParam("story", story)
|
||||
.request()
|
||||
.header("User-Agent", "Test User Agent")
|
||||
.header(OptionalAccess.UNIDENTIFIED, accessBytes);
|
||||
|
||||
// make the PUT request
|
||||
Response response = bldr.put(entity);
|
||||
|
||||
if (story || known) {
|
||||
// it's a story so we unconditionally get 200 ok
|
||||
assertEquals(200, response.getStatus());
|
||||
} else {
|
||||
// unknown recipient means 404 not found
|
||||
assertEquals(404, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
private static Stream<Arguments> testSendMultiRecipientMessageToUnknownAccounts() {
|
||||
return Stream.of(
|
||||
Arguments.of(true, true),
|
||||
Arguments.of(true, false),
|
||||
Arguments.of(false, true),
|
||||
Arguments.of(false, false));
|
||||
}
|
||||
|
||||
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
|
||||
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode)));
|
||||
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
|
||||
|
|
Loading…
Reference in New Issue