Add handling of registration id in multi recipient send payload

This commit is contained in:
Ehren Kret 2021-05-17 13:01:37 -05:00
parent 10cd60738a
commit f76e6705c0
4 changed files with 86 additions and 15 deletions

View File

@ -38,6 +38,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.validation.Valid; import javax.validation.Valid;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE; import javax.ws.rs.DELETE;
@ -62,6 +63,7 @@ import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKey
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices; import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
@ -94,6 +96,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.ForwardedIpUtil; import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
@ -360,16 +363,23 @@ public class MessageController {
checkAccessKeys(accessKeys, uuidToAccountMap); checkAccessKeys(accessKeys, uuidToAccountMap);
List<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>(); List<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
List<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
for (Account account : uuidToAccountMap.values()) { for (Account account : uuidToAccountMap.values()) {
Set<Long> deviceIds = Arrays.stream(multiRecipientMessage.getRecipients()) Set<Long> deviceIds = Arrays.stream(multiRecipientMessage.getRecipients())
.filter(recipient -> recipient.getUuid().equals(account.getUuid())) .filter(recipient -> recipient.getUuid().equals(account.getUuid()))
.map(Recipient::getDeviceId) .map(Recipient::getDeviceId)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = Arrays.stream(multiRecipientMessage.getRecipients())
.filter(recipient -> recipient.getUuid().equals(account.getUuid()))
.map(recipient -> new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()));
try { try {
validateCompleteDeviceList(account, deviceIds, false); validateCompleteDeviceList(account, deviceIds, false);
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
} catch (MismatchedDevicesException e) { } catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(new AccountStaleDevices(account.getUuid(), new StaleDevices(e.getStaleDevices())));
} }
} }
if (!accountMismatchedDevices.isEmpty()) { if (!accountMismatchedDevices.isEmpty()) {
@ -379,6 +389,13 @@ public class MessageController {
.entity(accountMismatchedDevices) .entity(accountMismatchedDevices)
.build(); .build();
} }
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
List<Tag> tags = List.of( List<Tag> tags = List.of(
UserAgentTagUtil.getPlatformTag(userAgent), UserAgentTagUtil.getPlatformTag(userAgent),
@ -639,21 +656,24 @@ public class MessageController {
} }
private void validateRegistrationIds(Account account, List<IncomingMessage> messages) private void validateRegistrationIds(Account account, List<IncomingMessage> messages)
throws StaleDevicesException throws StaleDevicesException {
{ final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = messages
List<Long> staleDevices = new LinkedList<>(); .stream()
.map(message -> new Pair<>(message.getDestinationDeviceId(), message.getDestinationRegistrationId()));
for (IncomingMessage message : messages) { validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
Optional<Device> device = account.getDevice(message.getDestinationDeviceId());
if (device.isPresent() &&
message.getDestinationRegistrationId() > 0 &&
message.getDestinationRegistrationId() != device.get().getRegistrationId())
{
staleDevices.add(device.get().getId());
}
} }
private void validateRegistrationIds(Account account, Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream)
throws StaleDevicesException {
final List<Long> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
Optional<Device> device = account.getDevice(deviceIdAndRegistrationId.first());
return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId();
})
.map(Pair::first)
.collect(Collectors.toList());
if (!staleDevices.isEmpty()) { if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices); throw new StaleDevicesException(staleDevices);
} }

View File

@ -0,0 +1,22 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.UUID;
public class AccountStaleDevices {
@JsonProperty
public final UUID uuid;
@JsonProperty
public final StaleDevices devices;
public AccountStaleDevices(final UUID uuid, final StaleDevices devices) {
this.uuid = uuid;
this.devices = devices;
}
}

View File

@ -6,6 +6,8 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import java.util.UUID; import java.util.UUID;
import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min; import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.validation.constraints.Size; import javax.validation.constraints.Size;
@ -21,13 +23,18 @@ public class MultiRecipientMessage {
@Min(1) @Min(1)
private final long deviceId; private final long deviceId;
@Min(0)
@Max(65535)
private final int registrationId;
@Size(min = 48, max = 48) @Size(min = 48, max = 48)
@NotNull @NotNull
private final byte[] perRecipientKeyMaterial; private final byte[] perRecipientKeyMaterial;
public Recipient(UUID uuid, long deviceId, byte[] perRecipientKeyMaterial) { public Recipient(UUID uuid, long deviceId, int registrationId, byte[] perRecipientKeyMaterial) {
this.uuid = uuid; this.uuid = uuid;
this.deviceId = deviceId; this.deviceId = deviceId;
this.registrationId = registrationId;
this.perRecipientKeyMaterial = perRecipientKeyMaterial; this.perRecipientKeyMaterial = perRecipientKeyMaterial;
} }
@ -39,6 +46,10 @@ public class MultiRecipientMessage {
return deviceId; return deviceId;
} }
public int getRegistrationId() {
return registrationId;
}
public byte[] getPerRecipientKeyMaterial() { public byte[] getPerRecipientKeyMaterial() {
return perRecipientKeyMaterial; return perRecipientKeyMaterial;
} }
@ -46,6 +57,7 @@ public class MultiRecipientMessage {
@NotNull @NotNull
@Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT) @Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT)
@Valid
private final Recipient[] recipients; private final Recipient[] recipients;
@NotNull @NotNull

View File

@ -54,11 +54,12 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
for (int i = 0; i < Math.toIntExact(count); i++) { for (int i = 0; i < Math.toIntExact(count); i++) {
UUID uuid = readUuid(entityStream); UUID uuid = readUuid(entityStream);
long deviceId = readVarint(entityStream); long deviceId = readVarint(entityStream);
int registrationId = readU16(entityStream);
byte[] perRecipientKeyMaterial = entityStream.readNBytes(48); byte[] perRecipientKeyMaterial = entityStream.readNBytes(48);
if (perRecipientKeyMaterial.length != 48) { if (perRecipientKeyMaterial.length != 48) {
throw new IOException("Failed to read expected number of key material bytes for a recipient"); throw new IOException("Failed to read expected number of key material bytes for a recipient");
} }
recipients[i] = new MultiRecipientMessage.Recipient(uuid, deviceId, perRecipientKeyMaterial); recipients[i] = new MultiRecipientMessage.Recipient(uuid, deviceId, registrationId, perRecipientKeyMaterial);
} }
// caller is responsible for checking that the entity stream is at EOF when we return; if there are more bytes than // caller is responsible for checking that the entity stream is at EOF when we return; if there are more bytes than
@ -126,4 +127,20 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
} }
return result; return result;
} }
/**
* Reads two bytes with most significant byte first. Treats the value as unsigned so the range returned is
* {@code [0, 65535]}.
*/
private int readU16(InputStream stream) throws IOException {
int b1 = stream.read();
if (b1 == -1) {
throw new IOException("Missing byte 1 of U16");
}
int b2 = stream.read();
if (b2 == -1) {
throw new IOException("Missing byte 2 of U16");
}
return (b1 << 8) | b2;
}
} }