Handle duplicate device ids more gracefully

This commit is contained in:
Ravi Khadiwala 2022-07-27 10:01:43 -05:00 committed by ravi-signal
parent 98760b631b
commit 36050f580e
4 changed files with 117 additions and 60 deletions

View File

@ -22,6 +22,7 @@ import java.util.Base64;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -90,6 +91,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
@ -226,9 +228,9 @@ public class MessageController {
excludedDeviceIds); excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(destination.get(), DestinationDeviceValidator.validateRegistrationIds(destination.get(),
messages.getMessages().stream().collect(Collectors.toMap( messages.getMessages(),
IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationDeviceId,
IncomingMessage::getDestinationRegistrationId)), IncomingMessage::getDestinationRegistrationId,
destination.get().getPhoneNumberIdentifier().equals(destinationUuid)); destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
@ -333,9 +335,9 @@ public class MessageController {
DestinationDeviceValidator.validateRegistrationIds( DestinationDeviceValidator.validateRegistrationIds(
destination.get(), destination.get(),
Arrays.stream(messages).collect(Collectors.toMap( Arrays.stream(messages).toList(),
IncomingDeviceMessage::getDeviceId, IncomingDeviceMessage::getDeviceId,
IncomingDeviceMessage::getRegistrationId)), IncomingDeviceMessage::getRegistrationId,
destination.get().getPhoneNumberIdentifier().equals(destinationUuid)); destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
@ -395,29 +397,33 @@ public class MessageController {
})); }));
checkAccessKeys(accessKeys, uuidToAccountMap); checkAccessKeys(accessKeys, uuidToAccountMap);
final Map<Account, Map<Long, Integer>> accountToDeviceIdAndRegistrationIdMap = Arrays.stream(multiRecipientMessage.getRecipients()) final Map<Account, HashSet<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap =
.collect(Collectors.toMap( Arrays
recipient -> uuidToAccountMap.get(recipient.getUuid()), .stream(multiRecipientMessage.getRecipients())
recipient -> Map.of(recipient.getDeviceId(), recipient.getRegistrationId()), .collect(Collectors.toMap(
(a, b) -> { recipient -> uuidToAccountMap.get(recipient.getUuid()),
final Map<Long, Integer> combined = new HashMap<>(); recipient -> new HashSet<>(
combined.putAll(a); Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
combined.putAll(b); (a, b) -> {
a.addAll(b);
return combined; return a;
} }
)); ));
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>(); Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>(); Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
uuidToAccountMap.values().forEach(account -> { uuidToAccountMap.values().forEach(account -> {
final Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).keySet(); final Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).stream().map(Pair::first)
.collect(Collectors.toSet());
try { try {
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet()); DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet());
// Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number // Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number
// identity // identity
DestinationDeviceValidator.validateRegistrationIds(account, accountToDeviceIdAndRegistrationIdMap.get(account), false); DestinationDeviceValidator.validateRegistrationIds(
account,
accountToDeviceIdAndRegistrationIdMap.get(account).stream(),
false);
} catch (MismatchedDevicesException e) { } catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));

View File

@ -25,6 +25,7 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.util.Pair;
public class ChangeNumberManager { public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(AccountController.class); private static final Logger logger = LoggerFactory.getLogger(AccountController.class);
@ -59,10 +60,9 @@ public class ChangeNumberManager {
DestinationDeviceValidator.validateRegistrationIds( DestinationDeviceValidator.validateRegistrationIds(
account, account,
deviceMessages.stream() deviceMessages,
.collect(Collectors.toMap( IncomingMessage::getDestinationDeviceId,
IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId,
IncomingMessage::getDestinationRegistrationId)),
false); false);
} else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) { } else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null"); throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null");

View File

@ -5,11 +5,14 @@
package org.whispersystems.textsecuregcm.util; package org.whispersystems.textsecuregcm.util;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -17,39 +20,50 @@ import org.whispersystems.textsecuregcm.storage.Device;
public class DestinationDeviceValidator { public class DestinationDeviceValidator {
/**
* @see #validateRegistrationIds(Account, Stream, boolean)
*/
public static <T> void validateRegistrationIds(final Account account, final Collection<T> messages,
Function<T, Long> getDeviceId, Function<T, Integer> getRegistrationId, boolean usePhoneNumberIdentity)
throws StaleDevicesException {
validateRegistrationIds(account,
messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))),
usePhoneNumberIdentity);
}
/** /**
* Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID * Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID
* pairs in the given destination account. This method does <em>not</em> validate that all devices associated with the * pairs in the given destination account. This method does <em>not</em> validate that all devices associated with the
* destination account are present in the given device ID/registration ID pairs. * destination account are present in the given device ID/registration ID pairs.
* *
* @param account the destination account against which to check the given device ID/registration ID pairs * @param account the destination account against which to check the given device
* @param registrationIdsByDeviceId a map of device IDs to registration IDs * ID/registration ID pairs
* @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device registration IDs * @param deviceIdAndRegistrationIdStream a stream of device ID and registration ID pairs
* associated with the account's PNI (if available); compare against the ACI-associated * @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device
* registration ID otherwise * registration IDs associated with the account's PNI (if available); compare
* * against the ACI-associated registration ID otherwise
* @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination * @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination
* account does not have a corresponding device or if the registration IDs do not match * account does not have a corresponding device or if the registration IDs do not match
*/ */
public static void validateRegistrationIds(final Account account, public static void validateRegistrationIds(final Account account,
final Map<Long, Integer> registrationIdsByDeviceId, final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream,
final boolean usePhoneNumberIdentity) throws StaleDevicesException { final boolean usePhoneNumberIdentity) throws StaleDevicesException {
final List<Long> staleDevices = new ArrayList<>(); final List<Long> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> { .filter(deviceIdAndRegistrationId -> {
if (registrationId > 0) { final long deviceId = deviceIdAndRegistrationId.first();
final boolean registrationIdMatches = final int registrationId = deviceIdAndRegistrationId.second();
account.getDevice(deviceId).map(device -> registrationId == (usePhoneNumberIdentity ? boolean registrationIdMatches = account.getDevice(deviceId)
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) : .map(device -> registrationId == (usePhoneNumberIdentity
device.getRegistrationId())) ? device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId())
.orElse(false); : device.getRegistrationId()))
.orElse(false);
if (!registrationIdMatches) { return !registrationIdMatches;
staleDevices.add(deviceId); })
} .map(Pair::first)
} .collect(Collectors.toList());
});
if (!staleDevices.isEmpty()) { if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices); throw new StaleDevicesException(staleDevices);
@ -63,11 +77,10 @@ public class DestinationDeviceValidator {
* "sync," message, though, the authenticated account is sending messages from one of their devices to all other * "sync," message, though, the authenticated account is sending messages from one of their devices to all other
* devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}. * devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}.
* *
* @param account the destination account against which to check the given set of device IDs * @param account the destination account against which to check the given set of device IDs
* @param messageDeviceIds the set of device IDs to check against the destination account * @param messageDeviceIds the set of device IDs to check against the destination account
* @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be * @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be
* present in the given set of device IDs (i.e. the device that is sending a sync message) * present in the given set of device IDs (i.e. the device that is sending a sync message)
*
* @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with * @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with
* the destination account or is missing entries associated with the destination * the destination account or is missing entries associated with the destination
* account * account

View File

@ -84,10 +84,17 @@ class DestinationDeviceValidatorTest {
Set<Long> expectedStaleDeviceIds) throws Exception { Set<Long> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) { if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class, Assertions.assertThat(assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId, false)).getStaleDevices()) () -> DestinationDeviceValidator.validateRegistrationIds(
account,
registrationIdsByDeviceId.entrySet(),
Map.Entry::getKey,
Map.Entry::getValue,
false))
.getStaleDevices())
.hasSameElementsAs(expectedStaleDeviceIds); .hasSameElementsAs(expectedStaleDeviceIds);
} else { } else {
DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId, false); DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId.entrySet(),
Map.Entry::getKey, Map.Entry::getValue, false);
} }
} }
@ -183,6 +190,18 @@ class DestinationDeviceValidatorTest {
} }
} }
@Test
void testDuplicateDeviceIds() {
final Account account = mockAccountWithDeviceAndRegId(Map.of(Device.MASTER_ID, 17));
try {
DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.MASTER_ID, 16), new Pair<>(Device.MASTER_ID, 17)), false);
Assertions.fail("duplicate devices should throw StaleDevicesException");
} catch (StaleDevicesException e) {
Assertions.assertThat(e.getStaleDevices()).hasSameElementsAs(Collections.singletonList(Device.MASTER_ID));
}
}
@Test @Test
void testValidatePniRegistrationIds() { void testValidatePniRegistrationIds() {
final Device device = mock(Device.class); final Device device = mock(Device.class);
@ -199,16 +218,35 @@ class DestinationDeviceValidatorTest {
when(device.getRegistrationId()).thenReturn(aciRegistrationId); when(device.getRegistrationId()).thenReturn(aciRegistrationId);
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId)); when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId));
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), false)); assertDoesNotThrow(
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, pniRegistrationId), true)); () -> DestinationDeviceValidator.validateRegistrationIds(account,
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), true)); Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)), false));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, pniRegistrationId), false)); assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.MASTER_ID, pniRegistrationId)),
true));
assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)),
true));
assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.MASTER_ID, pniRegistrationId)),
false));
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty()); when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), false)); assertDoesNotThrow(
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), true)); () -> DestinationDeviceValidator.validateRegistrationIds(account,
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, incorrectRegistrationId), true)); Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)),
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, incorrectRegistrationId), false)); false));
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)),
true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.MASTER_ID, incorrectRegistrationId)), true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.MASTER_ID, incorrectRegistrationId)), false));
} }
} }