diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 37e2999c2..ae21def9b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -25,6 +25,8 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; +import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; import java.util.List; @@ -362,26 +364,34 @@ public class MessageController { })); checkAccessKeys(accessKeys, uuidToAccountMap); - List accountMismatchedDevices = new ArrayList<>(); - List accountStaleDevices = new ArrayList<>(); - for (Account account : uuidToAccountMap.values()) { - Set deviceIds = Arrays.stream(multiRecipientMessage.getRecipients()) - .filter(recipient -> recipient.getUuid().equals(account.getUuid())) - .map(Recipient::getDeviceId) - .collect(Collectors.toSet()); - Stream> deviceIdAndRegistrationIdStream = Arrays.stream(multiRecipientMessage.getRecipients()) - .filter(recipient -> recipient.getUuid().equals(account.getUuid())) - .map(recipient -> new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId())); + final Map>> accountToDeviceIdAndRegistrationIdMap = + Arrays + .stream(multiRecipientMessage.getRecipients()) + .collect(Collectors.toMap( + recipient -> uuidToAccountMap.get(recipient.getUuid()), + recipient -> new HashSet<>( + Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), + (a, b) -> { + a.addAll(b); + return a; + } + )); + + Collection accountMismatchedDevices = new ArrayList<>(); + Collection accountStaleDevices = new ArrayList<>(); + uuidToAccountMap.values().forEach(account -> { + final Set> deviceIdAndRegistrationIdSet = accountToDeviceIdAndRegistrationIdMap.get(account); + final Set deviceIds = deviceIdAndRegistrationIdSet.stream().map(Pair::first).collect(Collectors.toSet()); try { validateCompleteDeviceList(account, deviceIds, false); - validateRegistrationIds(account, deviceIdAndRegistrationIdStream); + validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream()); } catch (MismatchedDevicesException e) { accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); } catch (StaleDevicesException e) { accountStaleDevices.add(new AccountStaleDevices(account.getUuid(), new StaleDevices(e.getStaleDevices()))); } - } + }); if (!accountMismatchedDevices.isEmpty()) { return Response .status(409)