Allow primary to set and provide new signed prekeys for linked devices (#950)

This commit is contained in:
gram-signal 2022-04-15 12:39:47 -06:00 committed by GitHub
parent 7b3703506b
commit 473ecbdf2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 782 additions and 324 deletions

@ -1 +1 @@
Subproject commit 8f690df72ccda8fcbf048eee4c07f3e60d52f1fd
Subproject commit ae98ea5c61257e76c98dc4db9e5c2911facb5849

View File

@ -169,6 +169,7 @@ import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerCache;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerListener;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
import org.whispersystems.textsecuregcm.storage.ContactDiscoveryWriter;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.DeletedAccountsDirectoryReconciler;
@ -496,6 +497,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new RateLimitChallengeOptionManager(dynamicRateLimiters, dynamicConfigurationManager);
MessagePersister messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, Duration.ofMinutes(config.getMessageCacheConfiguration().getPersistDelayMinutes()));
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
final List<AccountDatabaseCrawlerListener> directoryReconciliationAccountDatabaseCrawlerListeners = new ArrayList<>();
final List<DeletedAccountsDirectoryReconciler> deletedAccountsDirectoryReconcilers = new ArrayList<>();
@ -622,8 +624,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(
new AccountController(pendingAccountsManager, accountsManager, abusiveHostRules, rateLimiters,
smsSender, dynamicConfigurationManager, turnTokenGenerator, config.getTestDevices(),
recaptchaClient, gcmSender, apnSender, backupCredentialsGenerator,
verifyExperimentEnrollmentManager));
recaptchaClient, gcmSender, apnSender, verifyExperimentEnrollmentManager,
changeNumberManager, backupCredentialsGenerator));
environment.jersey().register(new KeysController(rateLimiters, keys, accountsManager));
final List<Object> commonControllers = Lists.newArrayList(

View File

@ -33,6 +33,7 @@ import javax.validation.constraints.NotNull;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.GET;
import javax.ws.rs.HEAD;
import javax.ws.rs.HeaderParam;
@ -69,8 +70,11 @@ import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest;
import org.whispersystems.textsecuregcm.entities.DeviceName;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.RegistrationLock;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.APNSender;
@ -84,6 +88,7 @@ import org.whispersystems.textsecuregcm.storage.AbusiveHostRule;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
@ -92,6 +97,7 @@ import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException;
import org.whispersystems.textsecuregcm.util.MessageValidation;
import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException;
import org.whispersystems.textsecuregcm.util.Username;
import org.whispersystems.textsecuregcm.util.Util;
@ -138,6 +144,7 @@ public class AccountController {
private final ExternalServiceCredentialGenerator backupServiceCredentialGenerator;
private final TwilioVerifyExperimentEnrollmentManager verifyExperimentEnrollmentManager;
private final ChangeNumberManager changeNumberManager;
public AccountController(StoredVerificationCodeManager pendingAccounts,
AccountsManager accounts,
@ -150,8 +157,9 @@ public class AccountController {
RecaptchaClient recaptchaClient,
GCMSender gcmSender,
APNSender apnSender,
ExternalServiceCredentialGenerator backupServiceCredentialGenerator,
TwilioVerifyExperimentEnrollmentManager verifyExperimentEnrollmentManager)
TwilioVerifyExperimentEnrollmentManager verifyExperimentEnrollmentManager,
ChangeNumberManager changeNumberManager,
ExternalServiceCredentialGenerator backupServiceCredentialGenerator)
{
this.pendingAccounts = pendingAccounts;
this.accounts = accounts;
@ -164,8 +172,9 @@ public class AccountController {
this.recaptchaClient = recaptchaClient;
this.gcmSender = gcmSender;
this.apnSender = apnSender;
this.backupServiceCredentialGenerator = backupServiceCredentialGenerator;
this.verifyExperimentEnrollmentManager = verifyExperimentEnrollmentManager;
this.backupServiceCredentialGenerator = backupServiceCredentialGenerator;
this.changeNumberManager = changeNumberManager;
}
@Timed
@ -403,38 +412,75 @@ public class AccountController {
public AccountIdentityResponse changeNumber(@Auth final AuthenticatedAccount authenticatedAccount, @NotNull @Valid final ChangePhoneNumberRequest request)
throws RateLimitExceededException, InterruptedException, ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException {
final Account updatedAccount;
if (!authenticatedAccount.getAuthenticatedDevice().isMaster()) {
throw new ForbiddenException();
}
if (request.getNumber().equals(authenticatedAccount.getAccount().getNumber())) {
// This may be a request that got repeated due to poor network conditions or other client error; take no action,
// but report success since the account is in the desired state
updatedAccount = authenticatedAccount.getAccount();
} else {
Util.requireNormalizedNumber(request.getNumber());
if (request.getDeviceSignedPrekeys() != null && !request.getDeviceSignedPrekeys().isEmpty()) {
if (request.getDeviceMessages() == null || request.getDeviceMessages().size() != request.getDeviceSignedPrekeys().size() - 1) {
// device_messages should exist and be one shorter than device_signed_prekeys, since it doesn't have the primary's key.
throw new WebApplicationException(Response.status(400).build());
}
try {
// Checks that all except master ID are in device messages
MessageValidation.validateCompleteDeviceList(
authenticatedAccount.getAccount(), request.getDeviceMessages(),
IncomingMessage::getDestinationDeviceId, true, Optional.of(Device.MASTER_ID));
MessageValidation.validateRegistrationIds(
authenticatedAccount.getAccount(), request.getDeviceMessages(),
IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId);
// Checks that all including master ID are in signed prekeys
MessageValidation.validateCompleteDeviceList(
authenticatedAccount.getAccount(), request.getDeviceSignedPrekeys().entrySet(),
e -> e.getKey(), false, Optional.empty());
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
}
} else if (request.getDeviceMessages() != null && !request.getDeviceMessages().isEmpty()) {
// device_messages shouldn't exist without device_signed_prekeys.
throw new WebApplicationException(Response.status(400).build());
}
rateLimiters.getVerifyLimiter().validate(request.getNumber());
final String number = request.getNumber();
if (!authenticatedAccount.getAccount().getNumber().equals(number)) {
Util.requireNormalizedNumber(number);
rateLimiters.getVerifyLimiter().validate(number);
final Optional<StoredVerificationCode> storedVerificationCode =
pendingAccounts.getCodeForNumber(request.getNumber());
pendingAccounts.getCodeForNumber(number);
if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(request.getCode())) {
throw new WebApplicationException(Response.status(403).build());
throw new ForbiddenException();
}
storedVerificationCode.flatMap(StoredVerificationCode::getTwilioVerificationSid)
.ifPresent(smsSender::reportVerificationSucceeded);
final Optional<Account> existingAccount = accounts.getByE164(request.getNumber());
final Optional<Account> existingAccount = accounts.getByE164(number);
if (existingAccount.isPresent()) {
verifyRegistrationLock(existingAccount.get(), request.getRegistrationLock());
}
rateLimiters.getVerifyLimiter().clear(request.getNumber());
updatedAccount = accounts.changeNumber(authenticatedAccount.getAccount(), request.getNumber());
rateLimiters.getVerifyLimiter().clear(number);
}
final Account updatedAccount = changeNumberManager.changeNumber(
authenticatedAccount.getAccount(),
request.getNumber(),
Optional.ofNullable(request.getDeviceSignedPrekeys()).orElse(Collections.emptyMap()),
Optional.ofNullable(request.getDeviceMessages()).orElse(Collections.emptyList()));
return new AccountIdentityResponse(
updatedAccount.getUuid(),
updatedAccount.getNumber(),

View File

@ -92,6 +92,7 @@ import org.whispersystems.textsecuregcm.storage.DeletedAccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.MessageValidation;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@ -225,10 +226,10 @@ public class MessageController {
checkRateLimit(source.get(), destination.get(), userAgent);
}
validateCompleteDeviceList(destination.get(), messages.getMessages(),
MessageValidation.validateCompleteDeviceList(destination.get(), messages.getMessages(),
IncomingMessage::getDestinationDeviceId, isSyncMessage,
source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId));
validateRegistrationIds(destination.get(), messages.getMessages(),
MessageValidation.validateRegistrationIds(destination.get(), messages.getMessages(),
IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId);
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
@ -319,10 +320,10 @@ public class MessageController {
}
final List<IncomingDeviceMessage> messagesAsList = Arrays.asList(messages);
validateCompleteDeviceList(destination.get(), messagesAsList,
MessageValidation.validateCompleteDeviceList(destination.get(), messagesAsList,
IncomingDeviceMessage::getDeviceId, isSyncMessage,
source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId));
validateRegistrationIds(destination.get(), messagesAsList,
MessageValidation.validateRegistrationIds(destination.get(), messagesAsList,
IncomingDeviceMessage::getDeviceId,
IncomingDeviceMessage::getRegistrationId);
@ -402,8 +403,8 @@ public class MessageController {
final Set<Pair<Long, Integer>> deviceIdAndRegistrationIdSet = accountToDeviceIdAndRegistrationIdMap.get(account);
final Set<Long> deviceIds = deviceIdAndRegistrationIdSet.stream().map(Pair::first).collect(Collectors.toSet());
try {
validateCompleteDeviceList(account, deviceIds, false, Optional.empty());
validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream());
MessageValidation.validateCompleteDeviceList(account, deviceIds, false, Optional.empty());
MessageValidation.validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream());
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
@ -731,72 +732,6 @@ public class MessageController {
}
}
@VisibleForTesting
public static <T> void validateRegistrationIds(Account account, List<T> messages, Function<T, Long> getDeviceId, Function<T, Integer> getRegistrationId)
throws StaleDevicesException {
final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = messages
.stream()
.map(message -> new Pair<>(getDeviceId.apply(message), getRegistrationId.apply(message)));
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
}
@VisibleForTesting
public static void validateRegistrationIds(Account account, Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream)
throws StaleDevicesException {
final List<Long> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
Optional<Device> device = account.getDevice(deviceIdAndRegistrationId.first());
return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId();
})
.map(Pair::first)
.collect(Collectors.toList());
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
@VisibleForTesting
public static <T> void validateCompleteDeviceList(Account account, List<T> messages, Function<T, Long> getDeviceId, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException {
Set<Long> messageDeviceIds = messages.stream().map(getDeviceId)
.collect(Collectors.toSet());
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId);
}
@VisibleForTesting
public static void validateCompleteDeviceList(Account account, Set<Long> messageDeviceIds, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException {
Set<Long> accountDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>();
List<Long> extraDeviceIds = new LinkedList<>();
for (Device device : account.getDevices()) {
if (device.isEnabled() &&
!(isSyncMessage && device.getId() == authenticatedDeviceId.get())) {
accountDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
}
}
}
for (Long deviceId : messageDeviceIds) {
if (!accountDeviceIds.contains(deviceId)) {
extraDeviceIds.add(deviceId);
}
}
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(missingDeviceIds, extraDeviceIds);
}
}
private void validateContentLength(final int contentLength, final String userAgent) {
Metrics.summary(CONTENT_SIZE_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.record(contentLength);
@ -818,7 +753,7 @@ public class MessageController {
}
}
private Optional<byte[]> getMessageContent(IncomingMessage message) {
public static Optional<byte[]> getMessageContent(IncomingMessage message) {
if (Util.isEmpty(message.getContent())) return Optional.empty();
try {

View File

@ -9,6 +9,8 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import javax.annotation.Nullable;
import javax.validation.constraints.NotBlank;
import java.util.List;
import java.util.Map;
public class ChangePhoneNumberRequest {
@ -24,14 +26,26 @@ public class ChangePhoneNumberRequest {
@Nullable
final String registrationLock;
@JsonProperty("device_messages")
@Nullable
final List<IncomingMessage> deviceMessages;
@JsonProperty("device_signed_prekeys")
@Nullable
final Map<Long, SignedPreKey> deviceSignedPrekeys;
@JsonCreator
public ChangePhoneNumberRequest(@JsonProperty("number") final String number,
@JsonProperty("code") final String code,
@JsonProperty("reglock") @Nullable final String registrationLock) {
@JsonProperty("reglock") @Nullable final String registrationLock,
@JsonProperty("device_messages") @Nullable final List<IncomingMessage> deviceMessages,
@JsonProperty("device_signed_prekeys") @Nullable final Map<Long, SignedPreKey> deviceSignedPrekeys) {
this.number = number;
this.code = code;
this.registrationLock = registrationLock;
this.deviceMessages = deviceMessages;
this.deviceSignedPrekeys = deviceSignedPrekeys;
}
public String getNumber() {
@ -46,4 +60,14 @@ public class ChangePhoneNumberRequest {
public String getRegistrationLock() {
return registrationLock;
}
@Nullable
public List<IncomingMessage> getDeviceMessages() {
return deviceMessages;
}
@Nullable
public Map<Long, SignedPreKey> getDeviceSignedPrekeys() {
return deviceSignedPrekeys;
}
}

View File

@ -0,0 +1,95 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import javax.validation.constraints.NotNull;
import java.util.List;
import java.util.Map;
import java.util.Optional;
public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(AccountController.class);
private final MessageSender messageSender;
private final AccountsManager accountsManager;
public ChangeNumberManager(
final MessageSender messageSender,
final AccountsManager accountsManager) {
this.messageSender = messageSender;
this.accountsManager = accountsManager;
}
public Account changeNumber(
@NotNull Account account,
@NotNull final String number,
@NotNull final Map<Long, SignedPreKey> deviceSignedPrekeys,
@NotNull final List<IncomingMessage> deviceMessages) throws InterruptedException {
final Account updatedAccount;
if (number.equals(account.getNumber())) {
// This may be a request that got repeated due to poor network conditions or other client error; take no action,
// but report success since the account is in the desired state
updatedAccount = account;
} else {
updatedAccount = accountsManager.changeNumber(account, number);
}
// Whether the account already has this number or not, we reset signed prekeys and resend messages.
// This makes it so the client can resend a request they didn't get a response for (timeout, etc)
// to make sure their messages sent and prekeys were updated, even if the first time around the
// server crashed at/above this point.
if (deviceSignedPrekeys != null && !deviceSignedPrekeys.isEmpty()) {
for (Map.Entry<Long, SignedPreKey> entry : deviceSignedPrekeys.entrySet()) {
accountsManager.updateDevice(updatedAccount, entry.getKey(),
d -> d.setPhoneNumberIdentitySignedPreKey(entry.getValue()));
}
for (IncomingMessage message : deviceMessages) {
sendMessageToSelf(updatedAccount, updatedAccount.getDevice(message.getDestinationDeviceId()), message);
}
}
return updatedAccount;
}
@VisibleForTesting
void sendMessageToSelf(
Account sourceAndDestinationAccount, Optional<Device> destinationDevice, IncomingMessage message) {
Optional<byte[]> contents = MessageController.getMessageContent(message);
if (!contents.isPresent()) {
logger.debug("empty message contents sending to self, ignoring");
return;
} else if (!destinationDevice.isPresent()) {
logger.debug("destination device not present");
return;
}
try {
long serverTimestamp = System.currentTimeMillis();
Envelope envelope = Envelope.newBuilder()
.setType(Envelope.Type.forNumber(message.getType()))
.setTimestamp(serverTimestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationUuid(sourceAndDestinationAccount.getUuid().toString())
.setContent(ByteString.copyFrom(contents.get()))
.setSource(sourceAndDestinationAccount.getNumber())
.setSourceUuid(sourceAndDestinationAccount.getUuid().toString())
.setSourceDevice((int) Device.MASTER_ID)
.build();
messageSender.sendMessage(sourceAndDestinationAccount, destinationDevice.get(), envelope, false);
} catch (NotPushRegisteredException e) {
logger.debug("Not registered", e);
}
}
}

View File

@ -0,0 +1,84 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class MessageValidation {
public static <T> void validateRegistrationIds(Account account, List<T> messages, Function<T, Long> getDeviceId, Function<T, Integer> getRegistrationId)
throws StaleDevicesException {
final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = messages
.stream()
.map(message -> new Pair<>(getDeviceId.apply(message), getRegistrationId.apply(message)));
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
}
public static void validateRegistrationIds(Account account, Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream)
throws StaleDevicesException {
final List<Long> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
Optional<Device> device = account.getDevice(deviceIdAndRegistrationId.first());
return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId();
})
.map(Pair::first)
.collect(Collectors.toList());
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
public static <T> void validateCompleteDeviceList(Account account, Collection<T> messages, Function<T, Long> getDeviceId, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException {
Set<Long> messageDeviceIds = messages.stream().map(getDeviceId)
.collect(Collectors.toSet());
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId);
}
public static void validateCompleteDeviceList(Account account, Set<Long> messageDeviceIds, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException {
Set<Long> accountDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>();
List<Long> extraDeviceIds = new LinkedList<>();
for (Device device : account.getDevices()) {
if (device.isEnabled() &&
!(isSyncMessage && device.getId() == authenticatedDeviceId.get())) {
accountDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
}
}
}
for (Long deviceId : messageDeviceIds) {
if (!accountDeviceIds.contains(deviceId)) {
extraDeviceIds.add(deviceId);
}
}
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(missingDeviceIds, extraDeviceIds);
}
}
}

View File

@ -85,6 +85,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.MessageValidation;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -775,196 +776,4 @@ class MessageControllerTest {
Arguments.of("fixtures/current_message_single_device_server_receipt_type.json", false)
);
}
static Account mockAccountWithDeviceAndRegId(Object... deviceAndRegistrationIds) {
Account account = mock(Account.class);
if (deviceAndRegistrationIds.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
for (int i = 0; i < deviceAndRegistrationIds.length; i+=2) {
if (!(deviceAndRegistrationIds[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) {
throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1));
}
Long deviceId = (Long) deviceAndRegistrationIds[i];
Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1];
Device device = mock(Device.class);
when(device.getRegistrationId()).thenReturn(registrationId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
}
return account;
}
static Collection<Pair<Long, Integer>> deviceAndRegistrationIds(Object... deviceAndRegistrationIds) {
final Collection<Pair<Long, Integer>> result = new HashSet<>(deviceAndRegistrationIds.length);
if (deviceAndRegistrationIds.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
for (int i = 0; i < deviceAndRegistrationIds.length; i += 2) {
if (!(deviceAndRegistrationIds[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) {
throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1));
}
Long deviceId = (Long) deviceAndRegistrationIds[i];
Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1];
result.add(new Pair<>(deviceId, registrationId));
}
return result;
}
static Stream<Arguments> validateRegistrationIdsSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndRegId(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
deviceAndRegistrationIds(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 1492),
Set.of(1L)),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 42),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 0),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42, 2L, 255),
deviceAndRegistrationIds(1L, 0, 2L, 42),
Set.of(2L)),
arguments(
mockAccountWithDeviceAndRegId(1L, 42, 2L, 256),
deviceAndRegistrationIds(1L, 41, 2L, 257),
Set.of(1L, 2L))
);
}
@ParameterizedTest
@MethodSource("validateRegistrationIdsSource")
void testValidateRegistrationIds(
Account account,
Collection<Pair<Long, Integer>> deviceAndRegistrationIds,
Set<Long> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class, () -> {
MessageController.validateRegistrationIds(account, deviceAndRegistrationIds.stream());
}).getStaleDevices()).hasSameElementsAs(expectedStaleDeviceIds);
} else {
MessageController.validateRegistrationIds(account, deviceAndRegistrationIds.stream());
}
}
static Account mockAccountWithDeviceAndEnabled(Object... deviceIdAndEnabled) {
Account account = mock(Account.class);
if (deviceIdAndEnabled.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
final Set<Device> devices = new HashSet<>(deviceIdAndEnabled.length / 2);
for (int i = 0; i < deviceIdAndEnabled.length; i+=2) {
if (!(deviceIdAndEnabled[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceIdAndEnabled[i + 1] instanceof Boolean)) {
throw new IllegalArgumentException("enabled is not instance of boolean at index " + (i + 1));
}
Long deviceId = (Long) deviceIdAndEnabled[i];
Boolean enabled = (Boolean) deviceIdAndEnabled[i + 1];
Device device = mock(Device.class);
when(device.isEnabled()).thenReturn(enabled);
when(device.getId()).thenReturn(deviceId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
devices.add(device);
}
when(account.getDevices()).thenReturn(devices);
return account;
}
static Stream<Arguments> validateCompleteDeviceListSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 3L),
null,
null,
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L, 3L),
null,
Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
null,
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L),
Set.of(3L),
Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
Set.of(1L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(2L),
Set.of(3L),
Set.of(2L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(3L),
null,
null,
true,
1L
)
);
}
@ParameterizedTest
@MethodSource("validateCompleteDeviceListSource")
void testValidateCompleteDeviceList(
Account account,
Set<Long> deviceIds,
Collection<Long> expectedMissingDeviceIds,
Collection<Long> expectedExtraDeviceIds,
boolean isSyncMessage,
Long authenticatedDeviceId) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> MessageController.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId)));
if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds);
}
if (expectedExtraDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
}
} else {
MessageController.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId));
}
}
}

View File

@ -0,0 +1,97 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.apache.commons.codec.binary.Base64;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class ChangeNumberManagerTest {
private static AccountsManager accountsManager = mock(AccountsManager.class);
private static MessageSender messageSender = mock(MessageSender.class);
private ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
@BeforeEach
void reset() throws Exception {
Mockito.reset(accountsManager, messageSender);
when(accountsManager.changeNumber(any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class);
final UUID uuid = account.getUuid();
final Set<Device> devices = account.getDevices();
final Account updatedAccount = mock(Account.class);
when(updatedAccount.getUuid()).thenReturn(uuid);
when(updatedAccount.getNumber()).thenReturn(number);
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID());
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
return updatedAccount;
});
}
@Test
void changeNumberNoMessages() throws Exception {
Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
changeNumberManager.changeNumber(account, "+18025551234", Collections.EMPTY_MAP, Collections.EMPTY_LIST);
verify(accountsManager).changeNumber(account, "+18025551234");
verify(accountsManager, never()).updateDevice(any(), eq(1L), any());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
}
@Test
void changeNumberSetPrimaryDevicePrekey() throws Exception {
Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
var prekeys = Map.of(1L, new SignedPreKey());
changeNumberManager.changeNumber(account, "+18025551234", prekeys, Collections.EMPTY_LIST);
verify(accountsManager).changeNumber(account, "+18025551234");
verify(accountsManager).updateDevice(any(), eq(1L), any());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
}
@Test
void changeNumberSetPrimaryDevicePrekeyAndSendMessages() throws Exception {
Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
Device d2 = mock(Device.class);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
var prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
IncomingMessage msg = mock(IncomingMessage.class);
when(msg.getDestinationDeviceId()).thenReturn(2L);
when(msg.getContent()).thenReturn(Base64.encodeBase64String(new byte[]{1}));
changeNumberManager.changeNumber(account, "+18025551234", prekeys, List.of(msg));
verify(accountsManager).changeNumber(account, "+18025551234");
verify(accountsManager).updateDevice(any(), eq(1L), any());
verify(accountsManager).updateDevice(any(), eq(2L), any());
verify(messageSender).sendMessage(any(), eq(d2), any(), eq(false));
}
}

View File

@ -33,6 +33,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
@ -69,8 +70,10 @@ import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.RegistrationLock;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper;
@ -88,6 +91,8 @@ import org.whispersystems.textsecuregcm.storage.AbusiveHostRule;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.UsernameNotAvailableException;
@ -141,6 +146,7 @@ class AccountControllerTest {
private static RecaptchaClient recaptchaClient = mock(RecaptchaClient.class);
private static GCMSender gcmSender = mock(GCMSender.class);
private static APNSender apnSender = mock(APNSender.class);
private static ChangeNumberManager changeNumberManager = mock(ChangeNumberManager.class);
private static DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@ -172,8 +178,9 @@ class AccountControllerTest {
recaptchaClient,
gcmSender,
apnSender,
storageCredentialGenerator,
verifyExperimentEnrollmentManager))
verifyExperimentEnrollmentManager,
changeNumberManager,
storageCredentialGenerator))
.build();
@ -243,16 +250,22 @@ class AccountControllerTest {
when(accountsManager.setUsername(AuthHelper.VALID_ACCOUNT, "takenusername"))
.thenThrow(new UsernameNotAvailableException());
when(accountsManager.changeNumber(any(), any())).thenAnswer((Answer<Account>) invocation -> {
when(changeNumberManager.changeNumber(any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class);
final UUID uuid = account.getUuid();
final Set<Device> devices = account.getDevices();
final Account updatedAccount = mock(Account.class);
when(updatedAccount.getUuid()).thenReturn(uuid);
when(updatedAccount.getNumber()).thenReturn(number);
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID());
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
return updatedAccount;
});
@ -305,7 +318,8 @@ class AccountControllerTest {
recaptchaClient,
gcmSender,
apnSender,
verifyExperimentEnrollmentManager);
verifyExperimentEnrollmentManager,
changeNumberManager);
clearInvocations(AuthHelper.DISABLED_DEVICE);
}
@ -1221,7 +1235,7 @@ class AccountControllerTest {
}
@Test
void testChangePhoneNumber() throws InterruptedException {
void testChangePhoneNumber() throws Exception {
final String number = "+18005559876";
final String code = "987654";
@ -1233,10 +1247,10 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(accountsManager).changeNumber(AuthHelper.VALID_ACCOUNT, number);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any());
assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.getNumber()).isEqualTo(number);
@ -1244,7 +1258,7 @@ class AccountControllerTest {
}
@Test
void testChangePhoneNumberImpossibleNumber() throws InterruptedException {
void testChangePhoneNumberImpossibleNumber() throws Exception {
final String number = "This is not a real phone number";
final String code = "987654";
@ -1253,16 +1267,16 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
assertThat(response.readEntity(String.class)).isBlank();
verify(accountsManager, never()).changeNumber(any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
}
@Test
void testChangePhoneNumberNonNormalized() throws InterruptedException {
void testChangePhoneNumberNonNormalized() throws Exception {
final String number = "+4407700900111";
final String code = "987654";
@ -1271,7 +1285,7 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
@ -1280,28 +1294,24 @@ class AccountControllerTest {
assertThat(responseEntity.getOriginalNumber()).isEqualTo(number);
assertThat(responseEntity.getNormalizedNumber()).isEqualTo("+447700900111");
verify(accountsManager, never()).changeNumber(any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
}
@Test
void testChangePhoneNumberSameNumber() throws InterruptedException {
void testChangePhoneNumberSameNumber() throws Exception {
final AccountIdentityResponse accountIdentityResponse =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null),
.put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any());
assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.getNumber()).isEqualTo(AuthHelper.VALID_NUMBER);
assertThat(accountIdentityResponse.getPni()).isEqualTo(AuthHelper.VALID_PNI);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any());
}
@Test
void testChangePhoneNumberNoPendingCode() throws InterruptedException {
void testChangePhoneNumberNoPendingCode() throws Exception {
final String number = "+18005559876";
final String code = "987654";
@ -1312,15 +1322,15 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);
verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
}
@Test
void testChangePhoneNumberIncorrectCode() throws InterruptedException {
void testChangePhoneNumberIncorrectCode() throws Exception {
final String number = "+18005559876";
final String code = "987654";
@ -1332,15 +1342,15 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code + "-incorrect", null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code + "-incorrect", null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);
verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
}
@Test
void testChangePhoneNumberExistingAccountReglockNotRequired() throws InterruptedException {
void testChangePhoneNumberExistingAccountReglockNotRequired() throws Exception {
final String number = "+18005559876";
final String code = "987654";
@ -1362,15 +1372,15 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
verify(accountsManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any());
}
@Test
void testChangePhoneNumberExistingAccountReglockRequiredNotProvided() throws InterruptedException {
void testChangePhoneNumberExistingAccountReglockRequiredNotProvided() throws Exception {
final String number = "+18005559876";
final String code = "987654";
@ -1392,15 +1402,15 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
}
@Test
void testChangePhoneNumberExistingAccountReglockRequiredIncorrect() throws InterruptedException {
void testChangePhoneNumberExistingAccountReglockRequiredIncorrect() throws Exception {
final String number = "+18005559876";
final String code = "987654";
final String reglock = "setec-astronomy";
@ -1424,15 +1434,15 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
}
@Test
void testChangePhoneNumberExistingAccountReglockRequiredCorrect() throws InterruptedException {
void testChangePhoneNumberExistingAccountReglockRequiredCorrect() throws Exception {
final String number = "+18005559876";
final String code = "987654";
final String reglock = "setec-astronomy";
@ -1456,11 +1466,142 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
verify(accountsManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any());
}
@Test
void testChangePhoneNumberDeviceMessagesWithoutPrekeys() throws Exception {
final String number = "+18005559876";
final String code = "987654";
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
final Response response =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null,
List.of(new IncomingMessage(1, null, 1, 1, "foo")), null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
}
@Test
void testChangePhoneNumberChangePrekeysDeviceMessagesMismatchDeviceIDs() throws Exception {
final String number = "+18005559876";
final String code = "987654";
Device device2 = mock(Device.class);
when(device2.getId()).thenReturn(2L);
when(device2.isEnabled()).thenReturn(true);
Device device3 = mock(Device.class);
when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(AuthHelper.VALID_DEVICE, device2, device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
final Response response =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
number, code, null,
List.of(
new IncomingMessage(1, null, 2, 1, "foo"),
new IncomingMessage(1, null, 4, 1, "foo")),
Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey())),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(409);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
}
@Test
void testChangePhoneNumberChangePrekeys() throws Exception {
final String number = "+18005559876";
final String code = "987654";
Device device2 = mock(Device.class);
when(device2.getId()).thenReturn(2L);
when(device2.isEnabled()).thenReturn(true);
when(device2.getRegistrationId()).thenReturn(2);
Device device3 = mock(Device.class);
when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true);
when(device3.getRegistrationId()).thenReturn(3);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(AuthHelper.VALID_DEVICE, device2, device3));
when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2));
when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
var deviceMessages = List.of(
new IncomingMessage(1, null, 2, 2, "content2"),
new IncomingMessage(1, null, 3, 3, "content3"));
var deviceKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey());
final AccountIdentityResponse accountIdentityResponse =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
number, code, null,
deviceMessages,
deviceKeys),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any());
assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.getNumber()).isEqualTo(number);
assertThat(accountIdentityResponse.getPni()).isNotEqualTo(AuthHelper.VALID_PNI);
}
@Test
void testChangePhoneNumberChangePrekeysDeviceMessagesMismatchRegistrationID() throws Exception {
final String number = "+18005559876";
final String code = "987654";
Device device2 = mock(Device.class);
when(device2.getId()).thenReturn(2L);
when(device2.isEnabled()).thenReturn(true);
when(device2.getRegistrationId()).thenReturn(2);
Device device3 = mock(Device.class);
when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true);
when(device3.getRegistrationId()).thenReturn(3);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(AuthHelper.VALID_DEVICE, device2, device3));
when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2));
when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
final Response response =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
number, code, null,
List.of(
new IncomingMessage(1, null, 2, 1, "foo"),
new IncomingMessage(1, null, 3, 1, "foo")),
Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey())),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(410);
verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any());
}
@Test

View File

@ -0,0 +1,225 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.MessageValidation;
import org.whispersystems.textsecuregcm.util.Pair;
@ExtendWith(DropwizardExtensionsSupport.class)
class MessageValidationTest {
static Account mockAccountWithDeviceAndRegId(Object... deviceAndRegistrationIds) {
Account account = mock(Account.class);
if (deviceAndRegistrationIds.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
for (int i = 0; i < deviceAndRegistrationIds.length; i+=2) {
if (!(deviceAndRegistrationIds[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) {
throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1));
}
Long deviceId = (Long) deviceAndRegistrationIds[i];
Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1];
Device device = mock(Device.class);
when(device.getRegistrationId()).thenReturn(registrationId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
}
return account;
}
static Collection<Pair<Long, Integer>> deviceAndRegistrationIds(Object... deviceAndRegistrationIds) {
final Collection<Pair<Long, Integer>> result = new HashSet<>(deviceAndRegistrationIds.length);
if (deviceAndRegistrationIds.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
for (int i = 0; i < deviceAndRegistrationIds.length; i += 2) {
if (!(deviceAndRegistrationIds[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) {
throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1));
}
Long deviceId = (Long) deviceAndRegistrationIds[i];
Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1];
result.add(new Pair<>(deviceId, registrationId));
}
return result;
}
static Stream<Arguments> validateRegistrationIdsSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndRegId(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
deviceAndRegistrationIds(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 1492),
Set.of(1L)),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 42),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 0),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42, 2L, 255),
deviceAndRegistrationIds(1L, 0, 2L, 42),
Set.of(2L)),
arguments(
mockAccountWithDeviceAndRegId(1L, 42, 2L, 256),
deviceAndRegistrationIds(1L, 41, 2L, 257),
Set.of(1L, 2L))
);
}
@ParameterizedTest
@MethodSource("validateRegistrationIdsSource")
void testValidateRegistrationIds(
Account account,
Collection<Pair<Long, Integer>> deviceAndRegistrationIds,
Set<Long> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class, () -> {
MessageValidation.validateRegistrationIds(account, deviceAndRegistrationIds.stream());
}).getStaleDevices()).hasSameElementsAs(expectedStaleDeviceIds);
} else {
MessageValidation.validateRegistrationIds(account, deviceAndRegistrationIds.stream());
}
}
static Account mockAccountWithDeviceAndEnabled(Object... deviceIdAndEnabled) {
Account account = mock(Account.class);
if (deviceIdAndEnabled.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
final Set<Device> devices = new HashSet<>(deviceIdAndEnabled.length / 2);
for (int i = 0; i < deviceIdAndEnabled.length; i+=2) {
if (!(deviceIdAndEnabled[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceIdAndEnabled[i + 1] instanceof Boolean)) {
throw new IllegalArgumentException("enabled is not instance of boolean at index " + (i + 1));
}
Long deviceId = (Long) deviceIdAndEnabled[i];
Boolean enabled = (Boolean) deviceIdAndEnabled[i + 1];
Device device = mock(Device.class);
when(device.isEnabled()).thenReturn(enabled);
when(device.getId()).thenReturn(deviceId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
devices.add(device);
}
when(account.getDevices()).thenReturn(devices);
return account;
}
static Stream<Arguments> validateCompleteDeviceListSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 3L),
null,
null,
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L, 3L),
null,
Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
null,
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L),
Set.of(3L),
Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
Set.of(1L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(2L),
Set.of(3L),
Set.of(2L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(3L),
null,
null,
true,
1L
)
);
}
@ParameterizedTest
@MethodSource("validateCompleteDeviceListSource")
void testValidateCompleteDeviceList(
Account account,
Set<Long> deviceIds,
Collection<Long> expectedMissingDeviceIds,
Collection<Long> expectedExtraDeviceIds,
boolean isSyncMessage,
Long authenticatedDeviceId) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> MessageValidation.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId)));
if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds);
}
if (expectedExtraDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
}
} else {
MessageValidation.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId));
}
}
}