diff --git a/abusive-message-filter b/abusive-message-filter index 8f690df72..ae98ea5c6 160000 --- a/abusive-message-filter +++ b/abusive-message-filter @@ -1 +1 @@ -Subproject commit 8f690df72ccda8fcbf048eee4c07f3e60d52f1fd +Subproject commit ae98ea5c61257e76c98dc4db9e5c2911facb5849 diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 2830c5ff2..18f122112 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -169,6 +169,7 @@ import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerCache; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerListener; import org.whispersystems.textsecuregcm.storage.Accounts; import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.ChangeNumberManager; import org.whispersystems.textsecuregcm.storage.ContactDiscoveryWriter; import org.whispersystems.textsecuregcm.storage.DeletedAccounts; import org.whispersystems.textsecuregcm.storage.DeletedAccountsDirectoryReconciler; @@ -496,6 +497,7 @@ public class WhisperServerService extends Application directoryReconciliationAccountDatabaseCrawlerListeners = new ArrayList<>(); final List deletedAccountsDirectoryReconcilers = new ArrayList<>(); @@ -622,8 +624,8 @@ public class WhisperServerService extends Application commonControllers = Lists.newArrayList( diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index fb56bc385..77c6a8ebd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -33,6 +33,7 @@ import javax.validation.constraints.NotNull; import javax.ws.rs.BadRequestException; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; +import javax.ws.rs.ForbiddenException; import javax.ws.rs.GET; import javax.ws.rs.HEAD; import javax.ws.rs.HeaderParam; @@ -69,8 +70,11 @@ import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest; import org.whispersystems.textsecuregcm.entities.DeviceName; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; +import org.whispersystems.textsecuregcm.entities.IncomingMessage; +import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.RegistrationLock; import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure; +import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.APNSender; @@ -84,6 +88,7 @@ import org.whispersystems.textsecuregcm.storage.AbusiveHostRule; import org.whispersystems.textsecuregcm.storage.AbusiveHostRules; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.ChangeNumberManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; @@ -92,6 +97,7 @@ import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.ForwardedIpUtil; import org.whispersystems.textsecuregcm.util.Hex; import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException; +import org.whispersystems.textsecuregcm.util.MessageValidation; import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException; import org.whispersystems.textsecuregcm.util.Username; import org.whispersystems.textsecuregcm.util.Util; @@ -138,6 +144,7 @@ public class AccountController { private final ExternalServiceCredentialGenerator backupServiceCredentialGenerator; private final TwilioVerifyExperimentEnrollmentManager verifyExperimentEnrollmentManager; + private final ChangeNumberManager changeNumberManager; public AccountController(StoredVerificationCodeManager pendingAccounts, AccountsManager accounts, @@ -150,8 +157,9 @@ public class AccountController { RecaptchaClient recaptchaClient, GCMSender gcmSender, APNSender apnSender, - ExternalServiceCredentialGenerator backupServiceCredentialGenerator, - TwilioVerifyExperimentEnrollmentManager verifyExperimentEnrollmentManager) + TwilioVerifyExperimentEnrollmentManager verifyExperimentEnrollmentManager, + ChangeNumberManager changeNumberManager, + ExternalServiceCredentialGenerator backupServiceCredentialGenerator) { this.pendingAccounts = pendingAccounts; this.accounts = accounts; @@ -164,8 +172,9 @@ public class AccountController { this.recaptchaClient = recaptchaClient; this.gcmSender = gcmSender; this.apnSender = apnSender; - this.backupServiceCredentialGenerator = backupServiceCredentialGenerator; this.verifyExperimentEnrollmentManager = verifyExperimentEnrollmentManager; + this.backupServiceCredentialGenerator = backupServiceCredentialGenerator; + this.changeNumberManager = changeNumberManager; } @Timed @@ -403,38 +412,75 @@ public class AccountController { public AccountIdentityResponse changeNumber(@Auth final AuthenticatedAccount authenticatedAccount, @NotNull @Valid final ChangePhoneNumberRequest request) throws RateLimitExceededException, InterruptedException, ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException { - final Account updatedAccount; + if (!authenticatedAccount.getAuthenticatedDevice().isMaster()) { + throw new ForbiddenException(); + } - if (request.getNumber().equals(authenticatedAccount.getAccount().getNumber())) { - // This may be a request that got repeated due to poor network conditions or other client error; take no action, - // but report success since the account is in the desired state - updatedAccount = authenticatedAccount.getAccount(); - } else { - Util.requireNormalizedNumber(request.getNumber()); + if (request.getDeviceSignedPrekeys() != null && !request.getDeviceSignedPrekeys().isEmpty()) { + if (request.getDeviceMessages() == null || request.getDeviceMessages().size() != request.getDeviceSignedPrekeys().size() - 1) { + // device_messages should exist and be one shorter than device_signed_prekeys, since it doesn't have the primary's key. + throw new WebApplicationException(Response.status(400).build()); + } + try { + // Checks that all except master ID are in device messages + MessageValidation.validateCompleteDeviceList( + authenticatedAccount.getAccount(), request.getDeviceMessages(), + IncomingMessage::getDestinationDeviceId, true, Optional.of(Device.MASTER_ID)); + MessageValidation.validateRegistrationIds( + authenticatedAccount.getAccount(), request.getDeviceMessages(), + IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId); + // Checks that all including master ID are in signed prekeys + MessageValidation.validateCompleteDeviceList( + authenticatedAccount.getAccount(), request.getDeviceSignedPrekeys().entrySet(), + e -> e.getKey(), false, Optional.empty()); + } catch (MismatchedDevicesException e) { + throw new WebApplicationException(Response.status(409) + .type(MediaType.APPLICATION_JSON_TYPE) + .entity(new MismatchedDevices(e.getMissingDevices(), + e.getExtraDevices())) + .build()); + } catch (StaleDevicesException e) { + throw new WebApplicationException(Response.status(410) + .type(MediaType.APPLICATION_JSON) + .entity(new StaleDevices(e.getStaleDevices())) + .build()); + } + } else if (request.getDeviceMessages() != null && !request.getDeviceMessages().isEmpty()) { + // device_messages shouldn't exist without device_signed_prekeys. + throw new WebApplicationException(Response.status(400).build()); + } - rateLimiters.getVerifyLimiter().validate(request.getNumber()); + final String number = request.getNumber(); + if (!authenticatedAccount.getAccount().getNumber().equals(number)) { + Util.requireNormalizedNumber(number); + + rateLimiters.getVerifyLimiter().validate(number); final Optional storedVerificationCode = - pendingAccounts.getCodeForNumber(request.getNumber()); + pendingAccounts.getCodeForNumber(number); if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(request.getCode())) { - throw new WebApplicationException(Response.status(403).build()); + throw new ForbiddenException(); } storedVerificationCode.flatMap(StoredVerificationCode::getTwilioVerificationSid) .ifPresent(smsSender::reportVerificationSucceeded); - final Optional existingAccount = accounts.getByE164(request.getNumber()); + final Optional existingAccount = accounts.getByE164(number); if (existingAccount.isPresent()) { verifyRegistrationLock(existingAccount.get(), request.getRegistrationLock()); } - rateLimiters.getVerifyLimiter().clear(request.getNumber()); - - updatedAccount = accounts.changeNumber(authenticatedAccount.getAccount(), request.getNumber()); + rateLimiters.getVerifyLimiter().clear(number); } + final Account updatedAccount = changeNumberManager.changeNumber( + authenticatedAccount.getAccount(), + request.getNumber(), + Optional.ofNullable(request.getDeviceSignedPrekeys()).orElse(Collections.emptyMap()), + Optional.ofNullable(request.getDeviceMessages()).orElse(Collections.emptyList())); + return new AccountIdentityResponse( updatedAccount.getUuid(), updatedAccount.getNumber(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index b43d899d5..968042d53 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -92,6 +92,7 @@ import org.whispersystems.textsecuregcm.storage.DeletedAccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; +import org.whispersystems.textsecuregcm.util.MessageValidation; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; @@ -225,10 +226,10 @@ public class MessageController { checkRateLimit(source.get(), destination.get(), userAgent); } - validateCompleteDeviceList(destination.get(), messages.getMessages(), + MessageValidation.validateCompleteDeviceList(destination.get(), messages.getMessages(), IncomingMessage::getDestinationDeviceId, isSyncMessage, source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId)); - validateRegistrationIds(destination.get(), messages.getMessages(), + MessageValidation.validateRegistrationIds(destination.get(), messages.getMessages(), IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId); final List tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), @@ -319,10 +320,10 @@ public class MessageController { } final List messagesAsList = Arrays.asList(messages); - validateCompleteDeviceList(destination.get(), messagesAsList, + MessageValidation.validateCompleteDeviceList(destination.get(), messagesAsList, IncomingDeviceMessage::getDeviceId, isSyncMessage, source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId)); - validateRegistrationIds(destination.get(), messagesAsList, + MessageValidation.validateRegistrationIds(destination.get(), messagesAsList, IncomingDeviceMessage::getDeviceId, IncomingDeviceMessage::getRegistrationId); @@ -402,8 +403,8 @@ public class MessageController { final Set> deviceIdAndRegistrationIdSet = accountToDeviceIdAndRegistrationIdMap.get(account); final Set deviceIds = deviceIdAndRegistrationIdSet.stream().map(Pair::first).collect(Collectors.toSet()); try { - validateCompleteDeviceList(account, deviceIds, false, Optional.empty()); - validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream()); + MessageValidation.validateCompleteDeviceList(account, deviceIds, false, Optional.empty()); + MessageValidation.validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream()); } catch (MismatchedDevicesException e) { accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); @@ -731,72 +732,6 @@ public class MessageController { } } - @VisibleForTesting - public static void validateRegistrationIds(Account account, List messages, Function getDeviceId, Function getRegistrationId) - throws StaleDevicesException { - final Stream> deviceIdAndRegistrationIdStream = messages - .stream() - .map(message -> new Pair<>(getDeviceId.apply(message), getRegistrationId.apply(message))); - validateRegistrationIds(account, deviceIdAndRegistrationIdStream); - } - - @VisibleForTesting - public static void validateRegistrationIds(Account account, Stream> deviceIdAndRegistrationIdStream) - throws StaleDevicesException { - final List staleDevices = deviceIdAndRegistrationIdStream - .filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0) - .filter(deviceIdAndRegistrationId -> { - Optional device = account.getDevice(deviceIdAndRegistrationId.first()); - return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId(); - }) - .map(Pair::first) - .collect(Collectors.toList()); - - if (!staleDevices.isEmpty()) { - throw new StaleDevicesException(staleDevices); - } - } - - @VisibleForTesting - public static void validateCompleteDeviceList(Account account, List messages, Function getDeviceId, boolean isSyncMessage, - Optional authenticatedDeviceId) - throws MismatchedDevicesException { - Set messageDeviceIds = messages.stream().map(getDeviceId) - .collect(Collectors.toSet()); - validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId); - } - - @VisibleForTesting - public static void validateCompleteDeviceList(Account account, Set messageDeviceIds, boolean isSyncMessage, - Optional authenticatedDeviceId) - throws MismatchedDevicesException { - Set accountDeviceIds = new HashSet<>(); - - List missingDeviceIds = new LinkedList<>(); - List extraDeviceIds = new LinkedList<>(); - - for (Device device : account.getDevices()) { - if (device.isEnabled() && - !(isSyncMessage && device.getId() == authenticatedDeviceId.get())) { - accountDeviceIds.add(device.getId()); - - if (!messageDeviceIds.contains(device.getId())) { - missingDeviceIds.add(device.getId()); - } - } - } - - for (Long deviceId : messageDeviceIds) { - if (!accountDeviceIds.contains(deviceId)) { - extraDeviceIds.add(deviceId); - } - } - - if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) { - throw new MismatchedDevicesException(missingDeviceIds, extraDeviceIds); - } - } - private void validateContentLength(final int contentLength, final String userAgent) { Metrics.summary(CONTENT_SIZE_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))) .record(contentLength); @@ -818,7 +753,7 @@ public class MessageController { } } - private Optional getMessageContent(IncomingMessage message) { + public static Optional getMessageContent(IncomingMessage message) { if (Util.isEmpty(message.getContent())) return Optional.empty(); try { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java index 088a4d554..5b97e59ec 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java @@ -9,6 +9,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import javax.annotation.Nullable; import javax.validation.constraints.NotBlank; +import java.util.List; +import java.util.Map; public class ChangePhoneNumberRequest { @@ -24,14 +26,26 @@ public class ChangePhoneNumberRequest { @Nullable final String registrationLock; + @JsonProperty("device_messages") + @Nullable + final List deviceMessages; + + @JsonProperty("device_signed_prekeys") + @Nullable + final Map deviceSignedPrekeys; + @JsonCreator public ChangePhoneNumberRequest(@JsonProperty("number") final String number, @JsonProperty("code") final String code, - @JsonProperty("reglock") @Nullable final String registrationLock) { + @JsonProperty("reglock") @Nullable final String registrationLock, + @JsonProperty("device_messages") @Nullable final List deviceMessages, + @JsonProperty("device_signed_prekeys") @Nullable final Map deviceSignedPrekeys) { this.number = number; this.code = code; this.registrationLock = registrationLock; + this.deviceMessages = deviceMessages; + this.deviceSignedPrekeys = deviceSignedPrekeys; } public String getNumber() { @@ -46,4 +60,14 @@ public class ChangePhoneNumberRequest { public String getRegistrationLock() { return registrationLock; } + + @Nullable + public List getDeviceMessages() { + return deviceMessages; + } + + @Nullable + public Map getDeviceSignedPrekeys() { + return deviceSignedPrekeys; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java new file mode 100644 index 000000000..984ce8f5b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -0,0 +1,95 @@ +/* + * Copyright 2013-2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.storage; + +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.ByteString; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.controllers.AccountController; +import org.whispersystems.textsecuregcm.controllers.MessageController; +import org.whispersystems.textsecuregcm.entities.IncomingMessage; +import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; +import javax.validation.constraints.NotNull; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class ChangeNumberManager { + private static final Logger logger = LoggerFactory.getLogger(AccountController.class); + private final MessageSender messageSender; + private final AccountsManager accountsManager; + + public ChangeNumberManager( + final MessageSender messageSender, + final AccountsManager accountsManager) { + this.messageSender = messageSender; + this.accountsManager = accountsManager; + } + + public Account changeNumber( + @NotNull Account account, + @NotNull final String number, + @NotNull final Map deviceSignedPrekeys, + @NotNull final List deviceMessages) throws InterruptedException { + + final Account updatedAccount; + if (number.equals(account.getNumber())) { + // This may be a request that got repeated due to poor network conditions or other client error; take no action, + // but report success since the account is in the desired state + updatedAccount = account; + } else { + updatedAccount = accountsManager.changeNumber(account, number); + } + + // Whether the account already has this number or not, we reset signed prekeys and resend messages. + // This makes it so the client can resend a request they didn't get a response for (timeout, etc) + // to make sure their messages sent and prekeys were updated, even if the first time around the + // server crashed at/above this point. + if (deviceSignedPrekeys != null && !deviceSignedPrekeys.isEmpty()) { + for (Map.Entry entry : deviceSignedPrekeys.entrySet()) { + accountsManager.updateDevice(updatedAccount, entry.getKey(), + d -> d.setPhoneNumberIdentitySignedPreKey(entry.getValue())); + } + + for (IncomingMessage message : deviceMessages) { + sendMessageToSelf(updatedAccount, updatedAccount.getDevice(message.getDestinationDeviceId()), message); + } + } + return updatedAccount; + } + + @VisibleForTesting + void sendMessageToSelf( + Account sourceAndDestinationAccount, Optional destinationDevice, IncomingMessage message) { + Optional contents = MessageController.getMessageContent(message); + if (!contents.isPresent()) { + logger.debug("empty message contents sending to self, ignoring"); + return; + } else if (!destinationDevice.isPresent()) { + logger.debug("destination device not present"); + return; + } + try { + long serverTimestamp = System.currentTimeMillis(); + Envelope envelope = Envelope.newBuilder() + .setType(Envelope.Type.forNumber(message.getType())) + .setTimestamp(serverTimestamp) + .setServerTimestamp(serverTimestamp) + .setDestinationUuid(sourceAndDestinationAccount.getUuid().toString()) + .setContent(ByteString.copyFrom(contents.get())) + .setSource(sourceAndDestinationAccount.getNumber()) + .setSourceUuid(sourceAndDestinationAccount.getUuid().toString()) + .setSourceDevice((int) Device.MASTER_ID) + .build(); + messageSender.sendMessage(sourceAndDestinationAccount, destinationDevice.get(), envelope, false); + } catch (NotPushRegisteredException e) { + logger.debug("Not registered", e); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/MessageValidation.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/MessageValidation.java new file mode 100644 index 000000000..5e16bd642 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/MessageValidation.java @@ -0,0 +1,84 @@ +/* + * Copyright 2013-2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.util; + +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import java.util.Collection; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class MessageValidation { + public static void validateRegistrationIds(Account account, List messages, Function getDeviceId, Function getRegistrationId) + throws StaleDevicesException { + final Stream> deviceIdAndRegistrationIdStream = messages + .stream() + .map(message -> new Pair<>(getDeviceId.apply(message), getRegistrationId.apply(message))); + validateRegistrationIds(account, deviceIdAndRegistrationIdStream); + } + + public static void validateRegistrationIds(Account account, Stream> deviceIdAndRegistrationIdStream) + throws StaleDevicesException { + final List staleDevices = deviceIdAndRegistrationIdStream + .filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0) + .filter(deviceIdAndRegistrationId -> { + Optional device = account.getDevice(deviceIdAndRegistrationId.first()); + return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId(); + }) + .map(Pair::first) + .collect(Collectors.toList()); + + if (!staleDevices.isEmpty()) { + throw new StaleDevicesException(staleDevices); + } + } + + public static void validateCompleteDeviceList(Account account, Collection messages, Function getDeviceId, boolean isSyncMessage, + Optional authenticatedDeviceId) + throws MismatchedDevicesException { + Set messageDeviceIds = messages.stream().map(getDeviceId) + .collect(Collectors.toSet()); + validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId); + } + + public static void validateCompleteDeviceList(Account account, Set messageDeviceIds, boolean isSyncMessage, + Optional authenticatedDeviceId) + throws MismatchedDevicesException { + Set accountDeviceIds = new HashSet<>(); + + List missingDeviceIds = new LinkedList<>(); + List extraDeviceIds = new LinkedList<>(); + + for (Device device : account.getDevices()) { + if (device.isEnabled() && + !(isSyncMessage && device.getId() == authenticatedDeviceId.get())) { + accountDeviceIds.add(device.getId()); + + if (!messageDeviceIds.contains(device.getId())) { + missingDeviceIds.add(device.getId()); + } + } + } + + for (Long deviceId : messageDeviceIds) { + if (!accountDeviceIds.contains(deviceId)) { + extraDeviceIds.add(deviceId); + } + } + + if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) { + throw new MismatchedDevicesException(missingDeviceIds, extraDeviceIds); + } + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index a3441f88f..33444d950 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -85,6 +85,7 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.util.MessageValidation; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -775,196 +776,4 @@ class MessageControllerTest { Arguments.of("fixtures/current_message_single_device_server_receipt_type.json", false) ); } - - static Account mockAccountWithDeviceAndRegId(Object... deviceAndRegistrationIds) { - Account account = mock(Account.class); - if (deviceAndRegistrationIds.length % 2 != 0) { - throw new IllegalArgumentException("invalid number of arguments specified; must be even"); - } - for (int i = 0; i < deviceAndRegistrationIds.length; i+=2) { - if (!(deviceAndRegistrationIds[i] instanceof Long)) { - throw new IllegalArgumentException("device id is not instance of long at index " + i); - } - if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) { - throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1)); - } - Long deviceId = (Long) deviceAndRegistrationIds[i]; - Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1]; - Device device = mock(Device.class); - when(device.getRegistrationId()).thenReturn(registrationId); - when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); - } - return account; - } - - static Collection> deviceAndRegistrationIds(Object... deviceAndRegistrationIds) { - final Collection> result = new HashSet<>(deviceAndRegistrationIds.length); - if (deviceAndRegistrationIds.length % 2 != 0) { - throw new IllegalArgumentException("invalid number of arguments specified; must be even"); - } - for (int i = 0; i < deviceAndRegistrationIds.length; i += 2) { - if (!(deviceAndRegistrationIds[i] instanceof Long)) { - throw new IllegalArgumentException("device id is not instance of long at index " + i); - } - if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) { - throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1)); - } - Long deviceId = (Long) deviceAndRegistrationIds[i]; - Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1]; - result.add(new Pair<>(deviceId, registrationId)); - } - return result; - } - - static Stream validateRegistrationIdsSource() { - return Stream.of( - arguments( - mockAccountWithDeviceAndRegId(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF), - deviceAndRegistrationIds(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF), - null), - arguments( - mockAccountWithDeviceAndRegId(1L, 42), - deviceAndRegistrationIds(1L, 1492), - Set.of(1L)), - arguments( - mockAccountWithDeviceAndRegId(1L, 42), - deviceAndRegistrationIds(1L, 42), - null), - arguments( - mockAccountWithDeviceAndRegId(1L, 42), - deviceAndRegistrationIds(1L, 0), - null), - arguments( - mockAccountWithDeviceAndRegId(1L, 42, 2L, 255), - deviceAndRegistrationIds(1L, 0, 2L, 42), - Set.of(2L)), - arguments( - mockAccountWithDeviceAndRegId(1L, 42, 2L, 256), - deviceAndRegistrationIds(1L, 41, 2L, 257), - Set.of(1L, 2L)) - ); - } - - @ParameterizedTest - @MethodSource("validateRegistrationIdsSource") - void testValidateRegistrationIds( - Account account, - Collection> deviceAndRegistrationIds, - Set expectedStaleDeviceIds) throws Exception { - if (expectedStaleDeviceIds != null) { - Assertions.assertThat(assertThrows(StaleDevicesException.class, () -> { - MessageController.validateRegistrationIds(account, deviceAndRegistrationIds.stream()); - }).getStaleDevices()).hasSameElementsAs(expectedStaleDeviceIds); - } else { - MessageController.validateRegistrationIds(account, deviceAndRegistrationIds.stream()); - } - } - - static Account mockAccountWithDeviceAndEnabled(Object... deviceIdAndEnabled) { - Account account = mock(Account.class); - if (deviceIdAndEnabled.length % 2 != 0) { - throw new IllegalArgumentException("invalid number of arguments specified; must be even"); - } - final Set devices = new HashSet<>(deviceIdAndEnabled.length / 2); - for (int i = 0; i < deviceIdAndEnabled.length; i+=2) { - if (!(deviceIdAndEnabled[i] instanceof Long)) { - throw new IllegalArgumentException("device id is not instance of long at index " + i); - } - if (!(deviceIdAndEnabled[i + 1] instanceof Boolean)) { - throw new IllegalArgumentException("enabled is not instance of boolean at index " + (i + 1)); - } - Long deviceId = (Long) deviceIdAndEnabled[i]; - Boolean enabled = (Boolean) deviceIdAndEnabled[i + 1]; - Device device = mock(Device.class); - when(device.isEnabled()).thenReturn(enabled); - when(device.getId()).thenReturn(deviceId); - when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); - devices.add(device); - } - when(account.getDevices()).thenReturn(devices); - return account; - } - - static Stream validateCompleteDeviceListSource() { - return Stream.of( - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(1L, 3L), - null, - null, - false, - null), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(1L, 2L, 3L), - null, - Set.of(2L), - false, - null), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(1L), - Set.of(3L), - null, - false, - null), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(1L, 2L), - Set.of(3L), - Set.of(2L), - false, - null), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(1L), - Set.of(3L), - Set.of(1L), - true, - 1L - ), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(2L), - Set.of(3L), - Set.of(2L), - true, - 1L - ), - arguments( - mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), - Set.of(3L), - null, - null, - true, - 1L - ) - ); - } - - @ParameterizedTest - @MethodSource("validateCompleteDeviceListSource") - void testValidateCompleteDeviceList( - Account account, - Set deviceIds, - Collection expectedMissingDeviceIds, - Collection expectedExtraDeviceIds, - boolean isSyncMessage, - Long authenticatedDeviceId) throws Exception { - if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) { - final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class, - () -> MessageController.validateCompleteDeviceList(account, deviceIds, isSyncMessage, - Optional.ofNullable(authenticatedDeviceId))); - if (expectedMissingDeviceIds != null) { - Assertions.assertThat(mismatchedDevicesException.getMissingDevices()) - .hasSameElementsAs(expectedMissingDeviceIds); - } - if (expectedExtraDeviceIds != null) { - Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds); - } - } else { - MessageController.validateCompleteDeviceList(account, deviceIds, isSyncMessage, - Optional.ofNullable(authenticatedDeviceId)); - } - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java new file mode 100644 index 000000000..7c8a25c34 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -0,0 +1,97 @@ +/* + * Copyright 2013-2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.storage; + +import org.apache.commons.codec.binary.Base64; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.mockito.stubbing.Answer; +import org.whispersystems.textsecuregcm.entities.IncomingMessage; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.push.MessageSender; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ChangeNumberManagerTest { + private static AccountsManager accountsManager = mock(AccountsManager.class); + private static MessageSender messageSender = mock(MessageSender.class); + private ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); + + @BeforeEach + void reset() throws Exception { + Mockito.reset(accountsManager, messageSender); + when(accountsManager.changeNumber(any(), any())).thenAnswer((Answer) invocation -> { + final Account account = invocation.getArgument(0, Account.class); + final String number = invocation.getArgument(1, String.class); + + final UUID uuid = account.getUuid(); + final Set devices = account.getDevices(); + + final Account updatedAccount = mock(Account.class); + when(updatedAccount.getUuid()).thenReturn(uuid); + when(updatedAccount.getNumber()).thenReturn(number); + when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID()); + when(updatedAccount.getDevices()).thenReturn(devices); + for (long i = 1; i <= 3; i++) { + final Optional d = account.getDevice(i); + when(updatedAccount.getDevice(i)).thenReturn(d); + } + + return updatedAccount; + }); + } + + @Test + void changeNumberNoMessages() throws Exception { + Account account = mock(Account.class); + when(account.getNumber()).thenReturn("+18005551234"); + changeNumberManager.changeNumber(account, "+18025551234", Collections.EMPTY_MAP, Collections.EMPTY_LIST); + verify(accountsManager).changeNumber(account, "+18025551234"); + verify(accountsManager, never()).updateDevice(any(), eq(1L), any()); + verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); + } + + @Test + void changeNumberSetPrimaryDevicePrekey() throws Exception { + Account account = mock(Account.class); + when(account.getNumber()).thenReturn("+18005551234"); + var prekeys = Map.of(1L, new SignedPreKey()); + changeNumberManager.changeNumber(account, "+18025551234", prekeys, Collections.EMPTY_LIST); + verify(accountsManager).changeNumber(account, "+18025551234"); + verify(accountsManager).updateDevice(any(), eq(1L), any()); + verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); + } + + @Test + void changeNumberSetPrimaryDevicePrekeyAndSendMessages() throws Exception { + Account account = mock(Account.class); + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + Device d2 = mock(Device.class); + when(account.getDevice(2L)).thenReturn(Optional.of(d2)); + var prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); + IncomingMessage msg = mock(IncomingMessage.class); + when(msg.getDestinationDeviceId()).thenReturn(2L); + when(msg.getContent()).thenReturn(Base64.encodeBase64String(new byte[]{1})); + changeNumberManager.changeNumber(account, "+18025551234", prekeys, List.of(msg)); + verify(accountsManager).changeNumber(account, "+18025551234"); + verify(accountsManager).updateDevice(any(), eq(1L), any()); + verify(accountsManager).updateDevice(any(), eq(2L), any()); + verify(messageSender).sendMessage(any(), eq(d2), any(), eq(false)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java index c0cf27b13..ae8b078cd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java @@ -33,6 +33,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; @@ -69,8 +70,10 @@ import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; +import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.RegistrationLock; import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper; @@ -88,6 +91,8 @@ import org.whispersystems.textsecuregcm.storage.AbusiveHostRule; import org.whispersystems.textsecuregcm.storage.AbusiveHostRules; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.ChangeNumberManager; +import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; import org.whispersystems.textsecuregcm.storage.UsernameNotAvailableException; @@ -141,6 +146,7 @@ class AccountControllerTest { private static RecaptchaClient recaptchaClient = mock(RecaptchaClient.class); private static GCMSender gcmSender = mock(GCMSender.class); private static APNSender apnSender = mock(APNSender.class); + private static ChangeNumberManager changeNumberManager = mock(ChangeNumberManager.class); private static DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); @@ -172,8 +178,9 @@ class AccountControllerTest { recaptchaClient, gcmSender, apnSender, - storageCredentialGenerator, - verifyExperimentEnrollmentManager)) + verifyExperimentEnrollmentManager, + changeNumberManager, + storageCredentialGenerator)) .build(); @@ -243,16 +250,22 @@ class AccountControllerTest { when(accountsManager.setUsername(AuthHelper.VALID_ACCOUNT, "takenusername")) .thenThrow(new UsernameNotAvailableException()); - when(accountsManager.changeNumber(any(), any())).thenAnswer((Answer) invocation -> { + when(changeNumberManager.changeNumber(any(), any(), any(), any())).thenAnswer((Answer) invocation -> { final Account account = invocation.getArgument(0, Account.class); final String number = invocation.getArgument(1, String.class); final UUID uuid = account.getUuid(); + final Set devices = account.getDevices(); final Account updatedAccount = mock(Account.class); when(updatedAccount.getUuid()).thenReturn(uuid); when(updatedAccount.getNumber()).thenReturn(number); when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID()); + when(updatedAccount.getDevices()).thenReturn(devices); + for (long i = 1; i <= 3; i++) { + final Optional d = account.getDevice(i); + when(updatedAccount.getDevice(i)).thenReturn(d); + } return updatedAccount; }); @@ -305,7 +318,8 @@ class AccountControllerTest { recaptchaClient, gcmSender, apnSender, - verifyExperimentEnrollmentManager); + verifyExperimentEnrollmentManager, + changeNumberManager); clearInvocations(AuthHelper.DISABLED_DEVICE); } @@ -1221,7 +1235,7 @@ class AccountControllerTest { } @Test - void testChangePhoneNumber() throws InterruptedException { + void testChangePhoneNumber() throws Exception { final String number = "+18005559876"; final String code = "987654"; @@ -1233,10 +1247,10 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); - verify(accountsManager).changeNumber(AuthHelper.VALID_ACCOUNT, number); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any()); assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID); assertThat(accountIdentityResponse.getNumber()).isEqualTo(number); @@ -1244,7 +1258,7 @@ class AccountControllerTest { } @Test - void testChangePhoneNumberImpossibleNumber() throws InterruptedException { + void testChangePhoneNumberImpossibleNumber() throws Exception { final String number = "This is not a real phone number"; final String code = "987654"; @@ -1253,16 +1267,16 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(400); assertThat(response.readEntity(String.class)).isBlank(); - verify(accountsManager, never()).changeNumber(any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); } @Test - void testChangePhoneNumberNonNormalized() throws InterruptedException { + void testChangePhoneNumberNonNormalized() throws Exception { final String number = "+4407700900111"; final String code = "987654"; @@ -1271,7 +1285,7 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(400); @@ -1280,28 +1294,24 @@ class AccountControllerTest { assertThat(responseEntity.getOriginalNumber()).isEqualTo(number); assertThat(responseEntity.getNormalizedNumber()).isEqualTo("+447700900111"); - verify(accountsManager, never()).changeNumber(any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); } @Test - void testChangePhoneNumberSameNumber() throws InterruptedException { + void testChangePhoneNumberSameNumber() throws Exception { final AccountIdentityResponse accountIdentityResponse = resources.getJerseyTest() .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null), + .put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); - verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any()); - - assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID); - assertThat(accountIdentityResponse.getNumber()).isEqualTo(AuthHelper.VALID_NUMBER); - assertThat(accountIdentityResponse.getPni()).isEqualTo(AuthHelper.VALID_PNI); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any()); } @Test - void testChangePhoneNumberNoPendingCode() throws InterruptedException { + void testChangePhoneNumberNoPendingCode() throws Exception { final String number = "+18005559876"; final String code = "987654"; @@ -1312,15 +1322,15 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(403); - verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); } @Test - void testChangePhoneNumberIncorrectCode() throws InterruptedException { + void testChangePhoneNumberIncorrectCode() throws Exception { final String number = "+18005559876"; final String code = "987654"; @@ -1332,15 +1342,15 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code + "-incorrect", null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code + "-incorrect", null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(403); - verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); } @Test - void testChangePhoneNumberExistingAccountReglockNotRequired() throws InterruptedException { + void testChangePhoneNumberExistingAccountReglockNotRequired() throws Exception { final String number = "+18005559876"; final String code = "987654"; @@ -1362,15 +1372,15 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(200); - verify(accountsManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any()); } @Test - void testChangePhoneNumberExistingAccountReglockRequiredNotProvided() throws InterruptedException { + void testChangePhoneNumberExistingAccountReglockRequiredNotProvided() throws Exception { final String number = "+18005559876"; final String code = "987654"; @@ -1392,15 +1402,15 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(423); - verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); } @Test - void testChangePhoneNumberExistingAccountReglockRequiredIncorrect() throws InterruptedException { + void testChangePhoneNumberExistingAccountReglockRequiredIncorrect() throws Exception { final String number = "+18005559876"; final String code = "987654"; final String reglock = "setec-astronomy"; @@ -1424,15 +1434,15 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(423); - verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); } @Test - void testChangePhoneNumberExistingAccountReglockRequiredCorrect() throws InterruptedException { + void testChangePhoneNumberExistingAccountReglockRequiredCorrect() throws Exception { final String number = "+18005559876"; final String code = "987654"; final String reglock = "setec-astronomy"; @@ -1456,11 +1466,142 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(200); - verify(accountsManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any()); + } + + @Test + void testChangePhoneNumberDeviceMessagesWithoutPrekeys() throws Exception { + final String number = "+18005559876"; + final String code = "987654"; + + when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( + new StoredVerificationCode(code, System.currentTimeMillis(), "push", null))); + + final Response response = + resources.getJerseyTest() + .target("/v1/accounts/number") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, + List.of(new IncomingMessage(1, null, 1, 1, "foo")), null), + MediaType.APPLICATION_JSON_TYPE)); + + assertThat(response.getStatus()).isEqualTo(400); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); + } + + @Test + void testChangePhoneNumberChangePrekeysDeviceMessagesMismatchDeviceIDs() throws Exception { + final String number = "+18005559876"; + final String code = "987654"; + + Device device2 = mock(Device.class); + when(device2.getId()).thenReturn(2L); + when(device2.isEnabled()).thenReturn(true); + Device device3 = mock(Device.class); + when(device3.getId()).thenReturn(3L); + when(device3.isEnabled()).thenReturn(true); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(AuthHelper.VALID_DEVICE, device2, device3)); + when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( + new StoredVerificationCode(code, System.currentTimeMillis(), "push", null))); + + final Response response = + resources.getJerseyTest() + .target("/v1/accounts/number") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(new ChangePhoneNumberRequest( + number, code, null, + List.of( + new IncomingMessage(1, null, 2, 1, "foo"), + new IncomingMessage(1, null, 4, 1, "foo")), + Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey())), + MediaType.APPLICATION_JSON_TYPE)); + + assertThat(response.getStatus()).isEqualTo(409); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any()); + } + + @Test + void testChangePhoneNumberChangePrekeys() throws Exception { + final String number = "+18005559876"; + final String code = "987654"; + + Device device2 = mock(Device.class); + when(device2.getId()).thenReturn(2L); + when(device2.isEnabled()).thenReturn(true); + when(device2.getRegistrationId()).thenReturn(2); + Device device3 = mock(Device.class); + when(device3.getId()).thenReturn(3L); + when(device3.isEnabled()).thenReturn(true); + when(device3.getRegistrationId()).thenReturn(3); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(AuthHelper.VALID_DEVICE, device2, device3)); + when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2)); + when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3)); + when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( + new StoredVerificationCode(code, System.currentTimeMillis(), "push", null))); + + var deviceMessages = List.of( + new IncomingMessage(1, null, 2, 2, "content2"), + new IncomingMessage(1, null, 3, 3, "content3")); + var deviceKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey()); + + final AccountIdentityResponse accountIdentityResponse = + resources.getJerseyTest() + .target("/v1/accounts/number") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(new ChangePhoneNumberRequest( + number, code, null, + deviceMessages, + deviceKeys), + MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); + + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any()); + + assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID); + assertThat(accountIdentityResponse.getNumber()).isEqualTo(number); + assertThat(accountIdentityResponse.getPni()).isNotEqualTo(AuthHelper.VALID_PNI); + } + + @Test + void testChangePhoneNumberChangePrekeysDeviceMessagesMismatchRegistrationID() throws Exception { + final String number = "+18005559876"; + final String code = "987654"; + + Device device2 = mock(Device.class); + when(device2.getId()).thenReturn(2L); + when(device2.isEnabled()).thenReturn(true); + when(device2.getRegistrationId()).thenReturn(2); + Device device3 = mock(Device.class); + when(device3.getId()).thenReturn(3L); + when(device3.isEnabled()).thenReturn(true); + when(device3.getRegistrationId()).thenReturn(3); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(AuthHelper.VALID_DEVICE, device2, device3)); + when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2)); + when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3)); + when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( + new StoredVerificationCode(code, System.currentTimeMillis(), "push", null))); + + final Response response = + resources.getJerseyTest() + .target("/v1/accounts/number") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(new ChangePhoneNumberRequest( + number, code, null, + List.of( + new IncomingMessage(1, null, 2, 1, "foo"), + new IncomingMessage(1, null, 3, 1, "foo")), + Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey())), + MediaType.APPLICATION_JSON_TYPE)); + + assertThat(response.getStatus()).isEqualTo(410); + verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any()); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageValidationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageValidationTest.java new file mode 100644 index 000000000..9235021a0 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageValidationTest.java @@ -0,0 +1,225 @@ +/* + * Copyright 2013-2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.tests.util; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import java.util.Collection; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.MessageValidation; +import org.whispersystems.textsecuregcm.util.Pair; + +@ExtendWith(DropwizardExtensionsSupport.class) +class MessageValidationTest { + + static Account mockAccountWithDeviceAndRegId(Object... deviceAndRegistrationIds) { + Account account = mock(Account.class); + if (deviceAndRegistrationIds.length % 2 != 0) { + throw new IllegalArgumentException("invalid number of arguments specified; must be even"); + } + for (int i = 0; i < deviceAndRegistrationIds.length; i+=2) { + if (!(deviceAndRegistrationIds[i] instanceof Long)) { + throw new IllegalArgumentException("device id is not instance of long at index " + i); + } + if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) { + throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1)); + } + Long deviceId = (Long) deviceAndRegistrationIds[i]; + Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1]; + Device device = mock(Device.class); + when(device.getRegistrationId()).thenReturn(registrationId); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + } + return account; + } + + static Collection> deviceAndRegistrationIds(Object... deviceAndRegistrationIds) { + final Collection> result = new HashSet<>(deviceAndRegistrationIds.length); + if (deviceAndRegistrationIds.length % 2 != 0) { + throw new IllegalArgumentException("invalid number of arguments specified; must be even"); + } + for (int i = 0; i < deviceAndRegistrationIds.length; i += 2) { + if (!(deviceAndRegistrationIds[i] instanceof Long)) { + throw new IllegalArgumentException("device id is not instance of long at index " + i); + } + if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) { + throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1)); + } + Long deviceId = (Long) deviceAndRegistrationIds[i]; + Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1]; + result.add(new Pair<>(deviceId, registrationId)); + } + return result; + } + + static Stream validateRegistrationIdsSource() { + return Stream.of( + arguments( + mockAccountWithDeviceAndRegId(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF), + deviceAndRegistrationIds(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF), + null), + arguments( + mockAccountWithDeviceAndRegId(1L, 42), + deviceAndRegistrationIds(1L, 1492), + Set.of(1L)), + arguments( + mockAccountWithDeviceAndRegId(1L, 42), + deviceAndRegistrationIds(1L, 42), + null), + arguments( + mockAccountWithDeviceAndRegId(1L, 42), + deviceAndRegistrationIds(1L, 0), + null), + arguments( + mockAccountWithDeviceAndRegId(1L, 42, 2L, 255), + deviceAndRegistrationIds(1L, 0, 2L, 42), + Set.of(2L)), + arguments( + mockAccountWithDeviceAndRegId(1L, 42, 2L, 256), + deviceAndRegistrationIds(1L, 41, 2L, 257), + Set.of(1L, 2L)) + ); + } + + @ParameterizedTest + @MethodSource("validateRegistrationIdsSource") + void testValidateRegistrationIds( + Account account, + Collection> deviceAndRegistrationIds, + Set expectedStaleDeviceIds) throws Exception { + if (expectedStaleDeviceIds != null) { + Assertions.assertThat(assertThrows(StaleDevicesException.class, () -> { + MessageValidation.validateRegistrationIds(account, deviceAndRegistrationIds.stream()); + }).getStaleDevices()).hasSameElementsAs(expectedStaleDeviceIds); + } else { + MessageValidation.validateRegistrationIds(account, deviceAndRegistrationIds.stream()); + } + } + + static Account mockAccountWithDeviceAndEnabled(Object... deviceIdAndEnabled) { + Account account = mock(Account.class); + if (deviceIdAndEnabled.length % 2 != 0) { + throw new IllegalArgumentException("invalid number of arguments specified; must be even"); + } + final Set devices = new HashSet<>(deviceIdAndEnabled.length / 2); + for (int i = 0; i < deviceIdAndEnabled.length; i+=2) { + if (!(deviceIdAndEnabled[i] instanceof Long)) { + throw new IllegalArgumentException("device id is not instance of long at index " + i); + } + if (!(deviceIdAndEnabled[i + 1] instanceof Boolean)) { + throw new IllegalArgumentException("enabled is not instance of boolean at index " + (i + 1)); + } + Long deviceId = (Long) deviceIdAndEnabled[i]; + Boolean enabled = (Boolean) deviceIdAndEnabled[i + 1]; + Device device = mock(Device.class); + when(device.isEnabled()).thenReturn(enabled); + when(device.getId()).thenReturn(deviceId); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + devices.add(device); + } + when(account.getDevices()).thenReturn(devices); + return account; + } + + static Stream validateCompleteDeviceListSource() { + return Stream.of( + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(1L, 3L), + null, + null, + false, + null), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(1L, 2L, 3L), + null, + Set.of(2L), + false, + null), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(1L), + Set.of(3L), + null, + false, + null), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(1L, 2L), + Set.of(3L), + Set.of(2L), + false, + null), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(1L), + Set.of(3L), + Set.of(1L), + true, + 1L + ), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(2L), + Set.of(3L), + Set.of(2L), + true, + 1L + ), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(3L), + null, + null, + true, + 1L + ) + ); + } + + @ParameterizedTest + @MethodSource("validateCompleteDeviceListSource") + void testValidateCompleteDeviceList( + Account account, + Set deviceIds, + Collection expectedMissingDeviceIds, + Collection expectedExtraDeviceIds, + boolean isSyncMessage, + Long authenticatedDeviceId) throws Exception { + if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) { + final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class, + () -> MessageValidation.validateCompleteDeviceList(account, deviceIds, isSyncMessage, + Optional.ofNullable(authenticatedDeviceId))); + if (expectedMissingDeviceIds != null) { + Assertions.assertThat(mismatchedDevicesException.getMissingDevices()) + .hasSameElementsAs(expectedMissingDeviceIds); + } + if (expectedExtraDeviceIds != null) { + Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds); + } + } else { + MessageValidation.validateCompleteDeviceList(account, deviceIds, isSyncMessage, + Optional.ofNullable(authenticatedDeviceId)); + } + } +}