Add handling of registration id in multi recipient send payload

This commit is contained in:
Ehren Kret 2021-05-17 13:01:37 -05:00
parent 10cd60738a
commit f76e6705c0
4 changed files with 86 additions and 15 deletions

View File

@ -38,6 +38,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
@ -62,6 +63,7 @@ import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKey
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
@ -94,6 +96,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
@ -360,16 +363,23 @@ public class MessageController {
checkAccessKeys(accessKeys, uuidToAccountMap);
List<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
List<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
for (Account account : uuidToAccountMap.values()) {
Set<Long> deviceIds = Arrays.stream(multiRecipientMessage.getRecipients())
.filter(recipient -> recipient.getUuid().equals(account.getUuid()))
.map(Recipient::getDeviceId)
.collect(Collectors.toSet());
Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = Arrays.stream(multiRecipientMessage.getRecipients())
.filter(recipient -> recipient.getUuid().equals(account.getUuid()))
.map(recipient -> new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()));
try {
validateCompleteDeviceList(account, deviceIds, false);
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(new AccountStaleDevices(account.getUuid(), new StaleDevices(e.getStaleDevices())));
}
}
if (!accountMismatchedDevices.isEmpty()) {
@ -379,6 +389,13 @@ public class MessageController {
.entity(accountMismatchedDevices)
.build();
}
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
List<Tag> tags = List.of(
UserAgentTagUtil.getPlatformTag(userAgent),
@ -639,21 +656,24 @@ public class MessageController {
}
private void validateRegistrationIds(Account account, List<IncomingMessage> messages)
throws StaleDevicesException
{
List<Long> staleDevices = new LinkedList<>();
for (IncomingMessage message : messages) {
Optional<Device> device = account.getDevice(message.getDestinationDeviceId());
if (device.isPresent() &&
message.getDestinationRegistrationId() > 0 &&
message.getDestinationRegistrationId() != device.get().getRegistrationId())
{
staleDevices.add(device.get().getId());
}
throws StaleDevicesException {
final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = messages
.stream()
.map(message -> new Pair<>(message.getDestinationDeviceId(), message.getDestinationRegistrationId()));
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
}
private void validateRegistrationIds(Account account, Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream)
throws StaleDevicesException {
final List<Long> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
Optional<Device> device = account.getDevice(deviceIdAndRegistrationId.first());
return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId();
})
.map(Pair::first)
.collect(Collectors.toList());
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}

View File

@ -0,0 +1,22 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.UUID;
public class AccountStaleDevices {
@JsonProperty
public final UUID uuid;
@JsonProperty
public final StaleDevices devices;
public AccountStaleDevices(final UUID uuid, final StaleDevices devices) {
this.uuid = uuid;
this.devices = devices;
}
}

View File

@ -6,6 +6,8 @@
package org.whispersystems.textsecuregcm.entities;
import java.util.UUID;
import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
import javax.validation.constraints.Size;
@ -21,13 +23,18 @@ public class MultiRecipientMessage {
@Min(1)
private final long deviceId;
@Min(0)
@Max(65535)
private final int registrationId;
@Size(min = 48, max = 48)
@NotNull
private final byte[] perRecipientKeyMaterial;
public Recipient(UUID uuid, long deviceId, byte[] perRecipientKeyMaterial) {
public Recipient(UUID uuid, long deviceId, int registrationId, byte[] perRecipientKeyMaterial) {
this.uuid = uuid;
this.deviceId = deviceId;
this.registrationId = registrationId;
this.perRecipientKeyMaterial = perRecipientKeyMaterial;
}
@ -39,6 +46,10 @@ public class MultiRecipientMessage {
return deviceId;
}
public int getRegistrationId() {
return registrationId;
}
public byte[] getPerRecipientKeyMaterial() {
return perRecipientKeyMaterial;
}
@ -46,6 +57,7 @@ public class MultiRecipientMessage {
@NotNull
@Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT)
@Valid
private final Recipient[] recipients;
@NotNull

View File

@ -54,11 +54,12 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
for (int i = 0; i < Math.toIntExact(count); i++) {
UUID uuid = readUuid(entityStream);
long deviceId = readVarint(entityStream);
int registrationId = readU16(entityStream);
byte[] perRecipientKeyMaterial = entityStream.readNBytes(48);
if (perRecipientKeyMaterial.length != 48) {
throw new IOException("Failed to read expected number of key material bytes for a recipient");
}
recipients[i] = new MultiRecipientMessage.Recipient(uuid, deviceId, perRecipientKeyMaterial);
recipients[i] = new MultiRecipientMessage.Recipient(uuid, deviceId, registrationId, perRecipientKeyMaterial);
}
// caller is responsible for checking that the entity stream is at EOF when we return; if there are more bytes than
@ -126,4 +127,20 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
}
return result;
}
/**
* Reads two bytes with most significant byte first. Treats the value as unsigned so the range returned is
* {@code [0, 65535]}.
*/
private int readU16(InputStream stream) throws IOException {
int b1 = stream.read();
if (b1 == -1) {
throw new IOException("Missing byte 1 of U16");
}
int b2 = stream.read();
if (b2 == -1) {
throw new IOException("Missing byte 2 of U16");
}
return (b1 << 8) | b2;
}
}