From dce391a248f0fd7e71b1a07c734776dfa1e4c8cb Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Tue, 26 Jul 2022 15:19:27 -0400 Subject: [PATCH] Add support for setting PNI-associated registration IDs and identity keys when changing numbers --- .../controllers/AccountController.java | 84 +++---- .../controllers/DeviceController.java | 1 + .../controllers/KeysController.java | 6 +- .../controllers/MessageController.java | 87 ++++--- .../entities/AccountAttributes.java | 10 + .../entities/ChangePhoneNumberRequest.java | 70 +----- .../entities/OutgoingMessageEntity.java | 31 ++- .../storage/AccountsManager.java | 46 +++- .../storage/ChangeNumberManager.java | 72 ++++-- .../textsecuregcm/storage/Device.java | 13 + .../textsecuregcm/storage/MessagesCache.java | 1 + .../storage/MessagesDynamoDb.java | 13 +- .../util/DestinationDeviceValidator.java | 95 ++++++++ .../textsecuregcm/util/MessageValidation.java | 84 ------- .../websocket/WebSocketConnection.java | 5 + service/src/main/proto/TextSecure.proto | 2 + .../controllers/MessageControllerTest.java | 106 ++++++-- ...ntsManagerChangeNumberIntegrationTest.java | 65 ++++- .../storage/AccountsManagerTest.java | 13 +- .../storage/ChangeNumberManagerTest.java | 149 ++++++++++-- .../controllers/AccountControllerTest.java | 147 +++--------- .../tests/controllers/KeysControllerTest.java | 31 ++- .../tests/util/MessageValidationTest.java | 227 ------------------ .../util/DestinationDeviceValidatorTest.java | 214 +++++++++++++++++ .../websocket/WebSocketConnectionTest.java | 12 +- .../current_message_multi_device_pni.json | 16 ++ 26 files changed, 927 insertions(+), 673 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/util/MessageValidation.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageValidationTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java create mode 100644 service/src/test/resources/fixtures/current_message_multi_device_pni.json diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index 3718f80cf..5df8d95e4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -68,13 +68,11 @@ import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest; import org.whispersystems.textsecuregcm.entities.DeviceName; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; -import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.RegistrationLock; import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.APNSender; import org.whispersystems.textsecuregcm.push.ApnMessage; @@ -95,7 +93,6 @@ import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.ForwardedIpUtil; import org.whispersystems.textsecuregcm.util.Hex; import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException; -import org.whispersystems.textsecuregcm.util.MessageValidation; import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException; import org.whispersystems.textsecuregcm.util.Username; import org.whispersystems.textsecuregcm.util.Util; @@ -416,41 +413,9 @@ public class AccountController { throw new ForbiddenException(); } - if (request.getDeviceSignedPrekeys() != null && !request.getDeviceSignedPrekeys().isEmpty()) { - if (request.getDeviceMessages() == null || request.getDeviceMessages().size() != request.getDeviceSignedPrekeys().size() - 1) { - // device_messages should exist and be one shorter than device_signed_prekeys, since it doesn't have the primary's key. - throw new WebApplicationException(Response.status(400).build()); - } - try { - // Checks that all except master ID are in device messages - MessageValidation.validateCompleteDeviceList( - authenticatedAccount.getAccount(), request.getDeviceMessages(), - IncomingMessage::getDestinationDeviceId, true, Optional.of(Device.MASTER_ID)); - MessageValidation.validateRegistrationIds( - authenticatedAccount.getAccount(), request.getDeviceMessages(), - IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId); - // Checks that all including master ID are in signed prekeys - MessageValidation.validateCompleteDeviceList( - authenticatedAccount.getAccount(), request.getDeviceSignedPrekeys().entrySet(), - e -> e.getKey(), false, Optional.empty()); - } catch (MismatchedDevicesException e) { - throw new WebApplicationException(Response.status(409) - .type(MediaType.APPLICATION_JSON_TYPE) - .entity(new MismatchedDevices(e.getMissingDevices(), - e.getExtraDevices())) - .build()); - } catch (StaleDevicesException e) { - throw new WebApplicationException(Response.status(410) - .type(MediaType.APPLICATION_JSON) - .entity(new StaleDevices(e.getStaleDevices())) - .build()); - } - } else if (request.getDeviceMessages() != null && !request.getDeviceMessages().isEmpty()) { - // device_messages shouldn't exist without device_signed_prekeys. - throw new WebApplicationException(Response.status(400).build()); - } + final String number = request.number(); - final String number = request.getNumber(); + // Only "bill" for rate limiting if we think there's a change to be made... if (!authenticatedAccount.getAccount().getNumber().equals(number)) { Util.requireNormalizedNumber(number); @@ -459,7 +424,7 @@ public class AccountController { final Optional storedVerificationCode = pendingAccounts.getCodeForNumber(number); - if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(request.getCode())) { + if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(request.code())) { throw new ForbiddenException(); } @@ -469,24 +434,42 @@ public class AccountController { final Optional existingAccount = accounts.getByE164(number); if (existingAccount.isPresent()) { - verifyRegistrationLock(existingAccount.get(), request.getRegistrationLock()); + verifyRegistrationLock(existingAccount.get(), request.registrationLock()); } rateLimiters.getVerifyLimiter().clear(number); } - final Account updatedAccount = changeNumberManager.changeNumber( - authenticatedAccount.getAccount(), - request.getNumber(), - Optional.ofNullable(request.getDeviceSignedPrekeys()).orElse(Collections.emptyMap()), - Optional.ofNullable(request.getDeviceMessages()).orElse(Collections.emptyList())); + // ...but always attempt to make the change in case a client retries and needs to re-send messages + try { + final Account updatedAccount = changeNumberManager.changeNumber( + authenticatedAccount.getAccount(), + request.number(), + request.pniIdentityKey(), + Optional.ofNullable(request.devicePniSignedPrekeys()).orElse(Collections.emptyMap()), + Optional.ofNullable(request.deviceMessages()).orElse(Collections.emptyList()), + Optional.ofNullable(request.pniRegistrationIds()).orElse(Collections.emptyMap())); - return new AccountIdentityResponse( - updatedAccount.getUuid(), - updatedAccount.getNumber(), - updatedAccount.getPhoneNumberIdentifier(), - updatedAccount.getUsername().orElse(null), - updatedAccount.isStorageSupported()); + return new AccountIdentityResponse( + updatedAccount.getUuid(), + updatedAccount.getNumber(), + updatedAccount.getPhoneNumberIdentifier(), + updatedAccount.getUsername().orElse(null), + updatedAccount.isStorageSupported()); + } catch (MismatchedDevicesException e) { + throw new WebApplicationException(Response.status(409) + .type(MediaType.APPLICATION_JSON_TYPE) + .entity(new MismatchedDevices(e.getMissingDevices(), + e.getExtraDevices())) + .build()); + } catch (StaleDevicesException e) { + throw new WebApplicationException(Response.status(410) + .type(MediaType.APPLICATION_JSON) + .entity(new StaleDevices(e.getStaleDevices())) + .build()); + } catch (IllegalArgumentException e) { + throw new BadRequestException(e); + } } @Timed @@ -625,6 +608,7 @@ public class AccountController { d.setLastSeen(Util.todayInMillis()); d.setCapabilities(attributes.getCapabilities()); d.setRegistrationId(attributes.getRegistrationId()); + attributes.getPhoneNumberIdentityRegistrationId().ifPresent(d::setPhoneNumberIdentityRegistrationId); d.setUserAgent(userAgent); }); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index ef078dee0..d8fd1691e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -198,6 +198,7 @@ public class DeviceController { device.setAuthenticationCredentials(new AuthenticationCredentials(password)); device.setFetchesMessages(accountAttributes.getFetchesMessages()); device.setRegistrationId(accountAttributes.getRegistrationId()); + accountAttributes.getPhoneNumberIdentityRegistrationId().ifPresent(device::setPhoneNumberIdentityRegistrationId); device.setLastSeen(Util.todayInMillis()); device.setCreated(System.currentTimeMillis()); device.setCapabilities(accountAttributes.getCapabilities()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index b07f73f06..ab899a9da 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -197,7 +197,11 @@ public class KeysController { PreKey preKey = preKeysByDeviceId.get(device.getId()); if (signedPreKey != null || preKey != null) { - responseItems.add(new PreKeyResponseItem(device.getId(), device.getRegistrationId(), signedPreKey, preKey)); + final int registrationId = usePhoneNumberIdentity ? + device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) : + device.getRegistrationId(); + + responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedPreKey, preKey)); } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 119c82d00..b61f5c298 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -21,7 +21,7 @@ import java.util.Arrays; import java.util.Base64; import java.util.Collection; import java.util.Collections; -import java.util.HashSet; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -89,8 +89,7 @@ import org.whispersystems.textsecuregcm.storage.DeletedAccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; -import org.whispersystems.textsecuregcm.util.MessageValidation; -import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; @@ -214,11 +213,23 @@ public class MessageController { checkRateLimit(source.get(), destination.get(), userAgent); } - MessageValidation.validateCompleteDeviceList(destination.get(), messages.getMessages(), - IncomingMessage::getDestinationDeviceId, isSyncMessage, - source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId)); - MessageValidation.validateRegistrationIds(destination.get(), messages.getMessages(), - IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId); + final Set excludedDeviceIds; + + if (isSyncMessage) { + excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId()); + } else { + excludedDeviceIds = Collections.emptySet(); + } + + DestinationDeviceValidator.validateCompleteDeviceList(destination.get(), + messages.getMessages().stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet()), + excludedDeviceIds); + + DestinationDeviceValidator.validateRegistrationIds(destination.get(), + messages.getMessages().stream().collect(Collectors.toMap( + IncomingMessage::getDestinationDeviceId, + IncomingMessage::getDestinationRegistrationId)), + destination.get().getPhoneNumberIdentifier().equals(destinationUuid)); final List tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())), @@ -307,13 +318,25 @@ public class MessageController { checkRateLimit(source.get(), destination.get(), userAgent); } - final List messagesAsList = Arrays.asList(messages); - MessageValidation.validateCompleteDeviceList(destination.get(), messagesAsList, - IncomingDeviceMessage::getDeviceId, isSyncMessage, - source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId)); - MessageValidation.validateRegistrationIds(destination.get(), messagesAsList, - IncomingDeviceMessage::getDeviceId, - IncomingDeviceMessage::getRegistrationId); + final Set excludedDeviceIds; + + if (isSyncMessage) { + excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId()); + } else { + excludedDeviceIds = Collections.emptySet(); + } + + DestinationDeviceValidator.validateCompleteDeviceList( + destination.get(), + Arrays.stream(messages).map(IncomingDeviceMessage::getDeviceId).collect(Collectors.toSet()), + excludedDeviceIds); + + DestinationDeviceValidator.validateRegistrationIds( + destination.get(), + Arrays.stream(messages).collect(Collectors.toMap( + IncomingDeviceMessage::getDeviceId, + IncomingDeviceMessage::getRegistrationId)), + destination.get().getPhoneNumberIdentifier().equals(destinationUuid)); final List tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)), @@ -372,27 +395,29 @@ public class MessageController { })); checkAccessKeys(accessKeys, uuidToAccountMap); - final Map>> accountToDeviceIdAndRegistrationIdMap = - Arrays - .stream(multiRecipientMessage.getRecipients()) - .collect(Collectors.toMap( - recipient -> uuidToAccountMap.get(recipient.getUuid()), - recipient -> new HashSet<>( - Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), - (a, b) -> { - a.addAll(b); - return a; - } - )); + final Map> accountToDeviceIdAndRegistrationIdMap = Arrays.stream(multiRecipientMessage.getRecipients()) + .collect(Collectors.toMap( + recipient -> uuidToAccountMap.get(recipient.getUuid()), + recipient -> Map.of(recipient.getDeviceId(), recipient.getRegistrationId()), + (a, b) -> { + final Map combined = new HashMap<>(); + combined.putAll(a); + combined.putAll(b); + + return combined; + } + )); Collection accountMismatchedDevices = new ArrayList<>(); Collection accountStaleDevices = new ArrayList<>(); uuidToAccountMap.values().forEach(account -> { - final Set> deviceIdAndRegistrationIdSet = accountToDeviceIdAndRegistrationIdMap.get(account); - final Set deviceIds = deviceIdAndRegistrationIdSet.stream().map(Pair::first).collect(Collectors.toSet()); + final Set deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).keySet(); try { - MessageValidation.validateCompleteDeviceList(account, deviceIds, false, Optional.empty()); - MessageValidation.validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream()); + DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet()); + + // Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number + // identity + DestinationDeviceValidator.validateRegistrationIds(account, accountToDeviceIdAndRegistrationIdMap.get(account), false); } catch (MismatchedDevicesException e) { accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountAttributes.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountAttributes.java index f47ad996b..199dc549f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountAttributes.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountAttributes.java @@ -6,9 +6,11 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; +import javax.annotation.Nullable; import javax.validation.constraints.Size; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.util.ExactlySize; +import java.util.OptionalInt; public class AccountAttributes { @@ -18,6 +20,10 @@ public class AccountAttributes { @JsonProperty private int registrationId; + @Nullable + @JsonProperty("pniRegistrationId") + private Integer phoneNumberIdentityRegistrationId; + @JsonProperty @Size(max = 204, message = "This field must be less than 50 characters") private String name; @@ -59,6 +65,10 @@ public class AccountAttributes { return registrationId; } + public OptionalInt getPhoneNumberIdentityRegistrationId() { + return phoneNumberIdentityRegistrationId != null ? OptionalInt.of(phoneNumberIdentityRegistrationId) : OptionalInt.empty(); + } + public String getName() { return name; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java index 5b97e59ec..49970b270 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java @@ -5,69 +5,17 @@ package org.whispersystems.textsecuregcm.entities; -import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import javax.annotation.Nullable; -import javax.validation.constraints.NotBlank; import java.util.List; import java.util.Map; +import javax.annotation.Nullable; +import javax.validation.constraints.NotBlank; -public class ChangePhoneNumberRequest { - - @JsonProperty - @NotBlank - final String number; - - @JsonProperty - @NotBlank - final String code; - - @JsonProperty("reglock") - @Nullable - final String registrationLock; - - @JsonProperty("device_messages") - @Nullable - final List deviceMessages; - - @JsonProperty("device_signed_prekeys") - @Nullable - final Map deviceSignedPrekeys; - - @JsonCreator - public ChangePhoneNumberRequest(@JsonProperty("number") final String number, - @JsonProperty("code") final String code, - @JsonProperty("reglock") @Nullable final String registrationLock, - @JsonProperty("device_messages") @Nullable final List deviceMessages, - @JsonProperty("device_signed_prekeys") @Nullable final Map deviceSignedPrekeys) { - - this.number = number; - this.code = code; - this.registrationLock = registrationLock; - this.deviceMessages = deviceMessages; - this.deviceSignedPrekeys = deviceSignedPrekeys; - } - - public String getNumber() { - return number; - } - - public String getCode() { - return code; - } - - @Nullable - public String getRegistrationLock() { - return registrationLock; - } - - @Nullable - public List getDeviceMessages() { - return deviceMessages; - } - - @Nullable - public Map getDeviceSignedPrekeys() { - return deviceSignedPrekeys; - } +public record ChangePhoneNumberRequest(@NotBlank String number, + @NotBlank String code, + @JsonProperty("reglock") @Nullable String registrationLock, + @Nullable String pniIdentityKey, + @Nullable List deviceMessages, + @Nullable Map devicePniSignedPrekeys, + @Nullable Map pniRegistrationIds) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java index 0d8bb6747..cae82e240 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java @@ -20,6 +20,7 @@ public class OutgoingMessageEntity { private final UUID sourceUuid; private final int sourceDevice; private final UUID destinationUuid; + private final UUID updatedPni; private final byte[] content; private final long serverTimestamp; @@ -31,6 +32,7 @@ public class OutgoingMessageEntity { @JsonProperty("sourceUuid") final UUID sourceUuid, @JsonProperty("sourceDevice") final int sourceDevice, @JsonProperty("destinationUuid") final UUID destinationUuid, + @JsonProperty("updatedPni") final UUID updatedPni, @JsonProperty("content") final byte[] content, @JsonProperty("serverTimestamp") final long serverTimestamp) { @@ -41,6 +43,7 @@ public class OutgoingMessageEntity { this.sourceUuid = sourceUuid; this.sourceDevice = sourceDevice; this.destinationUuid = destinationUuid; + this.updatedPni = updatedPni; this.content = content; this.serverTimestamp = serverTimestamp; } @@ -73,6 +76,10 @@ public class OutgoingMessageEntity { return destinationUuid; } + public UUID getUpdatedPni() { + return updatedPni; + } + public byte[] getContent() { return content; } @@ -83,23 +90,21 @@ public class OutgoingMessageEntity { @Override public boolean equals(final Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - final OutgoingMessageEntity that = (OutgoingMessageEntity)o; - return type == that.type && - timestamp == that.timestamp && - sourceDevice == that.sourceDevice && - serverTimestamp == that.serverTimestamp && - guid.equals(that.guid) && - Objects.equals(source, that.source) && - Objects.equals(sourceUuid, that.sourceUuid) && - destinationUuid.equals(that.destinationUuid) && - Arrays.equals(content, that.content); + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + final OutgoingMessageEntity that = (OutgoingMessageEntity) o; + return type == that.type && timestamp == that.timestamp && sourceDevice == that.sourceDevice + && serverTimestamp == that.serverTimestamp && guid.equals(that.guid) && Objects.equals(source, that.source) + && Objects.equals(sourceUuid, that.sourceUuid) && destinationUuid.equals(that.destinationUuid) + && Objects.equals(updatedPni, that.updatedPni) && Arrays.equals(content, that.content); } @Override public int hashCode() { - int result = Objects.hash(guid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, serverTimestamp); + int result = Objects.hash(guid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, updatedPni, + serverTimestamp); result = 31 * result + Arrays.hashCode(content); return result; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index ecd58bca0..d285b5173 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -19,7 +19,9 @@ import io.micrometer.core.instrument.Tags; import java.io.IOException; import java.time.Clock; import java.time.Duration; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.UUID; @@ -28,10 +30,13 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.RedisOperation; @@ -39,6 +44,7 @@ import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.util.Constants; +import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.UsernameValidator; import org.whispersystems.textsecuregcm.util.Util; @@ -152,6 +158,7 @@ public class AccountsManager { device.setAuthenticationCredentials(new AuthenticationCredentials(password)); device.setFetchesMessages(accountAttributes.getFetchesMessages()); device.setRegistrationId(accountAttributes.getRegistrationId()); + accountAttributes.getPhoneNumberIdentityRegistrationId().ifPresent(device::setPhoneNumberIdentityRegistrationId); device.setName(accountAttributes.getName()); device.setCapabilities(accountAttributes.getCapabilities()); device.setCreated(System.currentTimeMillis()); @@ -220,7 +227,11 @@ public class AccountsManager { } } - public Account changeNumber(final Account account, final String number) throws InterruptedException { + public Account changeNumber(final Account account, final String number, + @Nullable final String pniIdentityKey, + @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { + final String originalNumber = account.getNumber(); final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier(); @@ -228,6 +239,22 @@ public class AccountsManager { return account; } + if (pniSignedPreKeys != null && pniRegistrationIds != null) { + // Check that all including master ID are in signed pre-keys + DestinationDeviceValidator.validateCompleteDeviceList( + account, + pniSignedPreKeys.keySet(), + Collections.emptySet()); + + // Check that all devices are accounted for in the map of new PNI registration IDs + DestinationDeviceValidator.validateCompleteDeviceList( + account, + pniRegistrationIds.keySet(), + Collections.emptySet()); + } else if (pniSignedPreKeys != null || pniRegistrationIds != null) { + throw new IllegalArgumentException("Signed pre-keys and registration IDs must both be null or both be non-null"); + } + final AtomicReference updatedAccount = new AtomicReference<>(); deletedAccountsManager.lockAndPut(account.getNumber(), number, (originalAci, deletedAci) -> { @@ -252,7 +279,22 @@ public class AccountsManager { try { numberChangedAccount = updateWithRetries( account, - a -> true, + a -> { + //noinspection ConstantConditions + if (pniSignedPreKeys != null && pniRegistrationIds != null) { + pniSignedPreKeys.forEach((deviceId, signedPreKey) -> + a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey))); + + pniRegistrationIds.forEach((deviceId, registrationId) -> + a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId))); + } + + if (pniIdentityKey != null) { + a.setPhoneNumberIdentityKey(pniIdentityKey); + } + + return true; + }, a -> accounts.changeNumber(a, number, phoneNumberIdentifier), () -> accounts.getByAccountIdentifier(uuid).orElseThrow(), AccountChangeValidator.NUMBER_CHANGE_VALIDATOR); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index 984ce8f5b..796da06aa 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -6,19 +6,25 @@ package org.whispersystems.textsecuregcm.storage; import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.apache.commons.lang3.ObjectUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.controllers.AccountController; import org.whispersystems.textsecuregcm.controllers.MessageController; +import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; -import javax.validation.constraints.NotNull; -import java.util.List; -import java.util.Map; -import java.util.Optional; public class ChangeNumberManager { private static final Logger logger = LoggerFactory.getLogger(AccountController.class); @@ -32,35 +38,54 @@ public class ChangeNumberManager { this.accountsManager = accountsManager; } - public Account changeNumber( - @NotNull Account account, - @NotNull final String number, - @NotNull final Map deviceSignedPrekeys, - @NotNull final List deviceMessages) throws InterruptedException { + public Account changeNumber(final Account account, final String number, + @Nullable final String pniIdentityKey, + @Nullable final Map deviceSignedPreKeys, + @Nullable final List deviceMessages, + @Nullable final Map pniRegistrationIds) + throws InterruptedException, MismatchedDevicesException, StaleDevicesException { + + if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) { + assert pniIdentityKey != null; + assert deviceSignedPreKeys != null; + assert deviceMessages != null; + assert pniRegistrationIds != null; + + // Check that all except master ID are in device messages + DestinationDeviceValidator.validateCompleteDeviceList( + account, + deviceMessages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet()), + Set.of(Device.MASTER_ID)); + + DestinationDeviceValidator.validateRegistrationIds( + account, + deviceMessages.stream() + .collect(Collectors.toMap( + IncomingMessage::getDestinationDeviceId, + IncomingMessage::getDestinationRegistrationId)), + false); + } else if (!ObjectUtils.allNull(deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) { + throw new IllegalArgumentException("Signed pre-keys, device messages, and registration IDs must be all null or all non-null"); + } final Account updatedAccount; + if (number.equals(account.getNumber())) { // This may be a request that got repeated due to poor network conditions or other client error; take no action, // but report success since the account is in the desired state updatedAccount = account; } else { - updatedAccount = accountsManager.changeNumber(account, number); + updatedAccount = accountsManager.changeNumber(account, number, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds); } - // Whether the account already has this number or not, we reset signed prekeys and resend messages. - // This makes it so the client can resend a request they didn't get a response for (timeout, etc) - // to make sure their messages sent and prekeys were updated, even if the first time around the - // server crashed at/above this point. - if (deviceSignedPrekeys != null && !deviceSignedPrekeys.isEmpty()) { - for (Map.Entry entry : deviceSignedPrekeys.entrySet()) { - accountsManager.updateDevice(updatedAccount, entry.getKey(), - d -> d.setPhoneNumberIdentitySignedPreKey(entry.getValue())); - } - - for (IncomingMessage message : deviceMessages) { - sendMessageToSelf(updatedAccount, updatedAccount.getDevice(message.getDestinationDeviceId()), message); - } + // Whether the account already has this number or not, we resend messages. This makes it so the client can resend a + // request they didn't get a response for (timeout, etc) to make sure their messages sent even if the first time + // around the server crashed at/above this point. + if (deviceMessages != null) { + deviceMessages.forEach(message -> + sendMessageToSelf(updatedAccount, updatedAccount.getDevice(message.getDestinationDeviceId()), message)); } + return updatedAccount; } @@ -86,6 +111,7 @@ public class ChangeNumberManager { .setSource(sourceAndDestinationAccount.getNumber()) .setSourceUuid(sourceAndDestinationAccount.getUuid().toString()) .setSourceDevice((int) Device.MASTER_ID) + .setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString()) .build(); messageSender.sendMessage(sourceAndDestinationAccount, destinationDevice.get(), envelope, false); } catch (NotPushRegisteredException e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java index 4a0d9b3fd..00dc189ba 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.storage; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.OptionalInt; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; @@ -49,6 +50,10 @@ public class Device { @JsonProperty private int registrationId; + @Nullable + @JsonProperty("pniRegistrationId") + private Integer phoneNumberIdentityRegistrationId; + @JsonProperty private SignedPreKey signedPreKey; @@ -184,6 +189,14 @@ public class Device { this.registrationId = registrationId; } + public OptionalInt getPhoneNumberIdentityRegistrationId() { + return phoneNumberIdentityRegistrationId != null ? OptionalInt.of(phoneNumberIdentityRegistrationId) : OptionalInt.empty(); + } + + public void setPhoneNumberIdentityRegistrationId(final int phoneNumberIdentityRegistrationId) { + this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId; + } + public SignedPreKey getSignedPreKey() { return signedPreKey; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 3157e4fb9..7f4cd7aa0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -389,6 +389,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null, envelope.getSourceDevice(), envelope.hasDestinationUuid() ? UUID.fromString(envelope.getDestinationUuid()) : null, + envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null, envelope.hasContent() ? envelope.getContent().toByteArray() : null, envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java index ee96eea58..1f214c401 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java @@ -46,6 +46,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { private static final String KEY_SOURCE_UUID = "SU"; private static final String KEY_SOURCE_DEVICE = "SD"; private static final String KEY_DESTINATION_UUID = "DU"; + private static final String KEY_UPDATED_PNI = "UP"; private static final String KEY_CONTENT = "C"; private static final String KEY_TTL = "E"; @@ -85,10 +86,12 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { .put(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT, convertLocalIndexMessageUuidSortKey(messageUuid)) .put(KEY_TYPE, AttributeValues.fromInt(message.getType().getNumber())) .put(KEY_TIMESTAMP, AttributeValues.fromLong(message.getTimestamp())) - .put(KEY_TTL, AttributeValues.fromLong(getTtlForMessage(message))); - - item.put(KEY_DESTINATION_UUID, AttributeValues.fromUUID(UUID.fromString(message.getDestinationUuid()))); + .put(KEY_TTL, AttributeValues.fromLong(getTtlForMessage(message))) + .put(KEY_DESTINATION_UUID, AttributeValues.fromUUID(UUID.fromString(message.getDestinationUuid()))); + if (message.hasUpdatedPni()) { + item.put(KEY_UPDATED_PNI, AttributeValues.fromUUID(UUID.fromString(message.getUpdatedPni()))); + } if (message.hasSource()) { item.put(KEY_SOURCE, AttributeValues.fromString(message.getSource())); } @@ -240,7 +243,9 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { final int sourceDevice = AttributeValues.getInt(message, KEY_SOURCE_DEVICE, 0); final UUID destinationUuid = AttributeValues.getUUID(message, KEY_DESTINATION_UUID, null); final byte[] content = AttributeValues.getByteArray(message, KEY_CONTENT, null); - return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, content, sortKey.getServerTimestamp()); + final UUID updatedPni = AttributeValues.getUUID(message, KEY_UPDATED_PNI, null); + return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, + updatedPni, content, sortKey.getServerTimestamp()); } private void deleteRowsMatchingQuery(AttributeValue partitionKey, QueryRequest querySpec) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java new file mode 100644 index 000000000..3e44a8f7c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java @@ -0,0 +1,95 @@ +/* + * Copyright 2013-2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.util; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; + +public class DestinationDeviceValidator { + + /** + * Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID + * pairs in the given destination account. This method does not validate that all devices associated with the + * destination account are present in the given device ID/registration ID pairs. + * + * @param account the destination account against which to check the given device ID/registration ID pairs + * @param registrationIdsByDeviceId a map of device IDs to registration IDs + * @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device registration IDs + * associated with the account's PNI (if available); compare against the ACI-associated + * registration ID otherwise + * + * @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination + * account does not have a corresponding device or if the registration IDs do not match + */ + public static void validateRegistrationIds(final Account account, + final Map registrationIdsByDeviceId, + final boolean usePhoneNumberIdentity) throws StaleDevicesException { + + final List staleDevices = new ArrayList<>(); + + registrationIdsByDeviceId.forEach((deviceId, registrationId) -> { + if (registrationId > 0) { + final boolean registrationIdMatches = + account.getDevice(deviceId).map(device -> registrationId == (usePhoneNumberIdentity ? + device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) : + device.getRegistrationId())) + .orElse(false); + + if (!registrationIdMatches) { + staleDevices.add(deviceId); + } + } + }); + + if (!staleDevices.isEmpty()) { + throw new StaleDevicesException(staleDevices); + } + } + + /** + * Validates that the given set of device IDs from a set of messages matches the set of device IDs associated with the + * given destination account in preparation for sending those messages to the destination account. In general, the set + * of device IDs must exactly match the set of active devices associated with the destination account. When sending a + * "sync," message, though, the authenticated account is sending messages from one of their devices to all other + * devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}. + * + * @param account the destination account against which to check the given set of device IDs + * @param messageDeviceIds the set of device IDs to check against the destination account + * @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be + * present in the given set of device IDs (i.e. the device that is sending a sync message) + * + * @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with + * the destination account or is missing entries associated with the destination + * account + */ + public static void validateCompleteDeviceList(final Account account, + final Set messageDeviceIds, + final Set excludedDeviceIds) throws MismatchedDevicesException { + + final Set accountDeviceIds = account.getDevices().stream() + .filter(Device::isEnabled) + .map(Device::getId) + .filter(deviceId -> !excludedDeviceIds.contains(deviceId)) + .collect(Collectors.toSet()); + + final Set missingDeviceIds = new HashSet<>(accountDeviceIds); + missingDeviceIds.removeAll(messageDeviceIds); + + final Set extraDeviceIds = new HashSet<>(messageDeviceIds); + extraDeviceIds.removeAll(accountDeviceIds); + + if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) { + throw new MismatchedDevicesException(new ArrayList<>(missingDeviceIds), new ArrayList<>(extraDeviceIds)); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/MessageValidation.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/MessageValidation.java deleted file mode 100644 index 5e16bd642..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/MessageValidation.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright 2013-2022 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.util; - -import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; -import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.Device; -import java.util.Collection; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -public class MessageValidation { - public static void validateRegistrationIds(Account account, List messages, Function getDeviceId, Function getRegistrationId) - throws StaleDevicesException { - final Stream> deviceIdAndRegistrationIdStream = messages - .stream() - .map(message -> new Pair<>(getDeviceId.apply(message), getRegistrationId.apply(message))); - validateRegistrationIds(account, deviceIdAndRegistrationIdStream); - } - - public static void validateRegistrationIds(Account account, Stream> deviceIdAndRegistrationIdStream) - throws StaleDevicesException { - final List staleDevices = deviceIdAndRegistrationIdStream - .filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0) - .filter(deviceIdAndRegistrationId -> { - Optional device = account.getDevice(deviceIdAndRegistrationId.first()); - return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId(); - }) - .map(Pair::first) - .collect(Collectors.toList()); - - if (!staleDevices.isEmpty()) { - throw new StaleDevicesException(staleDevices); - } - } - - public static void validateCompleteDeviceList(Account account, Collection messages, Function getDeviceId, boolean isSyncMessage, - Optional authenticatedDeviceId) - throws MismatchedDevicesException { - Set messageDeviceIds = messages.stream().map(getDeviceId) - .collect(Collectors.toSet()); - validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId); - } - - public static void validateCompleteDeviceList(Account account, Set messageDeviceIds, boolean isSyncMessage, - Optional authenticatedDeviceId) - throws MismatchedDevicesException { - Set accountDeviceIds = new HashSet<>(); - - List missingDeviceIds = new LinkedList<>(); - List extraDeviceIds = new LinkedList<>(); - - for (Device device : account.getDevices()) { - if (device.isEnabled() && - !(isSyncMessage && device.getId() == authenticatedDeviceId.get())) { - accountDeviceIds.add(device.getId()); - - if (!messageDeviceIds.contains(device.getId())) { - missingDeviceIds.add(device.getId()); - } - } - } - - for (Long deviceId : messageDeviceIds) { - if (!accountDeviceIds.contains(deviceId)) { - extraDeviceIds.add(deviceId); - } - } - - if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) { - throw new MismatchedDevicesException(missingDeviceIds, extraDeviceIds); - } - } - -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index b31a2bd8b..afe7bc1fd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -333,8 +333,13 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac builder.setDestinationUuid(message.getDestinationUuid().toString()); + if (message.getUpdatedPni() != null) { + builder.setUpdatedPni(message.getUpdatedPni().toString()); + } + builder.setServerGuid(message.getGuid().toString()); + final Envelope envelope = builder.build(); if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { diff --git a/service/src/main/proto/TextSecure.proto b/service/src/main/proto/TextSecure.proto index 0897570ff..0f65bca27 100644 --- a/service/src/main/proto/TextSecure.proto +++ b/service/src/main/proto/TextSecure.proto @@ -41,6 +41,8 @@ message Envelope { optional uint64 server_timestamp = 10; optional bool ephemeral = 12; // indicates that the message should not be persisted if the recipient is offline optional string destination_uuid = 13; + optional string updated_pni = 15; + // next: 16 } message ProvisioningUuid { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 0df7b654a..70ebfdbc3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -27,7 +27,6 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; import com.google.common.collect.ImmutableSet; -import com.vdurmont.semver4j.Semver; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; @@ -38,7 +37,6 @@ import java.util.Base64; import java.util.LinkedList; import java.util.List; import java.util.Optional; -import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; @@ -92,6 +90,7 @@ class MessageControllerTest { private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID(); + private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID(); private static final String INTERNATIONAL_RECIPIENT = "+61123456789"; private static final UUID INTERNATIONAL_UUID = UUID.randomUUID(); @@ -127,31 +126,33 @@ class MessageControllerTest { @BeforeEach void setup() { final List singleDeviceList = List.of( - generateTestDevice(1, 111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis()) + generateTestDevice(1, 111, 1111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis()) ); final List multiDeviceList = List.of( - generateTestDevice(1, 222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()), - generateTestDevice(2, 333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()), - generateTestDevice(3, 444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) + generateTestDevice(1, 222, 2222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()), + generateTestDevice(2, 333, 3333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()), + generateTestDevice(3, 444, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) ); Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, "1234".getBytes()); - Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, UUID.randomUUID(), multiDeviceList, "1234".getBytes()); + Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, "1234".getBytes()); internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, "1234".getBytes()); when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByAccountIdentifier(eq(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount)); + when(accountsManager.getByPhoneNumberIdentifier(MULTI_DEVICE_PNI)).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByAccountIdentifier(INTERNATIONAL_UUID)).thenReturn(Optional.of(internationalAccount)); when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); } - private static Device generateTestDevice(final long id, final int registrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) { + private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) { final Device device = new Device(); device.setId(id); device.setRegistrationId(registrationId); + device.setPhoneNumberIdentityRegistrationId(pniRegistrationId); device.setSignedPreKey(signedPreKey); device.setCreated(createdAt); device.setLastSeen(lastSeen); @@ -197,6 +198,28 @@ class MessageControllerTest { } } + private static Stream> currentMessageSingleDevicePayloadsPni() { + ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); + messageStream.write(1); // version + messageStream.write(1); // count + messageStream.write(1); // device ID + messageStream.writeBytes(new byte[] { (byte)0x04, (byte)0x57 }); // registration ID + messageStream.write(1); // message type + messageStream.write(3); // message length + messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents + + try { + return Stream.of( + Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"), + IncomingMessageList.class), + MediaType.APPLICATION_JSON_TYPE), + Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE) + ); + } catch (Exception e) { + throw new AssertionError(e); + } + } + @ParameterizedTest @MethodSource("currentMessageSingleDevicePayloads") void testSendFromDisabledAccount(Entity payload) throws Exception { @@ -230,7 +253,7 @@ class MessageControllerTest { } @ParameterizedTest - @MethodSource("currentMessageSingleDevicePayloads") + @MethodSource("currentMessageSingleDevicePayloadsPni") void testSingleDeviceCurrentByPni(Entity payload) throws Exception { Response response = resources.getJerseyTest() @@ -403,6 +426,50 @@ class MessageControllerTest { } } + @ParameterizedTest + @MethodSource + void testMultiDeviceByPni(Entity payload) throws Exception { + Response response = + resources.getJerseyTest() + .target(String.format("/v1/messages/%s", MULTI_DEVICE_PNI)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(payload); + + assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); + + verify(messageSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false)); + } + + private static Stream> testMultiDeviceByPni() { + ByteArrayOutputStream messageStream = new ByteArrayOutputStream(); + messageStream.write(1); // version + messageStream.write(2); // count + + messageStream.write(1); // device ID + messageStream.writeBytes(new byte[] { (byte)0x08, (byte)0xae }); // registration ID + messageStream.write(1); // message type + messageStream.write(3); // message length + messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents + + messageStream.write(2); // device ID + messageStream.writeBytes(new byte[] { (byte)0x0d, (byte)0x05 }); // registration ID + messageStream.write(1); // message type + messageStream.write(3); // message length + messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents + + try { + return Stream.of( + Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_multi_device_pni.json"), + IncomingMessageList.class), + MediaType.APPLICATION_JSON_TYPE), + Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE) + ); + } catch (Exception e) { + throw new AssertionError(e); + } + } + @ParameterizedTest @MethodSource void testRegistrationIdMismatch(Entity payload) throws Exception { @@ -459,9 +526,11 @@ class MessageControllerTest { final UUID messageGuidOne = UUID.randomUUID(); final UUID sourceUuid = UUID.randomUUID(); + final UUID updatedPniOne = UUID.randomUUID(); + List messages = new LinkedList<>() {{ - add(new OutgoingMessageEntity(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, "hi there".getBytes(), 0)); - add(new OutgoingMessageEntity(null, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, null, 0)); + add(new OutgoingMessageEntity(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0)); + add(new OutgoingMessageEntity(null, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, null, null, 0)); }}; OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); @@ -485,16 +554,19 @@ class MessageControllerTest { assertEquals(response.getMessages().get(0).getSourceUuid(), sourceUuid); assertEquals(response.getMessages().get(1).getSourceUuid(), sourceUuid); + + assertEquals(updatedPniOne, response.getMessages().get(0).getUpdatedPni()); + assertNull(response.getMessages().get(1).getUpdatedPni()); } @Test - void testGetMessagesBadAuth() throws Exception { + void testGetMessagesBadAuth() { final long timestampOne = 313377; final long timestampTwo = 313388; - List messages = new LinkedList() {{ - add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, "hi there".getBytes(), 0)); - add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, 0)); + List messages = new LinkedList<>() {{ + add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0)); + add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0)); }}; OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); @@ -520,12 +592,12 @@ class MessageControllerTest { UUID uuid1 = UUID.randomUUID(); when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null)).thenReturn(Optional.of(new OutgoingMessageEntity( uuid1, Envelope.Type.CIPHERTEXT_VALUE, - timestamp, "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, "hi".getBytes(), 0))); + timestamp, "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0))); UUID uuid2 = UUID.randomUUID(); when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null)).thenReturn(Optional.of(new OutgoingMessageEntity( uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, - System.currentTimeMillis(), "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, 0))); + System.currentTimeMillis(), "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0))); UUID uuid3 = UUID.randomUUID(); when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid3, null)).thenReturn(Optional.empty()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index 155b1394e..ff4c01ee5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -15,14 +15,18 @@ import static org.mockito.Mockito.when; import java.time.Clock; import java.util.ArrayList; +import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.UUID; import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient; @@ -198,7 +202,7 @@ class AccountsManagerChangeNumberIntegrationTest { } @Test - void testChangeNumber() throws InterruptedException { + void testChangeNumber() throws InterruptedException, MismatchedDevicesException { final String originalNumber = "+18005551111"; final String secondNumber = "+18005552222"; @@ -206,7 +210,7 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); - accountsManager.changeNumber(account, secondNumber); + accountsManager.changeNumber(account, secondNumber, null, null, null); assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); @@ -221,7 +225,46 @@ class AccountsManagerChangeNumberIntegrationTest { } @Test - void testChangeNumberReturnToOriginal() throws InterruptedException { + void testChangeNumberWithPniExtensions() throws InterruptedException, MismatchedDevicesException { + final String originalNumber = "+18005551111"; + final String secondNumber = "+18005552222"; + final int rotatedPniRegistrationId = 17; + final SignedPreKey rotatedSignedPreKey = new SignedPreKey(1, "test", "test"); + + final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, "test", null, true, new Device.DeviceCapabilities()); + final Account account = accountsManager.create(originalNumber, "password", null, accountAttributes, new ArrayList<>()); + account.getMasterDevice().orElseThrow().setSignedPreKey(new SignedPreKey()); + + final UUID originalUuid = account.getUuid(); + final UUID originalPni = account.getPhoneNumberIdentifier(); + + final String pniIdentityKey = "changed-pni-identity-key"; + final Map preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey); + final Map registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId); + + final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, registrationIds); + + assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); + + assertTrue(accountsManager.getByE164(secondNumber).isPresent()); + assertEquals(originalUuid, accountsManager.getByE164(secondNumber).map(Account::getUuid).orElseThrow()); + assertNotEquals(originalPni, accountsManager.getByE164(secondNumber).map(Account::getPhoneNumberIdentifier).orElseThrow()); + + assertEquals(secondNumber, accountsManager.getByAccountIdentifier(originalUuid).map(Account::getNumber).orElseThrow()); + + assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber)); + assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); + + assertEquals(pniIdentityKey, updatedAccount.getPhoneNumberIdentityKey()); + + assertEquals(OptionalInt.of(rotatedPniRegistrationId), + updatedAccount.getMasterDevice().orElseThrow().getPhoneNumberIdentityRegistrationId()); + + assertEquals(rotatedSignedPreKey, updatedAccount.getMasterDevice().orElseThrow().getPhoneNumberIdentitySignedPreKey()); + } + + @Test + void testChangeNumberReturnToOriginal() throws InterruptedException, MismatchedDevicesException { final String originalNumber = "+18005551111"; final String secondNumber = "+18005552222"; @@ -229,8 +272,8 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); - account = accountsManager.changeNumber(account, secondNumber); - accountsManager.changeNumber(account, originalNumber); + account = accountsManager.changeNumber(account, secondNumber, null, null, null); + accountsManager.changeNumber(account, originalNumber, null, null, null); assertTrue(accountsManager.getByE164(originalNumber).isPresent()); assertEquals(originalUuid, accountsManager.getByE164(originalNumber).map(Account::getUuid).orElseThrow()); @@ -245,7 +288,7 @@ class AccountsManagerChangeNumberIntegrationTest { } @Test - void testChangeNumberContested() throws InterruptedException { + void testChangeNumberContested() throws InterruptedException, MismatchedDevicesException { final String originalNumber = "+18005551111"; final String secondNumber = "+18005552222"; @@ -255,7 +298,7 @@ class AccountsManagerChangeNumberIntegrationTest { final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final UUID existingAccountUuid = existingAccount.getUuid(); - accountsManager.changeNumber(account, secondNumber); + accountsManager.changeNumber(account, secondNumber, null, null, null); assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); @@ -269,7 +312,7 @@ class AccountsManagerChangeNumberIntegrationTest { assertEquals(Optional.of(existingAccountUuid), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); - accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber); + accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null); final Account existingAccount2 = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); @@ -278,7 +321,7 @@ class AccountsManagerChangeNumberIntegrationTest { } @Test - void testChangeNumberChaining() throws InterruptedException { + void testChangeNumberChaining() throws InterruptedException, MismatchedDevicesException { final String originalNumber = "+18005551111"; final String secondNumber = "+18005552222"; @@ -289,7 +332,7 @@ class AccountsManagerChangeNumberIntegrationTest { final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final UUID existingAccountUuid = existingAccount.getUuid(); - final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber); + final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null); final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier(); final Account reRegisteredAccount = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); @@ -300,7 +343,7 @@ class AccountsManagerChangeNumberIntegrationTest { assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); - final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber); + final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber, null, null, null); assertEquals(Optional.of(originalUuid), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index e805b3399..76b63047b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -45,6 +45,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; @@ -641,7 +642,7 @@ class AccountsManagerTest { } @Test - void testChangePhoneNumber() throws InterruptedException { + void testChangePhoneNumber() throws InterruptedException, MismatchedDevicesException { doAnswer(invocation -> invocation.getArgument(2, BiFunction.class).apply(Optional.empty(), Optional.empty())) .when(deletedAccountsManager).lockAndPut(anyString(), anyString(), any()); @@ -651,7 +652,7 @@ class AccountsManagerTest { final UUID originalPni = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]); - account = accountsManager.changeNumber(account, targetNumber); + account = accountsManager.changeNumber(account, targetNumber, null, null, null); assertEquals(targetNumber, account.getNumber()); @@ -663,11 +664,11 @@ class AccountsManagerTest { } @Test - void testChangePhoneNumberSameNumber() throws InterruptedException { + void testChangePhoneNumberSameNumber() throws InterruptedException, MismatchedDevicesException { final String number = "+14152222222"; Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); - account = accountsManager.changeNumber(account, number); + account = accountsManager.changeNumber(account, number, null, null, null); assertEquals(number, account.getNumber()); verify(deletedAccountsManager, never()).lockAndPut(anyString(), anyString(), any()); @@ -676,7 +677,7 @@ class AccountsManagerTest { } @Test - void testChangePhoneNumberExistingAccount() throws InterruptedException { + void testChangePhoneNumberExistingAccount() throws InterruptedException, MismatchedDevicesException { doAnswer(invocation -> invocation.getArgument(2, BiFunction.class).apply(Optional.empty(), Optional.empty())) .when(deletedAccountsManager).lockAndPut(anyString(), anyString(), any()); @@ -691,7 +692,7 @@ class AccountsManagerTest { when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]); - account = accountsManager.changeNumber(account, targetNumber); + account = accountsManager.changeNumber(account, targetNumber, null, null, null); assertEquals(targetNumber, account.getNumber()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 5e1923e7e..2899909c2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -4,6 +4,8 @@ */ package org.whispersystems.textsecuregcm.storage; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -11,7 +13,9 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -19,31 +23,43 @@ import java.util.UUID; import org.apache.commons.codec.binary.Base64; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.Mockito; +import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; +import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.entities.IncomingMessage; +import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.push.MessageSender; public class ChangeNumberManagerTest { - private static AccountsManager accountsManager = mock(AccountsManager.class); - private static MessageSender messageSender = mock(MessageSender.class); - private ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); + private AccountsManager accountsManager; + private MessageSender messageSender; + private ChangeNumberManager changeNumberManager; + + private Map updatedPhoneNumberIdentifiersByAccount; @BeforeEach - void reset() throws Exception { - Mockito.reset(accountsManager, messageSender); - when(accountsManager.changeNumber(any(), any())).thenAnswer((Answer) invocation -> { + void setUp() throws Exception { + accountsManager = mock(AccountsManager.class); + messageSender = mock(MessageSender.class); + changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); + + updatedPhoneNumberIdentifiersByAccount = new HashMap<>(); + + when(accountsManager.changeNumber(any(), any(), any(), any(), any())).thenAnswer((Answer)invocation -> { final Account account = invocation.getArgument(0, Account.class); final String number = invocation.getArgument(1, String.class); final UUID uuid = account.getUuid(); final List devices = account.getDevices(); + final UUID updatedPni = UUID.randomUUID(); + updatedPhoneNumberIdentifiersByAccount.put(account, updatedPni); + final Account updatedAccount = mock(Account.class); when(updatedAccount.getUuid()).thenReturn(uuid); when(updatedAccount.getNumber()).thenReturn(number); - when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID()); + when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(updatedPni); when(updatedAccount.getDevices()).thenReturn(devices); for (long i = 1; i <= 3; i++) { final Optional d = account.getDevice(i); @@ -58,8 +74,8 @@ public class ChangeNumberManagerTest { void changeNumberNoMessages() throws Exception { Account account = mock(Account.class); when(account.getNumber()).thenReturn("+18005551234"); - changeNumberManager.changeNumber(account, "+18025551234", Collections.EMPTY_MAP, Collections.EMPTY_LIST); - verify(accountsManager).changeNumber(account, "+18025551234"); + changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null); + verify(accountsManager).changeNumber(account, "+18025551234", null, null, null); verify(accountsManager, never()).updateDevice(any(), eq(1L), any()); verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); } @@ -69,27 +85,112 @@ public class ChangeNumberManagerTest { Account account = mock(Account.class); when(account.getNumber()).thenReturn("+18005551234"); var prekeys = Map.of(1L, new SignedPreKey()); - changeNumberManager.changeNumber(account, "+18025551234", prekeys, Collections.EMPTY_LIST); - verify(accountsManager).changeNumber(account, "+18025551234"); - verify(accountsManager).updateDevice(any(), eq(1L), any()); + final String pniIdentityKey = "pni-identity-key"; + + changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, Collections.emptyList(), Collections.emptyMap()); + verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, Collections.emptyMap()); verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); } @Test void changeNumberSetPrimaryDevicePrekeyAndSendMessages() throws Exception { - Account account = mock(Account.class); - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(UUID.randomUUID()); - Device d2 = mock(Device.class); + final String originalE164 = "+18005551234"; + final String changedE164 = "+18025551234"; + final UUID aci = UUID.randomUUID(); + final UUID pni = UUID.randomUUID(); + + final Account account = mock(Account.class); + when(account.getNumber()).thenReturn(originalE164); + when(account.getUuid()).thenReturn(aci); + when(account.getPhoneNumberIdentifier()).thenReturn(pni); + + final Device d2 = mock(Device.class); + when(d2.isEnabled()).thenReturn(true); + when(d2.getId()).thenReturn(2L); + when(account.getDevice(2L)).thenReturn(Optional.of(d2)); - var prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); - IncomingMessage msg = mock(IncomingMessage.class); + when(account.getDevices()).thenReturn(List.of(d2)); + + final String pniIdentityKey = "pni-identity-key"; + final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); + final Map registrationIds = Map.of(1L, 17, 2L, 19); + + final IncomingMessage msg = mock(IncomingMessage.class); when(msg.getDestinationDeviceId()).thenReturn(2L); when(msg.getContent()).thenReturn(Base64.encodeBase64String(new byte[]{1})); - changeNumberManager.changeNumber(account, "+18025551234", prekeys, List.of(msg)); - verify(accountsManager).changeNumber(account, "+18025551234"); - verify(accountsManager).updateDevice(any(), eq(1L), any()); - verify(accountsManager).updateDevice(any(), eq(2L), any()); - verify(messageSender).sendMessage(any(), eq(d2), any(), eq(false)); + + changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, List.of(msg), registrationIds); + + verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, registrationIds); + + final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); + verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); + + final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); + + assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); + assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); + assertEquals(changedE164, envelope.getSource()); + assertEquals(Device.MASTER_ID, envelope.getSourceDevice()); + assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni())); + } + + @Test + void changeNumberMismatchedRegistrationId() { + final Account account = mock(Account.class); + when(account.getNumber()).thenReturn("+18005551234"); + + final List devices = new ArrayList<>(); + + for (int i = 1; i <= 3; i++) { + final Device device = mock(Device.class); + when(device.getId()).thenReturn((long) i); + when(device.isEnabled()).thenReturn(true); + when(device.getRegistrationId()).thenReturn(i); + + devices.add(device); + when(account.getDevice(i)).thenReturn(Optional.of(device)); + } + + when(account.getDevices()).thenReturn(devices); + + final List messages = List.of( + new IncomingMessage(1, null, 2, 1, "foo"), + new IncomingMessage(1, null, 3, 1, "foo")); + + final Map preKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey()); + final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); + + assertThrows(StaleDevicesException.class, + () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", preKeys, messages, registrationIds)); + } + + @Test + void changeNumberMissingData() { + final Account account = mock(Account.class); + when(account.getNumber()).thenReturn("+18005551234"); + + final List devices = new ArrayList<>(); + + for (int i = 1; i <= 3; i++) { + final Device device = mock(Device.class); + when(device.getId()).thenReturn((long) i); + when(device.isEnabled()).thenReturn(true); + when(device.getRegistrationId()).thenReturn(i); + + devices.add(device); + when(account.getDevice(i)).thenReturn(Optional.of(device)); + } + + when(account.getDevices()).thenReturn(devices); + + final List messages = List.of( + new IncomingMessage(1, null, 2, 2, "foo"), + new IncomingMessage(1, null, 3, 3, "foo")); + + final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); + + assertThrows(IllegalArgumentException.class, + () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", null, messages, registrationIds)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java index be103dac2..bb39937bf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java @@ -253,9 +253,10 @@ class AccountControllerTest { when(accountsManager.setUsername(AuthHelper.VALID_ACCOUNT, "takenusername")) .thenThrow(new UsernameNotAvailableException()); - when(changeNumberManager.changeNumber(any(), any(), any(), any())).thenAnswer((Answer) invocation -> { + when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer) invocation -> { final Account account = invocation.getArgument(0, Account.class); final String number = invocation.getArgument(1, String.class); + final String pniIdentityKey = invocation.getArgument(2, String.class); final UUID uuid = account.getUuid(); final List devices = account.getDevices(); @@ -263,8 +264,10 @@ class AccountControllerTest { final Account updatedAccount = mock(Account.class); when(updatedAccount.getUuid()).thenReturn(uuid); when(updatedAccount.getNumber()).thenReturn(number); + when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey); when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID()); when(updatedAccount.getDevices()).thenReturn(devices); + for (long i = 1; i <= 3; i++) { final Optional d = account.getDevice(i); when(updatedAccount.getDevice(i)).thenReturn(d); @@ -1298,10 +1301,10 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); - verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any()); assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID); assertThat(accountIdentityResponse.getNumber()).isEqualTo(number); @@ -1318,12 +1321,12 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(400); assertThat(response.readEntity(String.class)).isBlank(); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); } @Test @@ -1336,7 +1339,7 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(400); @@ -1345,7 +1348,7 @@ class AccountControllerTest { assertThat(responseEntity.getOriginalNumber()).isEqualTo(number); assertThat(responseEntity.getNormalizedNumber()).isEqualTo("+447700900111"); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); } @Test @@ -1355,10 +1358,10 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); - verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any()); } @Test @@ -1373,11 +1376,11 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(403); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); } @Test @@ -1393,11 +1396,11 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code + "-incorrect", null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code + "-incorrect", null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(403); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); } @Test @@ -1423,11 +1426,11 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(200); - verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any()); } @Test @@ -1453,11 +1456,11 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(423); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); } @Test @@ -1485,11 +1488,11 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(423); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); } @Test @@ -1517,82 +1520,33 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(200); - verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any()); - } - - @Test - void testChangePhoneNumberDeviceMessagesWithoutPrekeys() throws Exception { - final String number = "+18005559876"; - final String code = "987654"; - - when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( - new StoredVerificationCode(code, System.currentTimeMillis(), "push", null))); - - final Response response = - resources.getJerseyTest() - .target("/v1/accounts/number") - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, - List.of(new IncomingMessage(1, null, 1, 1, "foo")), null), - MediaType.APPLICATION_JSON_TYPE)); - - assertThat(response.getStatus()).isEqualTo(400); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); - } - - @Test - void testChangePhoneNumberChangePrekeysDeviceMessagesMismatchDeviceIDs() throws Exception { - final String number = "+18005559876"; - final String code = "987654"; - - Device device2 = mock(Device.class); - when(device2.getId()).thenReturn(2L); - when(device2.isEnabled()).thenReturn(true); - Device device3 = mock(Device.class); - when(device3.getId()).thenReturn(3L); - when(device3.isEnabled()).thenReturn(true); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3)); - when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( - new StoredVerificationCode(code, System.currentTimeMillis(), "push", null))); - - final Response response = - resources.getJerseyTest() - .target("/v1/accounts/number") - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest( - number, code, null, - List.of( - new IncomingMessage(1, null, 2, 1, "foo"), - new IncomingMessage(1, null, 4, 1, "foo")), - Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey())), - MediaType.APPLICATION_JSON_TYPE)); - - assertThat(response.getStatus()).isEqualTo(409); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any()); } @Test void testChangePhoneNumberChangePrekeys() throws Exception { final String number = "+18005559876"; final String code = "987654"; + final String pniIdentityKey = "changed-pni-identity-key"; Device device2 = mock(Device.class); when(device2.getId()).thenReturn(2L); when(device2.isEnabled()).thenReturn(true); when(device2.getRegistrationId()).thenReturn(2); + Device device3 = mock(Device.class); when(device3.getId()).thenReturn(3L); when(device3.isEnabled()).thenReturn(true); when(device3.getRegistrationId()).thenReturn(3); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3)); when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2)); when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3)); + when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( new StoredVerificationCode(code, System.currentTimeMillis(), "push", null))); @@ -1601,6 +1555,8 @@ class AccountControllerTest { new IncomingMessage(1, null, 3, 3, "content3")); var deviceKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey()); + final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); + final AccountIdentityResponse accountIdentityResponse = resources.getJerseyTest() .target("/v1/accounts/number") @@ -1608,53 +1564,18 @@ class AccountControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest( number, code, null, - deviceMessages, - deviceKeys), + pniIdentityKey, deviceMessages, + deviceKeys, + registrationIds), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); - verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any()); assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID); assertThat(accountIdentityResponse.getNumber()).isEqualTo(number); assertThat(accountIdentityResponse.getPni()).isNotEqualTo(AuthHelper.VALID_PNI); } - @Test - void testChangePhoneNumberChangePrekeysDeviceMessagesMismatchRegistrationID() throws Exception { - final String number = "+18005559876"; - final String code = "987654"; - - Device device2 = mock(Device.class); - when(device2.getId()).thenReturn(2L); - when(device2.isEnabled()).thenReturn(true); - when(device2.getRegistrationId()).thenReturn(2); - Device device3 = mock(Device.class); - when(device3.getId()).thenReturn(3L); - when(device3.isEnabled()).thenReturn(true); - when(device3.getRegistrationId()).thenReturn(3); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3)); - when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2)); - when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3)); - when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( - new StoredVerificationCode(code, System.currentTimeMillis(), "push", null))); - - final Response response = - resources.getJerseyTest() - .target("/v1/accounts/number") - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest( - number, code, null, - List.of( - new IncomingMessage(1, null, 2, 1, "foo"), - new IncomingMessage(1, null, 3, 1, "foo")), - Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey())), - MediaType.APPLICATION_JSON_TYPE)); - - assertThat(response.getStatus()).isEqualTo(410); - verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any()); - } - @Test void testSetRegistrationLock() { Response response = diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index 4d8e7300c..12066b77a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -28,6 +28,7 @@ import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import java.util.UUID; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; @@ -73,6 +74,8 @@ class KeysControllerTest { private static final int SAMPLE_REGISTRATION_ID2 = 1002; private static final int SAMPLE_REGISTRATION_ID4 = 1555; + private static final int SAMPLE_PNI_REGISTRATION_ID = 1717; + private final PreKey SAMPLE_KEY = new PreKey(1234, "test1"); private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3"); private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5"); @@ -106,9 +109,11 @@ class KeysControllerTest { .addResource(new RateLimitExceededExceptionMapper()) .build(); + private Device sampleDevice; + @BeforeEach void setup() { - final Device sampleDevice = mock(Device.class); + sampleDevice = mock(Device.class); final Device sampleDevice2 = mock(Device.class); final Device sampleDevice3 = mock(Device.class); final Device sampleDevice4 = mock(Device.class); @@ -121,6 +126,7 @@ class KeysControllerTest { when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2); when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2); when(sampleDevice4.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID4); + when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(SAMPLE_PNI_REGISTRATION_ID)); when(sampleDevice.isEnabled()).thenReturn(true); when(sampleDevice2.isEnabled()).thenReturn(true); when(sampleDevice3.isEnabled()).thenReturn(false); @@ -284,6 +290,7 @@ class KeysControllerTest { assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); + assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); verify(KEYS).take(EXISTS_UUID, 1); @@ -302,6 +309,28 @@ class KeysControllerTest { assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId()); assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey()); + assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); + assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey()); + + verify(KEYS).take(EXISTS_PNI, 1); + verifyNoMoreInteractions(KEYS); + } + + @Test + void validSingleRequestByPhoneNumberIdentifierNoPniRegistrationIdTestV2() { + when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty()); + + PreKeyResponse result = resources.getJerseyTest() + .target(String.format("/v2/keys/%s/1", EXISTS_PNI)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(PreKeyResponse.class); + + assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); + assertThat(result.getDevicesCount()).isEqualTo(1); + assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId()); + assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey()); + assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey()); verify(KEYS).take(EXISTS_PNI, 1); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageValidationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageValidationTest.java deleted file mode 100644 index b90af7409..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageValidationTest.java +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Copyright 2013-2022 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.tests.util; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.params.provider.Arguments.arguments; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Stream; -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; -import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.util.MessageValidation; -import org.whispersystems.textsecuregcm.util.Pair; - -@ExtendWith(DropwizardExtensionsSupport.class) -class MessageValidationTest { - - static Account mockAccountWithDeviceAndRegId(Object... deviceAndRegistrationIds) { - Account account = mock(Account.class); - if (deviceAndRegistrationIds.length % 2 != 0) { - throw new IllegalArgumentException("invalid number of arguments specified; must be even"); - } - for (int i = 0; i < deviceAndRegistrationIds.length; i+=2) { - if (!(deviceAndRegistrationIds[i] instanceof Long)) { - throw new IllegalArgumentException("device id is not instance of long at index " + i); - } - if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) { - throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1)); - } - Long deviceId = (Long) deviceAndRegistrationIds[i]; - Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1]; - Device device = mock(Device.class); - when(device.getRegistrationId()).thenReturn(registrationId); - when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); - } - return account; - } - - static Collection> deviceAndRegistrationIds(Object... deviceAndRegistrationIds) { - final Collection> result = new HashSet<>(deviceAndRegistrationIds.length); - if (deviceAndRegistrationIds.length % 2 != 0) { - throw new IllegalArgumentException("invalid number of arguments specified; must be even"); - } - for (int i = 0; i < deviceAndRegistrationIds.length; i += 2) { - if (!(deviceAndRegistrationIds[i] instanceof Long)) { - throw new IllegalArgumentException("device id is not instance of long at index " + i); - } - if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) { - throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1)); - } - Long deviceId = (Long) deviceAndRegistrationIds[i]; - Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1]; - result.add(new Pair<>(deviceId, registrationId)); - } - return result; - } - - static Stream validateRegistrationIdsSource() { - return Stream.of( - arguments( - mockAccountWithDeviceAndRegId(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF), - deviceAndRegistrationIds(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF), - null), - arguments( - mockAccountWithDeviceAndRegId(1L, 42), - deviceAndRegistrationIds(1L, 1492), - Set.of(1L)), - arguments( - mockAccountWithDeviceAndRegId(1L, 42), - deviceAndRegistrationIds(1L, 42), - null), - arguments( - mockAccountWithDeviceAndRegId(1L, 42), - deviceAndRegistrationIds(1L, 0), - null), - arguments( - mockAccountWithDeviceAndRegId(1L, 42, 2L, 255), - deviceAndRegistrationIds(1L, 0, 2L, 42), - Set.of(2L)), - arguments( - mockAccountWithDeviceAndRegId(1L, 42, 2L, 256), - deviceAndRegistrationIds(1L, 41, 2L, 257), - Set.of(1L, 2L)) - ); - } - - @ParameterizedTest - @MethodSource("validateRegistrationIdsSource") - void testValidateRegistrationIds( - Account account, - Collection> deviceAndRegistrationIds, - Set expectedStaleDeviceIds) throws Exception { - if (expectedStaleDeviceIds != null) { - Assertions.assertThat(assertThrows(StaleDevicesException.class, () -> { - MessageValidation.validateRegistrationIds(account, deviceAndRegistrationIds.stream()); - }).getStaleDevices()).hasSameElementsAs(expectedStaleDeviceIds); - } else { - MessageValidation.validateRegistrationIds(account, deviceAndRegistrationIds.stream()); - } - } - - static Account mockAccountWithDeviceAndEnabled(Object... deviceIdAndEnabled) { - Account account = mock(Account.class); - if (deviceIdAndEnabled.length % 2 != 0) { - throw new IllegalArgumentException("invalid number of arguments specified; must be even"); - } - final List devices = new ArrayList<>(deviceIdAndEnabled.length / 2); - for (int i = 0; i < deviceIdAndEnabled.length; i+=2) { - if (!(deviceIdAndEnabled[i] instanceof Long)) { - throw new IllegalArgumentException("device id is not instance of long at index " + i); - } - if (!(deviceIdAndEnabled[i + 1] instanceof Boolean)) { - throw new IllegalArgumentException("enabled is not instance of boolean at index " + (i + 1)); - } - Long deviceId = (Long) deviceIdAndEnabled[i]; - Boolean enabled = (Boolean) deviceIdAndEnabled[i + 1]; - Device device = mock(Device.class); - when(device.isEnabled()).thenReturn(enabled); - when(device.getId()).thenReturn(deviceId); - when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); - devices.add(device); - } - when(account.getDevices()).thenReturn(devices); - return account; - } - - static Stream validateCompleteDeviceListSource() { - return Stream.of( - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(1L, 3L), - null, - null, - false, - null), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(1L, 2L, 3L), - null, - Set.of(2L), - false, - null), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(1L), - Set.of(3L), - null, - false, - null), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(1L, 2L), - Set.of(3L), - Set.of(2L), - false, - null), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(1L), - Set.of(3L), - Set.of(1L), - true, - 1L - ), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(2L), - Set.of(3L), - Set.of(2L), - true, - 1L - ), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(3L), - null, - null, - true, - 1L - ) - ); - } - - @ParameterizedTest - @MethodSource("validateCompleteDeviceListSource") - void testValidateCompleteDeviceList( - Account account, - Set deviceIds, - Collection expectedMissingDeviceIds, - Collection expectedExtraDeviceIds, - boolean isSyncMessage, - Long authenticatedDeviceId) throws Exception { - if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) { - final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class, - () -> MessageValidation.validateCompleteDeviceList(account, deviceIds, isSyncMessage, - Optional.ofNullable(authenticatedDeviceId))); - if (expectedMissingDeviceIds != null) { - Assertions.assertThat(mismatchedDevicesException.getMissingDevices()) - .hasSameElementsAs(expectedMissingDeviceIds); - } - if (expectedExtraDeviceIds != null) { - Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds); - } - } else { - MessageValidation.validateCompleteDeviceList(account, deviceIds, isSyncMessage, - Optional.ofNullable(authenticatedDeviceId)); - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java new file mode 100644 index 000000000..a9eaf1c57 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java @@ -0,0 +1,214 @@ +/* + * Copyright 2013-2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; + +@ExtendWith(DropwizardExtensionsSupport.class) +class DestinationDeviceValidatorTest { + + static Account mockAccountWithDeviceAndRegId(final Map registrationIdsByDeviceId) { + final Account account = mock(Account.class); + + registrationIdsByDeviceId.forEach((deviceId, registrationId) -> { + final Device device = mock(Device.class); + when(device.getRegistrationId()).thenReturn(registrationId); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + }); + + return account; + } + + static Stream validateRegistrationIdsSource() { + return Stream.of( + arguments( + mockAccountWithDeviceAndRegId(Map.of(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF)), + Map.of(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF), + null), + arguments( + mockAccountWithDeviceAndRegId(Map.of(1L, 42)), + Map.of(1L, 1492), + Set.of(1L)), + arguments( + mockAccountWithDeviceAndRegId(Map.of(1L, 42)), + Map.of(1L, 42), + null), + arguments( + mockAccountWithDeviceAndRegId(Map.of(1L, 42)), + Map.of(1L, 0), + null), + arguments( + mockAccountWithDeviceAndRegId(Map.of(1L, 42, 2L, 255)), + Map.of(1L, 0, 2L, 42), + Set.of(2L)), + arguments( + mockAccountWithDeviceAndRegId(Map.of(1L, 42, 2L, 256)), + Map.of(1L, 41, 2L, 257), + Set.of(1L, 2L)) + ); + } + + @ParameterizedTest + @MethodSource("validateRegistrationIdsSource") + void testValidateRegistrationIds( + Account account, + Map registrationIdsByDeviceId, + Set expectedStaleDeviceIds) throws Exception { + if (expectedStaleDeviceIds != null) { + Assertions.assertThat(assertThrows(StaleDevicesException.class, + () -> DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId, false)).getStaleDevices()) + .hasSameElementsAs(expectedStaleDeviceIds); + } else { + DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId, false); + } + } + + static Account mockAccountWithDeviceAndEnabled(final Map enabledStateByDeviceId) { + final Account account = mock(Account.class); + final List devices = new ArrayList<>(); + + enabledStateByDeviceId.forEach((deviceId, enabled) -> { + final Device device = mock(Device.class); + when(device.isEnabled()).thenReturn(enabled); + when(device.getId()).thenReturn(deviceId); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + + devices.add(device); + }); + + when(account.getDevices()).thenReturn(devices); + + return account; + } + + static Stream validateCompleteDeviceListSource() { + return Stream.of( + arguments( + mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), + Set.of(1L, 3L), + null, + null, + Collections.emptySet()), + arguments( + mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), + Set.of(1L, 2L, 3L), + null, + Set.of(2L), + Collections.emptySet()), + arguments( + mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), + Set.of(1L), + Set.of(3L), + null, + Collections.emptySet()), + arguments( + mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), + Set.of(1L, 2L), + Set.of(3L), + Set.of(2L), + Collections.emptySet()), + arguments( + mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), + Set.of(1L), + Set.of(3L), + Set.of(1L), + Set.of(1L) + ), + arguments( + mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), + Set.of(2L), + Set.of(3L), + Set.of(2L), + Set.of(1L) + ), + arguments( + mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), + Set.of(3L), + null, + null, + Set.of(1L) + ) + ); + } + + @ParameterizedTest + @MethodSource("validateCompleteDeviceListSource") + void testValidateCompleteDeviceList( + Account account, + Set deviceIds, + Collection expectedMissingDeviceIds, + Collection expectedExtraDeviceIds, + Set excludedDeviceIds) throws Exception { + + if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) { + final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class, + () -> DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds)); + if (expectedMissingDeviceIds != null) { + Assertions.assertThat(mismatchedDevicesException.getMissingDevices()) + .hasSameElementsAs(expectedMissingDeviceIds); + } + if (expectedExtraDeviceIds != null) { + Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds); + } + } else { + DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds); + } + } + + @Test + void testValidatePniRegistrationIds() { + final Device device = mock(Device.class); + when(device.getId()).thenReturn(Device.MASTER_ID); + + final Account account = mock(Account.class); + when(account.getDevices()).thenReturn(List.of(device)); + when(account.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(device)); + + final int aciRegistrationId = 17; + final int pniRegistrationId = 89; + final int incorrectRegistrationId = aciRegistrationId + pniRegistrationId; + + when(device.getRegistrationId()).thenReturn(aciRegistrationId); + when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId)); + + assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), false)); + assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, pniRegistrationId), true)); + assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), true)); + assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, pniRegistrationId), false)); + + when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty()); + + assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), false)); + assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), true)); + assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, incorrectRegistrationId), true)); + assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, incorrectRegistrationId), false)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index d1a37ea3d..9a9455607 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -64,6 +64,7 @@ import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.session.WebSocketSessionContext; +import javax.annotation.Nullable; class WebSocketConnectionTest { @@ -285,6 +286,7 @@ class WebSocketConnectionTest { .setSource("sender1") .setSourceUuid(UUID.randomUUID().toString()) .setDestinationUuid(UUID.randomUUID().toString()) + .setUpdatedPni(UUID.randomUUID().toString()) .setTimestamp(System.currentTimeMillis()) .setSourceDevice(1) .setType(Envelope.Type.CIPHERTEXT) @@ -302,12 +304,12 @@ class WebSocketConnectionTest { List pendingMessages = new LinkedList() {{ add(new OutgoingMessageEntity(UUID.randomUUID(), firstMessage.getType().getNumber(), firstMessage.getTimestamp(), firstMessage.getSource(), UUID.fromString(firstMessage.getSourceUuid()), - firstMessage.getSourceDevice(), UUID.fromString(firstMessage.getDestinationUuid()), - firstMessage.getContent().toByteArray(), 0)); + firstMessage.getSourceDevice(), UUID.fromString(firstMessage.getDestinationUuid()), UUID.fromString(firstMessage.getUpdatedPni()), + firstMessage.getContent().toByteArray(), 0)); add(new OutgoingMessageEntity(UUID.randomUUID(), secondMessage.getType().getNumber(), secondMessage.getTimestamp(), secondMessage.getSource(), UUID.fromString(secondMessage.getSourceUuid()), - secondMessage.getSourceDevice(), UUID.fromString(secondMessage.getDestinationUuid()), - secondMessage.getContent().toByteArray(), 0)); + secondMessage.getSourceDevice(), UUID.fromString(secondMessage.getDestinationUuid()), null, + secondMessage.getContent().toByteArray(), 0)); }}; OutgoingMessageEntityList pendingMessagesList = new OutgoingMessageEntityList(pendingMessages, false); @@ -884,7 +886,7 @@ class WebSocketConnectionTest { private OutgoingMessageEntity createMessage(String sender, UUID senderUuid, UUID destinationUuid, long timestamp, boolean receipt, String content) { return new OutgoingMessageEntity(UUID.randomUUID(), receipt ? Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, - timestamp, sender, senderUuid, 1, destinationUuid, content.getBytes(), 0); + timestamp, sender, senderUuid, 1, destinationUuid, null, content.getBytes(), 0); } } diff --git a/service/src/test/resources/fixtures/current_message_multi_device_pni.json b/service/src/test/resources/fixtures/current_message_multi_device_pni.json new file mode 100644 index 000000000..0be6b7832 --- /dev/null +++ b/service/src/test/resources/fixtures/current_message_multi_device_pni.json @@ -0,0 +1,16 @@ +{ + "messages" : [{ + "type" : 1, + "destinationDeviceId" : 1, + "destinationRegistrationId" : 2222, + "content" : "Zm9vYmFyego", + "timestamp" : 1234 + }, + { + "type" : 1, + "destinationDeviceId" : 2, + "destinationRegistrationId" : 3333, + "content" : "Zm9vYmFyego", + "timestamp" : 1234 + }] +}