diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 0c1ee9e75..37e2999c2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -38,6 +38,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.validation.Valid; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; @@ -62,6 +63,7 @@ import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKey import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration; import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices; +import org.whispersystems.textsecuregcm.entities.AccountStaleDevices; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; @@ -94,6 +96,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.ForwardedIpUtil; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; @@ -360,16 +363,23 @@ public class MessageController { checkAccessKeys(accessKeys, uuidToAccountMap); List accountMismatchedDevices = new ArrayList<>(); + List accountStaleDevices = new ArrayList<>(); for (Account account : uuidToAccountMap.values()) { Set deviceIds = Arrays.stream(multiRecipientMessage.getRecipients()) .filter(recipient -> recipient.getUuid().equals(account.getUuid())) .map(Recipient::getDeviceId) .collect(Collectors.toSet()); + Stream> deviceIdAndRegistrationIdStream = Arrays.stream(multiRecipientMessage.getRecipients()) + .filter(recipient -> recipient.getUuid().equals(account.getUuid())) + .map(recipient -> new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId())); try { validateCompleteDeviceList(account, deviceIds, false); + validateRegistrationIds(account, deviceIdAndRegistrationIdStream); } catch (MismatchedDevicesException e) { accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); + } catch (StaleDevicesException e) { + accountStaleDevices.add(new AccountStaleDevices(account.getUuid(), new StaleDevices(e.getStaleDevices()))); } } if (!accountMismatchedDevices.isEmpty()) { @@ -379,6 +389,13 @@ public class MessageController { .entity(accountMismatchedDevices) .build(); } + if (!accountStaleDevices.isEmpty()) { + return Response + .status(410) + .type(MediaType.APPLICATION_JSON) + .entity(accountStaleDevices) + .build(); + } List tags = List.of( UserAgentTagUtil.getPlatformTag(userAgent), @@ -639,20 +656,23 @@ public class MessageController { } private void validateRegistrationIds(Account account, List messages) - throws StaleDevicesException - { - List staleDevices = new LinkedList<>(); + throws StaleDevicesException { + final Stream> deviceIdAndRegistrationIdStream = messages + .stream() + .map(message -> new Pair<>(message.getDestinationDeviceId(), message.getDestinationRegistrationId())); + validateRegistrationIds(account, deviceIdAndRegistrationIdStream); + } - for (IncomingMessage message : messages) { - Optional device = account.getDevice(message.getDestinationDeviceId()); - - if (device.isPresent() && - message.getDestinationRegistrationId() > 0 && - message.getDestinationRegistrationId() != device.get().getRegistrationId()) - { - staleDevices.add(device.get().getId()); - } - } + private void validateRegistrationIds(Account account, Stream> deviceIdAndRegistrationIdStream) + throws StaleDevicesException { + final List staleDevices = deviceIdAndRegistrationIdStream + .filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0) + .filter(deviceIdAndRegistrationId -> { + Optional device = account.getDevice(deviceIdAndRegistrationId.first()); + return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId(); + }) + .map(Pair::first) + .collect(Collectors.toList()); if (!staleDevices.isEmpty()) { throw new StaleDevicesException(staleDevices); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java new file mode 100644 index 000000000..35d818ea0 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java @@ -0,0 +1,22 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.UUID; + +public class AccountStaleDevices { + @JsonProperty + public final UUID uuid; + + @JsonProperty + public final StaleDevices devices; + + public AccountStaleDevices(final UUID uuid, final StaleDevices devices) { + this.uuid = uuid; + this.devices = devices; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java index 73ac76104..564d007f4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java @@ -6,6 +6,8 @@ package org.whispersystems.textsecuregcm.entities; import java.util.UUID; +import javax.validation.Valid; +import javax.validation.constraints.Max; import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; import javax.validation.constraints.Size; @@ -21,13 +23,18 @@ public class MultiRecipientMessage { @Min(1) private final long deviceId; + @Min(0) + @Max(65535) + private final int registrationId; + @Size(min = 48, max = 48) @NotNull private final byte[] perRecipientKeyMaterial; - public Recipient(UUID uuid, long deviceId, byte[] perRecipientKeyMaterial) { + public Recipient(UUID uuid, long deviceId, int registrationId, byte[] perRecipientKeyMaterial) { this.uuid = uuid; this.deviceId = deviceId; + this.registrationId = registrationId; this.perRecipientKeyMaterial = perRecipientKeyMaterial; } @@ -39,6 +46,10 @@ public class MultiRecipientMessage { return deviceId; } + public int getRegistrationId() { + return registrationId; + } + public byte[] getPerRecipientKeyMaterial() { return perRecipientKeyMaterial; } @@ -46,6 +57,7 @@ public class MultiRecipientMessage { @NotNull @Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT) + @Valid private final Recipient[] recipients; @NotNull diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java index 1f6ec12ed..da9029899 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java @@ -54,11 +54,12 @@ public class MultiRecipientMessageProvider implements MessageBodyReader