diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index b61f5c298..3a43eaf47 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -22,6 +22,7 @@ import java.util.Base64; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -90,6 +91,7 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; @@ -226,9 +228,9 @@ public class MessageController { excludedDeviceIds); DestinationDeviceValidator.validateRegistrationIds(destination.get(), - messages.getMessages().stream().collect(Collectors.toMap( - IncomingMessage::getDestinationDeviceId, - IncomingMessage::getDestinationRegistrationId)), + messages.getMessages(), + IncomingMessage::getDestinationDeviceId, + IncomingMessage::getDestinationRegistrationId, destination.get().getPhoneNumberIdentifier().equals(destinationUuid)); final List tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), @@ -333,9 +335,9 @@ public class MessageController { DestinationDeviceValidator.validateRegistrationIds( destination.get(), - Arrays.stream(messages).collect(Collectors.toMap( - IncomingDeviceMessage::getDeviceId, - IncomingDeviceMessage::getRegistrationId)), + Arrays.stream(messages).toList(), + IncomingDeviceMessage::getDeviceId, + IncomingDeviceMessage::getRegistrationId, destination.get().getPhoneNumberIdentifier().equals(destinationUuid)); final List tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), @@ -395,29 +397,33 @@ public class MessageController { })); checkAccessKeys(accessKeys, uuidToAccountMap); - final Map> accountToDeviceIdAndRegistrationIdMap = Arrays.stream(multiRecipientMessage.getRecipients()) - .collect(Collectors.toMap( - recipient -> uuidToAccountMap.get(recipient.getUuid()), - recipient -> Map.of(recipient.getDeviceId(), recipient.getRegistrationId()), - (a, b) -> { - final Map combined = new HashMap<>(); - combined.putAll(a); - combined.putAll(b); - - return combined; - } - )); + final Map>> accountToDeviceIdAndRegistrationIdMap = + Arrays + .stream(multiRecipientMessage.getRecipients()) + .collect(Collectors.toMap( + recipient -> uuidToAccountMap.get(recipient.getUuid()), + recipient -> new HashSet<>( + Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), + (a, b) -> { + a.addAll(b); + return a; + } + )); Collection accountMismatchedDevices = new ArrayList<>(); Collection accountStaleDevices = new ArrayList<>(); uuidToAccountMap.values().forEach(account -> { - final Set deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).keySet(); + final Set deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).stream().map(Pair::first) + .collect(Collectors.toSet()); try { DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet()); // Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number // identity - DestinationDeviceValidator.validateRegistrationIds(account, accountToDeviceIdAndRegistrationIdMap.get(account), false); + DestinationDeviceValidator.validateRegistrationIds( + account, + accountToDeviceIdAndRegistrationIdMap.get(account).stream(), + false); } catch (MismatchedDevicesException e) { accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index b87168cad..5733822c7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -25,6 +25,7 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; +import org.whispersystems.textsecuregcm.util.Pair; public class ChangeNumberManager { private static final Logger logger = LoggerFactory.getLogger(AccountController.class); @@ -59,10 +60,9 @@ public class ChangeNumberManager { DestinationDeviceValidator.validateRegistrationIds( account, - deviceMessages.stream() - .collect(Collectors.toMap( - IncomingMessage::getDestinationDeviceId, - IncomingMessage::getDestinationRegistrationId)), + deviceMessages, + IncomingMessage::getDestinationDeviceId, + IncomingMessage::getDestinationRegistrationId, false); } else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) { throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null"); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java index 3e44a8f7c..1c26c731e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java @@ -5,11 +5,14 @@ package org.whispersystems.textsecuregcm.util; import java.util.ArrayList; +import java.util.Collection; import java.util.HashSet; import java.util.List; -import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.storage.Account; @@ -17,39 +20,50 @@ import org.whispersystems.textsecuregcm.storage.Device; public class DestinationDeviceValidator { + /** + * @see #validateRegistrationIds(Account, Stream, boolean) + */ + public static void validateRegistrationIds(final Account account, final Collection messages, + Function getDeviceId, Function getRegistrationId, boolean usePhoneNumberIdentity) + throws StaleDevicesException { + validateRegistrationIds(account, + messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))), + usePhoneNumberIdentity); + + } + /** * Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID * pairs in the given destination account. This method does not validate that all devices associated with the * destination account are present in the given device ID/registration ID pairs. * - * @param account the destination account against which to check the given device ID/registration ID pairs - * @param registrationIdsByDeviceId a map of device IDs to registration IDs - * @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device registration IDs - * associated with the account's PNI (if available); compare against the ACI-associated - * registration ID otherwise - * + * @param account the destination account against which to check the given device + * ID/registration ID pairs + * @param deviceIdAndRegistrationIdStream a stream of device ID and registration ID pairs + * @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device + * registration IDs associated with the account's PNI (if available); compare + * against the ACI-associated registration ID otherwise * @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination * account does not have a corresponding device or if the registration IDs do not match */ public static void validateRegistrationIds(final Account account, - final Map registrationIdsByDeviceId, + final Stream> deviceIdAndRegistrationIdStream, final boolean usePhoneNumberIdentity) throws StaleDevicesException { - final List staleDevices = new ArrayList<>(); - - registrationIdsByDeviceId.forEach((deviceId, registrationId) -> { - if (registrationId > 0) { - final boolean registrationIdMatches = - account.getDevice(deviceId).map(device -> registrationId == (usePhoneNumberIdentity ? - device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) : - device.getRegistrationId())) - .orElse(false); - - if (!registrationIdMatches) { - staleDevices.add(deviceId); - } - } - }); + final List staleDevices = deviceIdAndRegistrationIdStream + .filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0) + .filter(deviceIdAndRegistrationId -> { + final long deviceId = deviceIdAndRegistrationId.first(); + final int registrationId = deviceIdAndRegistrationId.second(); + boolean registrationIdMatches = account.getDevice(deviceId) + .map(device -> registrationId == (usePhoneNumberIdentity + ? device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) + : device.getRegistrationId())) + .orElse(false); + return !registrationIdMatches; + }) + .map(Pair::first) + .collect(Collectors.toList()); if (!staleDevices.isEmpty()) { throw new StaleDevicesException(staleDevices); @@ -63,11 +77,10 @@ public class DestinationDeviceValidator { * "sync," message, though, the authenticated account is sending messages from one of their devices to all other * devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}. * - * @param account the destination account against which to check the given set of device IDs - * @param messageDeviceIds the set of device IDs to check against the destination account + * @param account the destination account against which to check the given set of device IDs + * @param messageDeviceIds the set of device IDs to check against the destination account * @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be * present in the given set of device IDs (i.e. the device that is sending a sync message) - * * @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with * the destination account or is missing entries associated with the destination * account diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java index a9eaf1c57..2ad4e7837 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java @@ -84,10 +84,17 @@ class DestinationDeviceValidatorTest { Set expectedStaleDeviceIds) throws Exception { if (expectedStaleDeviceIds != null) { Assertions.assertThat(assertThrows(StaleDevicesException.class, - () -> DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId, false)).getStaleDevices()) + () -> DestinationDeviceValidator.validateRegistrationIds( + account, + registrationIdsByDeviceId.entrySet(), + Map.Entry::getKey, + Map.Entry::getValue, + false)) + .getStaleDevices()) .hasSameElementsAs(expectedStaleDeviceIds); } else { - DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId, false); + DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId.entrySet(), + Map.Entry::getKey, Map.Entry::getValue, false); } } @@ -183,6 +190,18 @@ class DestinationDeviceValidatorTest { } } + @Test + void testDuplicateDeviceIds() { + final Account account = mockAccountWithDeviceAndRegId(Map.of(Device.MASTER_ID, 17)); + try { + DestinationDeviceValidator.validateRegistrationIds(account, + Stream.of(new Pair<>(Device.MASTER_ID, 16), new Pair<>(Device.MASTER_ID, 17)), false); + Assertions.fail("duplicate devices should throw StaleDevicesException"); + } catch (StaleDevicesException e) { + Assertions.assertThat(e.getStaleDevices()).hasSameElementsAs(Collections.singletonList(Device.MASTER_ID)); + } + } + @Test void testValidatePniRegistrationIds() { final Device device = mock(Device.class); @@ -199,16 +218,35 @@ class DestinationDeviceValidatorTest { when(device.getRegistrationId()).thenReturn(aciRegistrationId); when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId)); - assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), false)); - assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, pniRegistrationId), true)); - assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), true)); - assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, pniRegistrationId), false)); + assertDoesNotThrow( + () -> DestinationDeviceValidator.validateRegistrationIds(account, + Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)), false)); + assertDoesNotThrow( + () -> DestinationDeviceValidator.validateRegistrationIds(account, + Stream.of(new Pair<>(Device.MASTER_ID, pniRegistrationId)), + true)); + assertThrows(StaleDevicesException.class, + () -> DestinationDeviceValidator.validateRegistrationIds(account, + Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)), + true)); + assertThrows(StaleDevicesException.class, + () -> DestinationDeviceValidator.validateRegistrationIds(account, + Stream.of(new Pair<>(Device.MASTER_ID, pniRegistrationId)), + false)); when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty()); - assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), false)); - assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), true)); - assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, incorrectRegistrationId), true)); - assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, incorrectRegistrationId), false)); + assertDoesNotThrow( + () -> DestinationDeviceValidator.validateRegistrationIds(account, + Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)), + false)); + assertDoesNotThrow( + () -> DestinationDeviceValidator.validateRegistrationIds(account, + Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)), + true)); + assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, + Stream.of(new Pair<>(Device.MASTER_ID, incorrectRegistrationId)), true)); + assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, + Stream.of(new Pair<>(Device.MASTER_ID, incorrectRegistrationId)), false)); } }