Handle duplicate device ids more gracefully
This commit is contained in:
parent
98760b631b
commit
36050f580e
|
@ -22,6 +22,7 @@ import java.util.Base64;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
@ -90,6 +91,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
|
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
|
||||||
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
|
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
|
||||||
|
import org.whispersystems.textsecuregcm.util.Pair;
|
||||||
import org.whispersystems.textsecuregcm.util.Util;
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||||
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
|
||||||
|
@ -226,9 +228,9 @@ public class MessageController {
|
||||||
excludedDeviceIds);
|
excludedDeviceIds);
|
||||||
|
|
||||||
DestinationDeviceValidator.validateRegistrationIds(destination.get(),
|
DestinationDeviceValidator.validateRegistrationIds(destination.get(),
|
||||||
messages.getMessages().stream().collect(Collectors.toMap(
|
messages.getMessages(),
|
||||||
IncomingMessage::getDestinationDeviceId,
|
IncomingMessage::getDestinationDeviceId,
|
||||||
IncomingMessage::getDestinationRegistrationId)),
|
IncomingMessage::getDestinationRegistrationId,
|
||||||
destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
|
destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
|
||||||
|
|
||||||
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
|
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
|
||||||
|
@ -333,9 +335,9 @@ public class MessageController {
|
||||||
|
|
||||||
DestinationDeviceValidator.validateRegistrationIds(
|
DestinationDeviceValidator.validateRegistrationIds(
|
||||||
destination.get(),
|
destination.get(),
|
||||||
Arrays.stream(messages).collect(Collectors.toMap(
|
Arrays.stream(messages).toList(),
|
||||||
IncomingDeviceMessage::getDeviceId,
|
IncomingDeviceMessage::getDeviceId,
|
||||||
IncomingDeviceMessage::getRegistrationId)),
|
IncomingDeviceMessage::getRegistrationId,
|
||||||
destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
|
destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
|
||||||
|
|
||||||
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
|
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
|
||||||
|
@ -395,29 +397,33 @@ public class MessageController {
|
||||||
}));
|
}));
|
||||||
checkAccessKeys(accessKeys, uuidToAccountMap);
|
checkAccessKeys(accessKeys, uuidToAccountMap);
|
||||||
|
|
||||||
final Map<Account, Map<Long, Integer>> accountToDeviceIdAndRegistrationIdMap = Arrays.stream(multiRecipientMessage.getRecipients())
|
final Map<Account, HashSet<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap =
|
||||||
.collect(Collectors.toMap(
|
Arrays
|
||||||
recipient -> uuidToAccountMap.get(recipient.getUuid()),
|
.stream(multiRecipientMessage.getRecipients())
|
||||||
recipient -> Map.of(recipient.getDeviceId(), recipient.getRegistrationId()),
|
.collect(Collectors.toMap(
|
||||||
(a, b) -> {
|
recipient -> uuidToAccountMap.get(recipient.getUuid()),
|
||||||
final Map<Long, Integer> combined = new HashMap<>();
|
recipient -> new HashSet<>(
|
||||||
combined.putAll(a);
|
Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
|
||||||
combined.putAll(b);
|
(a, b) -> {
|
||||||
|
a.addAll(b);
|
||||||
return combined;
|
return a;
|
||||||
}
|
}
|
||||||
));
|
));
|
||||||
|
|
||||||
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
|
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
|
||||||
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
|
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
|
||||||
uuidToAccountMap.values().forEach(account -> {
|
uuidToAccountMap.values().forEach(account -> {
|
||||||
final Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).keySet();
|
final Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).stream().map(Pair::first)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
try {
|
try {
|
||||||
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet());
|
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet());
|
||||||
|
|
||||||
// Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number
|
// Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number
|
||||||
// identity
|
// identity
|
||||||
DestinationDeviceValidator.validateRegistrationIds(account, accountToDeviceIdAndRegistrationIdMap.get(account), false);
|
DestinationDeviceValidator.validateRegistrationIds(
|
||||||
|
account,
|
||||||
|
accountToDeviceIdAndRegistrationIdMap.get(account).stream(),
|
||||||
|
false);
|
||||||
} catch (MismatchedDevicesException e) {
|
} catch (MismatchedDevicesException e) {
|
||||||
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
|
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
|
||||||
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
|
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||||
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
|
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
|
||||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||||
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
|
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
|
||||||
|
import org.whispersystems.textsecuregcm.util.Pair;
|
||||||
|
|
||||||
public class ChangeNumberManager {
|
public class ChangeNumberManager {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(AccountController.class);
|
private static final Logger logger = LoggerFactory.getLogger(AccountController.class);
|
||||||
|
@ -59,10 +60,9 @@ public class ChangeNumberManager {
|
||||||
|
|
||||||
DestinationDeviceValidator.validateRegistrationIds(
|
DestinationDeviceValidator.validateRegistrationIds(
|
||||||
account,
|
account,
|
||||||
deviceMessages.stream()
|
deviceMessages,
|
||||||
.collect(Collectors.toMap(
|
IncomingMessage::getDestinationDeviceId,
|
||||||
IncomingMessage::getDestinationDeviceId,
|
IncomingMessage::getDestinationRegistrationId,
|
||||||
IncomingMessage::getDestinationRegistrationId)),
|
|
||||||
false);
|
false);
|
||||||
} else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
|
} else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
|
||||||
throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null");
|
throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null");
|
||||||
|
|
|
@ -5,11 +5,14 @@
|
||||||
package org.whispersystems.textsecuregcm.util;
|
package org.whispersystems.textsecuregcm.util;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.function.Function;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||||
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
|
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
|
||||||
import org.whispersystems.textsecuregcm.storage.Account;
|
import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
|
@ -17,39 +20,50 @@ import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
|
|
||||||
public class DestinationDeviceValidator {
|
public class DestinationDeviceValidator {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @see #validateRegistrationIds(Account, Stream, boolean)
|
||||||
|
*/
|
||||||
|
public static <T> void validateRegistrationIds(final Account account, final Collection<T> messages,
|
||||||
|
Function<T, Long> getDeviceId, Function<T, Integer> getRegistrationId, boolean usePhoneNumberIdentity)
|
||||||
|
throws StaleDevicesException {
|
||||||
|
validateRegistrationIds(account,
|
||||||
|
messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))),
|
||||||
|
usePhoneNumberIdentity);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID
|
* Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID
|
||||||
* pairs in the given destination account. This method does <em>not</em> validate that all devices associated with the
|
* pairs in the given destination account. This method does <em>not</em> validate that all devices associated with the
|
||||||
* destination account are present in the given device ID/registration ID pairs.
|
* destination account are present in the given device ID/registration ID pairs.
|
||||||
*
|
*
|
||||||
* @param account the destination account against which to check the given device ID/registration ID pairs
|
* @param account the destination account against which to check the given device
|
||||||
* @param registrationIdsByDeviceId a map of device IDs to registration IDs
|
* ID/registration ID pairs
|
||||||
* @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device registration IDs
|
* @param deviceIdAndRegistrationIdStream a stream of device ID and registration ID pairs
|
||||||
* associated with the account's PNI (if available); compare against the ACI-associated
|
* @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device
|
||||||
* registration ID otherwise
|
* registration IDs associated with the account's PNI (if available); compare
|
||||||
*
|
* against the ACI-associated registration ID otherwise
|
||||||
* @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination
|
* @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination
|
||||||
* account does not have a corresponding device or if the registration IDs do not match
|
* account does not have a corresponding device or if the registration IDs do not match
|
||||||
*/
|
*/
|
||||||
public static void validateRegistrationIds(final Account account,
|
public static void validateRegistrationIds(final Account account,
|
||||||
final Map<Long, Integer> registrationIdsByDeviceId,
|
final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream,
|
||||||
final boolean usePhoneNumberIdentity) throws StaleDevicesException {
|
final boolean usePhoneNumberIdentity) throws StaleDevicesException {
|
||||||
|
|
||||||
final List<Long> staleDevices = new ArrayList<>();
|
final List<Long> staleDevices = deviceIdAndRegistrationIdStream
|
||||||
|
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
|
||||||
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> {
|
.filter(deviceIdAndRegistrationId -> {
|
||||||
if (registrationId > 0) {
|
final long deviceId = deviceIdAndRegistrationId.first();
|
||||||
final boolean registrationIdMatches =
|
final int registrationId = deviceIdAndRegistrationId.second();
|
||||||
account.getDevice(deviceId).map(device -> registrationId == (usePhoneNumberIdentity ?
|
boolean registrationIdMatches = account.getDevice(deviceId)
|
||||||
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
|
.map(device -> registrationId == (usePhoneNumberIdentity
|
||||||
device.getRegistrationId()))
|
? device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId())
|
||||||
.orElse(false);
|
: device.getRegistrationId()))
|
||||||
|
.orElse(false);
|
||||||
if (!registrationIdMatches) {
|
return !registrationIdMatches;
|
||||||
staleDevices.add(deviceId);
|
})
|
||||||
}
|
.map(Pair::first)
|
||||||
}
|
.collect(Collectors.toList());
|
||||||
});
|
|
||||||
|
|
||||||
if (!staleDevices.isEmpty()) {
|
if (!staleDevices.isEmpty()) {
|
||||||
throw new StaleDevicesException(staleDevices);
|
throw new StaleDevicesException(staleDevices);
|
||||||
|
@ -63,11 +77,10 @@ public class DestinationDeviceValidator {
|
||||||
* "sync," message, though, the authenticated account is sending messages from one of their devices to all other
|
* "sync," message, though, the authenticated account is sending messages from one of their devices to all other
|
||||||
* devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}.
|
* devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}.
|
||||||
*
|
*
|
||||||
* @param account the destination account against which to check the given set of device IDs
|
* @param account the destination account against which to check the given set of device IDs
|
||||||
* @param messageDeviceIds the set of device IDs to check against the destination account
|
* @param messageDeviceIds the set of device IDs to check against the destination account
|
||||||
* @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be
|
* @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be
|
||||||
* present in the given set of device IDs (i.e. the device that is sending a sync message)
|
* present in the given set of device IDs (i.e. the device that is sending a sync message)
|
||||||
*
|
|
||||||
* @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with
|
* @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with
|
||||||
* the destination account or is missing entries associated with the destination
|
* the destination account or is missing entries associated with the destination
|
||||||
* account
|
* account
|
||||||
|
|
|
@ -84,10 +84,17 @@ class DestinationDeviceValidatorTest {
|
||||||
Set<Long> expectedStaleDeviceIds) throws Exception {
|
Set<Long> expectedStaleDeviceIds) throws Exception {
|
||||||
if (expectedStaleDeviceIds != null) {
|
if (expectedStaleDeviceIds != null) {
|
||||||
Assertions.assertThat(assertThrows(StaleDevicesException.class,
|
Assertions.assertThat(assertThrows(StaleDevicesException.class,
|
||||||
() -> DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId, false)).getStaleDevices())
|
() -> DestinationDeviceValidator.validateRegistrationIds(
|
||||||
|
account,
|
||||||
|
registrationIdsByDeviceId.entrySet(),
|
||||||
|
Map.Entry::getKey,
|
||||||
|
Map.Entry::getValue,
|
||||||
|
false))
|
||||||
|
.getStaleDevices())
|
||||||
.hasSameElementsAs(expectedStaleDeviceIds);
|
.hasSameElementsAs(expectedStaleDeviceIds);
|
||||||
} else {
|
} else {
|
||||||
DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId, false);
|
DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId.entrySet(),
|
||||||
|
Map.Entry::getKey, Map.Entry::getValue, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,6 +190,18 @@ class DestinationDeviceValidatorTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testDuplicateDeviceIds() {
|
||||||
|
final Account account = mockAccountWithDeviceAndRegId(Map.of(Device.MASTER_ID, 17));
|
||||||
|
try {
|
||||||
|
DestinationDeviceValidator.validateRegistrationIds(account,
|
||||||
|
Stream.of(new Pair<>(Device.MASTER_ID, 16), new Pair<>(Device.MASTER_ID, 17)), false);
|
||||||
|
Assertions.fail("duplicate devices should throw StaleDevicesException");
|
||||||
|
} catch (StaleDevicesException e) {
|
||||||
|
Assertions.assertThat(e.getStaleDevices()).hasSameElementsAs(Collections.singletonList(Device.MASTER_ID));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testValidatePniRegistrationIds() {
|
void testValidatePniRegistrationIds() {
|
||||||
final Device device = mock(Device.class);
|
final Device device = mock(Device.class);
|
||||||
|
@ -199,16 +218,35 @@ class DestinationDeviceValidatorTest {
|
||||||
when(device.getRegistrationId()).thenReturn(aciRegistrationId);
|
when(device.getRegistrationId()).thenReturn(aciRegistrationId);
|
||||||
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId));
|
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId));
|
||||||
|
|
||||||
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), false));
|
assertDoesNotThrow(
|
||||||
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, pniRegistrationId), true));
|
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||||
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), true));
|
Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)), false));
|
||||||
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, pniRegistrationId), false));
|
assertDoesNotThrow(
|
||||||
|
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||||
|
Stream.of(new Pair<>(Device.MASTER_ID, pniRegistrationId)),
|
||||||
|
true));
|
||||||
|
assertThrows(StaleDevicesException.class,
|
||||||
|
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||||
|
Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)),
|
||||||
|
true));
|
||||||
|
assertThrows(StaleDevicesException.class,
|
||||||
|
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||||
|
Stream.of(new Pair<>(Device.MASTER_ID, pniRegistrationId)),
|
||||||
|
false));
|
||||||
|
|
||||||
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
|
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
|
||||||
|
|
||||||
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), false));
|
assertDoesNotThrow(
|
||||||
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), true));
|
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||||
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, incorrectRegistrationId), true));
|
Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)),
|
||||||
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, incorrectRegistrationId), false));
|
false));
|
||||||
|
assertDoesNotThrow(
|
||||||
|
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||||
|
Stream.of(new Pair<>(Device.MASTER_ID, aciRegistrationId)),
|
||||||
|
true));
|
||||||
|
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||||
|
Stream.of(new Pair<>(Device.MASTER_ID, incorrectRegistrationId)), true));
|
||||||
|
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||||
|
Stream.of(new Pair<>(Device.MASTER_ID, incorrectRegistrationId)), false));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue