Add support for setting PNI-associated registration IDs and identity keys when changing numbers

This commit is contained in:
Jon Chambers 2022-07-26 15:19:27 -04:00 committed by GitHub
parent c252118cfc
commit dce391a248
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 927 additions and 673 deletions

View File

@ -68,13 +68,11 @@ import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest;
import org.whispersystems.textsecuregcm.entities.DeviceName;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.RegistrationLock;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnMessage;
@ -95,7 +93,6 @@ import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException;
import org.whispersystems.textsecuregcm.util.MessageValidation;
import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException;
import org.whispersystems.textsecuregcm.util.Username;
import org.whispersystems.textsecuregcm.util.Util;
@ -416,41 +413,9 @@ public class AccountController {
throw new ForbiddenException();
}
if (request.getDeviceSignedPrekeys() != null && !request.getDeviceSignedPrekeys().isEmpty()) {
if (request.getDeviceMessages() == null || request.getDeviceMessages().size() != request.getDeviceSignedPrekeys().size() - 1) {
// device_messages should exist and be one shorter than device_signed_prekeys, since it doesn't have the primary's key.
throw new WebApplicationException(Response.status(400).build());
}
try {
// Checks that all except master ID are in device messages
MessageValidation.validateCompleteDeviceList(
authenticatedAccount.getAccount(), request.getDeviceMessages(),
IncomingMessage::getDestinationDeviceId, true, Optional.of(Device.MASTER_ID));
MessageValidation.validateRegistrationIds(
authenticatedAccount.getAccount(), request.getDeviceMessages(),
IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId);
// Checks that all including master ID are in signed prekeys
MessageValidation.validateCompleteDeviceList(
authenticatedAccount.getAccount(), request.getDeviceSignedPrekeys().entrySet(),
e -> e.getKey(), false, Optional.empty());
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
}
} else if (request.getDeviceMessages() != null && !request.getDeviceMessages().isEmpty()) {
// device_messages shouldn't exist without device_signed_prekeys.
throw new WebApplicationException(Response.status(400).build());
}
final String number = request.number();
final String number = request.getNumber();
// Only "bill" for rate limiting if we think there's a change to be made...
if (!authenticatedAccount.getAccount().getNumber().equals(number)) {
Util.requireNormalizedNumber(number);
@ -459,7 +424,7 @@ public class AccountController {
final Optional<StoredVerificationCode> storedVerificationCode =
pendingAccounts.getCodeForNumber(number);
if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(request.getCode())) {
if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(request.code())) {
throw new ForbiddenException();
}
@ -469,24 +434,42 @@ public class AccountController {
final Optional<Account> existingAccount = accounts.getByE164(number);
if (existingAccount.isPresent()) {
verifyRegistrationLock(existingAccount.get(), request.getRegistrationLock());
verifyRegistrationLock(existingAccount.get(), request.registrationLock());
}
rateLimiters.getVerifyLimiter().clear(number);
}
final Account updatedAccount = changeNumberManager.changeNumber(
authenticatedAccount.getAccount(),
request.getNumber(),
Optional.ofNullable(request.getDeviceSignedPrekeys()).orElse(Collections.emptyMap()),
Optional.ofNullable(request.getDeviceMessages()).orElse(Collections.emptyList()));
// ...but always attempt to make the change in case a client retries and needs to re-send messages
try {
final Account updatedAccount = changeNumberManager.changeNumber(
authenticatedAccount.getAccount(),
request.number(),
request.pniIdentityKey(),
Optional.ofNullable(request.devicePniSignedPrekeys()).orElse(Collections.emptyMap()),
Optional.ofNullable(request.deviceMessages()).orElse(Collections.emptyList()),
Optional.ofNullable(request.pniRegistrationIds()).orElse(Collections.emptyMap()));
return new AccountIdentityResponse(
updatedAccount.getUuid(),
updatedAccount.getNumber(),
updatedAccount.getPhoneNumberIdentifier(),
updatedAccount.getUsername().orElse(null),
updatedAccount.isStorageSupported());
return new AccountIdentityResponse(
updatedAccount.getUuid(),
updatedAccount.getNumber(),
updatedAccount.getPhoneNumberIdentifier(),
updatedAccount.getUsername().orElse(null),
updatedAccount.isStorageSupported());
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
} catch (IllegalArgumentException e) {
throw new BadRequestException(e);
}
}
@Timed
@ -625,6 +608,7 @@ public class AccountController {
d.setLastSeen(Util.todayInMillis());
d.setCapabilities(attributes.getCapabilities());
d.setRegistrationId(attributes.getRegistrationId());
attributes.getPhoneNumberIdentityRegistrationId().ifPresent(d::setPhoneNumberIdentityRegistrationId);
d.setUserAgent(userAgent);
});

View File

@ -198,6 +198,7 @@ public class DeviceController {
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
accountAttributes.getPhoneNumberIdentityRegistrationId().ifPresent(device::setPhoneNumberIdentityRegistrationId);
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());

View File

@ -197,7 +197,11 @@ public class KeysController {
PreKey preKey = preKeysByDeviceId.get(device.getId());
if (signedPreKey != null || preKey != null) {
responseItems.add(new PreKeyResponseItem(device.getId(), device.getRegistrationId(), signedPreKey, preKey));
final int registrationId = usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId();
responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedPreKey, preKey));
}
}
}

View File

@ -21,7 +21,7 @@ import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -89,8 +89,7 @@ import org.whispersystems.textsecuregcm.storage.DeletedAccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.MessageValidation;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
@ -214,11 +213,23 @@ public class MessageController {
checkRateLimit(source.get(), destination.get(), userAgent);
}
MessageValidation.validateCompleteDeviceList(destination.get(), messages.getMessages(),
IncomingMessage::getDestinationDeviceId, isSyncMessage,
source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId));
MessageValidation.validateRegistrationIds(destination.get(), messages.getMessages(),
IncomingMessage::getDestinationDeviceId, IncomingMessage::getDestinationRegistrationId);
final Set<Long> excludedDeviceIds;
if (isSyncMessage) {
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
} else {
excludedDeviceIds = Collections.emptySet();
}
DestinationDeviceValidator.validateCompleteDeviceList(destination.get(),
messages.getMessages().stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet()),
excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(destination.get(),
messages.getMessages().stream().collect(Collectors.toMap(
IncomingMessage::getDestinationDeviceId,
IncomingMessage::getDestinationRegistrationId)),
destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())),
@ -307,13 +318,25 @@ public class MessageController {
checkRateLimit(source.get(), destination.get(), userAgent);
}
final List<IncomingDeviceMessage> messagesAsList = Arrays.asList(messages);
MessageValidation.validateCompleteDeviceList(destination.get(), messagesAsList,
IncomingDeviceMessage::getDeviceId, isSyncMessage,
source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId));
MessageValidation.validateRegistrationIds(destination.get(), messagesAsList,
IncomingDeviceMessage::getDeviceId,
IncomingDeviceMessage::getRegistrationId);
final Set<Long> excludedDeviceIds;
if (isSyncMessage) {
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
} else {
excludedDeviceIds = Collections.emptySet();
}
DestinationDeviceValidator.validateCompleteDeviceList(
destination.get(),
Arrays.stream(messages).map(IncomingDeviceMessage::getDeviceId).collect(Collectors.toSet()),
excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(
destination.get(),
Arrays.stream(messages).collect(Collectors.toMap(
IncomingDeviceMessage::getDeviceId,
IncomingDeviceMessage::getRegistrationId)),
destination.get().getPhoneNumberIdentifier().equals(destinationUuid));
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
@ -372,27 +395,29 @@ public class MessageController {
}));
checkAccessKeys(accessKeys, uuidToAccountMap);
final Map<Account, HashSet<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap =
Arrays
.stream(multiRecipientMessage.getRecipients())
.collect(Collectors.toMap(
recipient -> uuidToAccountMap.get(recipient.getUuid()),
recipient -> new HashSet<>(
Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
(a, b) -> {
a.addAll(b);
return a;
}
));
final Map<Account, Map<Long, Integer>> accountToDeviceIdAndRegistrationIdMap = Arrays.stream(multiRecipientMessage.getRecipients())
.collect(Collectors.toMap(
recipient -> uuidToAccountMap.get(recipient.getUuid()),
recipient -> Map.of(recipient.getDeviceId(), recipient.getRegistrationId()),
(a, b) -> {
final Map<Long, Integer> combined = new HashMap<>();
combined.putAll(a);
combined.putAll(b);
return combined;
}
));
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
uuidToAccountMap.values().forEach(account -> {
final Set<Pair<Long, Integer>> deviceIdAndRegistrationIdSet = accountToDeviceIdAndRegistrationIdMap.get(account);
final Set<Long> deviceIds = deviceIdAndRegistrationIdSet.stream().map(Pair::first).collect(Collectors.toSet());
final Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).keySet();
try {
MessageValidation.validateCompleteDeviceList(account, deviceIds, false, Optional.empty());
MessageValidation.validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream());
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet());
// Multi-recipient messages are always sealed-sender messages, and so can never be sent to a phone number
// identity
DestinationDeviceValidator.validateRegistrationIds(account, accountToDeviceIdAndRegistrationIdMap.get(account), false);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));

View File

@ -6,9 +6,11 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import javax.annotation.Nullable;
import javax.validation.constraints.Size;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.util.ExactlySize;
import java.util.OptionalInt;
public class AccountAttributes {
@ -18,6 +20,10 @@ public class AccountAttributes {
@JsonProperty
private int registrationId;
@Nullable
@JsonProperty("pniRegistrationId")
private Integer phoneNumberIdentityRegistrationId;
@JsonProperty
@Size(max = 204, message = "This field must be less than 50 characters")
private String name;
@ -59,6 +65,10 @@ public class AccountAttributes {
return registrationId;
}
public OptionalInt getPhoneNumberIdentityRegistrationId() {
return phoneNumberIdentityRegistrationId != null ? OptionalInt.of(phoneNumberIdentityRegistrationId) : OptionalInt.empty();
}
public String getName() {
return name;
}

View File

@ -5,69 +5,17 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import javax.annotation.Nullable;
import javax.validation.constraints.NotBlank;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import javax.validation.constraints.NotBlank;
public class ChangePhoneNumberRequest {
@JsonProperty
@NotBlank
final String number;
@JsonProperty
@NotBlank
final String code;
@JsonProperty("reglock")
@Nullable
final String registrationLock;
@JsonProperty("device_messages")
@Nullable
final List<IncomingMessage> deviceMessages;
@JsonProperty("device_signed_prekeys")
@Nullable
final Map<Long, SignedPreKey> deviceSignedPrekeys;
@JsonCreator
public ChangePhoneNumberRequest(@JsonProperty("number") final String number,
@JsonProperty("code") final String code,
@JsonProperty("reglock") @Nullable final String registrationLock,
@JsonProperty("device_messages") @Nullable final List<IncomingMessage> deviceMessages,
@JsonProperty("device_signed_prekeys") @Nullable final Map<Long, SignedPreKey> deviceSignedPrekeys) {
this.number = number;
this.code = code;
this.registrationLock = registrationLock;
this.deviceMessages = deviceMessages;
this.deviceSignedPrekeys = deviceSignedPrekeys;
}
public String getNumber() {
return number;
}
public String getCode() {
return code;
}
@Nullable
public String getRegistrationLock() {
return registrationLock;
}
@Nullable
public List<IncomingMessage> getDeviceMessages() {
return deviceMessages;
}
@Nullable
public Map<Long, SignedPreKey> getDeviceSignedPrekeys() {
return deviceSignedPrekeys;
}
public record ChangePhoneNumberRequest(@NotBlank String number,
@NotBlank String code,
@JsonProperty("reglock") @Nullable String registrationLock,
@Nullable String pniIdentityKey,
@Nullable List<IncomingMessage> deviceMessages,
@Nullable Map<Long, SignedPreKey> devicePniSignedPrekeys,
@Nullable Map<Long, Integer> pniRegistrationIds) {
}

View File

@ -20,6 +20,7 @@ public class OutgoingMessageEntity {
private final UUID sourceUuid;
private final int sourceDevice;
private final UUID destinationUuid;
private final UUID updatedPni;
private final byte[] content;
private final long serverTimestamp;
@ -31,6 +32,7 @@ public class OutgoingMessageEntity {
@JsonProperty("sourceUuid") final UUID sourceUuid,
@JsonProperty("sourceDevice") final int sourceDevice,
@JsonProperty("destinationUuid") final UUID destinationUuid,
@JsonProperty("updatedPni") final UUID updatedPni,
@JsonProperty("content") final byte[] content,
@JsonProperty("serverTimestamp") final long serverTimestamp)
{
@ -41,6 +43,7 @@ public class OutgoingMessageEntity {
this.sourceUuid = sourceUuid;
this.sourceDevice = sourceDevice;
this.destinationUuid = destinationUuid;
this.updatedPni = updatedPni;
this.content = content;
this.serverTimestamp = serverTimestamp;
}
@ -73,6 +76,10 @@ public class OutgoingMessageEntity {
return destinationUuid;
}
public UUID getUpdatedPni() {
return updatedPni;
}
public byte[] getContent() {
return content;
}
@ -83,23 +90,21 @@ public class OutgoingMessageEntity {
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final OutgoingMessageEntity that = (OutgoingMessageEntity)o;
return type == that.type &&
timestamp == that.timestamp &&
sourceDevice == that.sourceDevice &&
serverTimestamp == that.serverTimestamp &&
guid.equals(that.guid) &&
Objects.equals(source, that.source) &&
Objects.equals(sourceUuid, that.sourceUuid) &&
destinationUuid.equals(that.destinationUuid) &&
Arrays.equals(content, that.content);
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
final OutgoingMessageEntity that = (OutgoingMessageEntity) o;
return type == that.type && timestamp == that.timestamp && sourceDevice == that.sourceDevice
&& serverTimestamp == that.serverTimestamp && guid.equals(that.guid) && Objects.equals(source, that.source)
&& Objects.equals(sourceUuid, that.sourceUuid) && destinationUuid.equals(that.destinationUuid)
&& Objects.equals(updatedPni, that.updatedPni) && Arrays.equals(content, that.content);
}
@Override
public int hashCode() {
int result = Objects.hash(guid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, serverTimestamp);
int result = Objects.hash(guid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, updatedPni,
serverTimestamp);
result = 31 * result + Arrays.hashCode(content);
return result;
}

View File

@ -19,7 +19,9 @@ import io.micrometer.core.instrument.Tags;
import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
@ -28,10 +30,13 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
@ -39,6 +44,7 @@ import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UsernameValidator;
import org.whispersystems.textsecuregcm.util.Util;
@ -152,6 +158,7 @@ public class AccountsManager {
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
accountAttributes.getPhoneNumberIdentityRegistrationId().ifPresent(device::setPhoneNumberIdentityRegistrationId);
device.setName(accountAttributes.getName());
device.setCapabilities(accountAttributes.getCapabilities());
device.setCreated(System.currentTimeMillis());
@ -220,7 +227,11 @@ public class AccountsManager {
}
}
public Account changeNumber(final Account account, final String number) throws InterruptedException {
public Account changeNumber(final Account account, final String number,
@Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
final String originalNumber = account.getNumber();
final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier();
@ -228,6 +239,22 @@ public class AccountsManager {
return account;
}
if (pniSignedPreKeys != null && pniRegistrationIds != null) {
// Check that all including master ID are in signed pre-keys
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniSignedPreKeys.keySet(),
Collections.emptySet());
// Check that all devices are accounted for in the map of new PNI registration IDs
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniRegistrationIds.keySet(),
Collections.emptySet());
} else if (pniSignedPreKeys != null || pniRegistrationIds != null) {
throw new IllegalArgumentException("Signed pre-keys and registration IDs must both be null or both be non-null");
}
final AtomicReference<Account> updatedAccount = new AtomicReference<>();
deletedAccountsManager.lockAndPut(account.getNumber(), number, (originalAci, deletedAci) -> {
@ -252,7 +279,22 @@ public class AccountsManager {
try {
numberChangedAccount = updateWithRetries(
account,
a -> true,
a -> {
//noinspection ConstantConditions
if (pniSignedPreKeys != null && pniRegistrationIds != null) {
pniSignedPreKeys.forEach((deviceId, signedPreKey) ->
a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey)));
pniRegistrationIds.forEach((deviceId, registrationId) ->
a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId)));
}
if (pniIdentityKey != null) {
a.setPhoneNumberIdentityKey(pniIdentityKey);
}
return true;
},
a -> accounts.changeNumber(a, number, phoneNumberIdentifier),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);

View File

@ -6,19 +6,25 @@ package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import javax.validation.constraints.NotNull;
import java.util.List;
import java.util.Map;
import java.util.Optional;
public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(AccountController.class);
@ -32,35 +38,54 @@ public class ChangeNumberManager {
this.accountsManager = accountsManager;
}
public Account changeNumber(
@NotNull Account account,
@NotNull final String number,
@NotNull final Map<Long, SignedPreKey> deviceSignedPrekeys,
@NotNull final List<IncomingMessage> deviceMessages) throws InterruptedException {
public Account changeNumber(final Account account, final String number,
@Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> deviceSignedPreKeys,
@Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Long, Integer> pniRegistrationIds)
throws InterruptedException, MismatchedDevicesException, StaleDevicesException {
if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
assert pniIdentityKey != null;
assert deviceSignedPreKeys != null;
assert deviceMessages != null;
assert pniRegistrationIds != null;
// Check that all except master ID are in device messages
DestinationDeviceValidator.validateCompleteDeviceList(
account,
deviceMessages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet()),
Set.of(Device.MASTER_ID));
DestinationDeviceValidator.validateRegistrationIds(
account,
deviceMessages.stream()
.collect(Collectors.toMap(
IncomingMessage::getDestinationDeviceId,
IncomingMessage::getDestinationRegistrationId)),
false);
} else if (!ObjectUtils.allNull(deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
throw new IllegalArgumentException("Signed pre-keys, device messages, and registration IDs must be all null or all non-null");
}
final Account updatedAccount;
if (number.equals(account.getNumber())) {
// This may be a request that got repeated due to poor network conditions or other client error; take no action,
// but report success since the account is in the desired state
updatedAccount = account;
} else {
updatedAccount = accountsManager.changeNumber(account, number);
updatedAccount = accountsManager.changeNumber(account, number, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds);
}
// Whether the account already has this number or not, we reset signed prekeys and resend messages.
// This makes it so the client can resend a request they didn't get a response for (timeout, etc)
// to make sure their messages sent and prekeys were updated, even if the first time around the
// server crashed at/above this point.
if (deviceSignedPrekeys != null && !deviceSignedPrekeys.isEmpty()) {
for (Map.Entry<Long, SignedPreKey> entry : deviceSignedPrekeys.entrySet()) {
accountsManager.updateDevice(updatedAccount, entry.getKey(),
d -> d.setPhoneNumberIdentitySignedPreKey(entry.getValue()));
}
for (IncomingMessage message : deviceMessages) {
sendMessageToSelf(updatedAccount, updatedAccount.getDevice(message.getDestinationDeviceId()), message);
}
// Whether the account already has this number or not, we resend messages. This makes it so the client can resend a
// request they didn't get a response for (timeout, etc) to make sure their messages sent even if the first time
// around the server crashed at/above this point.
if (deviceMessages != null) {
deviceMessages.forEach(message ->
sendMessageToSelf(updatedAccount, updatedAccount.getDevice(message.getDestinationDeviceId()), message));
}
return updatedAccount;
}
@ -86,6 +111,7 @@ public class ChangeNumberManager {
.setSource(sourceAndDestinationAccount.getNumber())
.setSourceUuid(sourceAndDestinationAccount.getUuid().toString())
.setSourceDevice((int) Device.MASTER_ID)
.setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString())
.build();
messageSender.sendMessage(sourceAndDestinationAccount, destinationDevice.get(), envelope, false);
} catch (NotPushRegisteredException e) {

View File

@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.OptionalInt;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
@ -49,6 +50,10 @@ public class Device {
@JsonProperty
private int registrationId;
@Nullable
@JsonProperty("pniRegistrationId")
private Integer phoneNumberIdentityRegistrationId;
@JsonProperty
private SignedPreKey signedPreKey;
@ -184,6 +189,14 @@ public class Device {
this.registrationId = registrationId;
}
public OptionalInt getPhoneNumberIdentityRegistrationId() {
return phoneNumberIdentityRegistrationId != null ? OptionalInt.of(phoneNumberIdentityRegistrationId) : OptionalInt.empty();
}
public void setPhoneNumberIdentityRegistrationId(final int phoneNumberIdentityRegistrationId) {
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
}
public SignedPreKey getSignedPreKey() {
return signedPreKey;
}

View File

@ -389,6 +389,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(),
envelope.hasDestinationUuid() ? UUID.fromString(envelope.getDestinationUuid()) : null,
envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null,
envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
}

View File

@ -46,6 +46,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
private static final String KEY_SOURCE_UUID = "SU";
private static final String KEY_SOURCE_DEVICE = "SD";
private static final String KEY_DESTINATION_UUID = "DU";
private static final String KEY_UPDATED_PNI = "UP";
private static final String KEY_CONTENT = "C";
private static final String KEY_TTL = "E";
@ -85,10 +86,12 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
.put(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT, convertLocalIndexMessageUuidSortKey(messageUuid))
.put(KEY_TYPE, AttributeValues.fromInt(message.getType().getNumber()))
.put(KEY_TIMESTAMP, AttributeValues.fromLong(message.getTimestamp()))
.put(KEY_TTL, AttributeValues.fromLong(getTtlForMessage(message)));
item.put(KEY_DESTINATION_UUID, AttributeValues.fromUUID(UUID.fromString(message.getDestinationUuid())));
.put(KEY_TTL, AttributeValues.fromLong(getTtlForMessage(message)))
.put(KEY_DESTINATION_UUID, AttributeValues.fromUUID(UUID.fromString(message.getDestinationUuid())));
if (message.hasUpdatedPni()) {
item.put(KEY_UPDATED_PNI, AttributeValues.fromUUID(UUID.fromString(message.getUpdatedPni())));
}
if (message.hasSource()) {
item.put(KEY_SOURCE, AttributeValues.fromString(message.getSource()));
}
@ -240,7 +243,9 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
final int sourceDevice = AttributeValues.getInt(message, KEY_SOURCE_DEVICE, 0);
final UUID destinationUuid = AttributeValues.getUUID(message, KEY_DESTINATION_UUID, null);
final byte[] content = AttributeValues.getByteArray(message, KEY_CONTENT, null);
return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, content, sortKey.getServerTimestamp());
final UUID updatedPni = AttributeValues.getUUID(message, KEY_UPDATED_PNI, null);
return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid,
updatedPni, content, sortKey.getServerTimestamp());
}
private void deleteRowsMatchingQuery(AttributeValue partitionKey, QueryRequest querySpec) {

View File

@ -0,0 +1,95 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
public class DestinationDeviceValidator {
/**
* Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID
* pairs in the given destination account. This method does <em>not</em> validate that all devices associated with the
* destination account are present in the given device ID/registration ID pairs.
*
* @param account the destination account against which to check the given device ID/registration ID pairs
* @param registrationIdsByDeviceId a map of device IDs to registration IDs
* @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device registration IDs
* associated with the account's PNI (if available); compare against the ACI-associated
* registration ID otherwise
*
* @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination
* account does not have a corresponding device or if the registration IDs do not match
*/
public static void validateRegistrationIds(final Account account,
final Map<Long, Integer> registrationIdsByDeviceId,
final boolean usePhoneNumberIdentity) throws StaleDevicesException {
final List<Long> staleDevices = new ArrayList<>();
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> {
if (registrationId > 0) {
final boolean registrationIdMatches =
account.getDevice(deviceId).map(device -> registrationId == (usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId()))
.orElse(false);
if (!registrationIdMatches) {
staleDevices.add(deviceId);
}
}
});
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
/**
* Validates that the given set of device IDs from a set of messages matches the set of device IDs associated with the
* given destination account in preparation for sending those messages to the destination account. In general, the set
* of device IDs must exactly match the set of active devices associated with the destination account. When sending a
* "sync," message, though, the authenticated account is sending messages from one of their devices to all other
* devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}.
*
* @param account the destination account against which to check the given set of device IDs
* @param messageDeviceIds the set of device IDs to check against the destination account
* @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be
* present in the given set of device IDs (i.e. the device that is sending a sync message)
*
* @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with
* the destination account or is missing entries associated with the destination
* account
*/
public static void validateCompleteDeviceList(final Account account,
final Set<Long> messageDeviceIds,
final Set<Long> excludedDeviceIds) throws MismatchedDevicesException {
final Set<Long> accountDeviceIds = account.getDevices().stream()
.filter(Device::isEnabled)
.map(Device::getId)
.filter(deviceId -> !excludedDeviceIds.contains(deviceId))
.collect(Collectors.toSet());
final Set<Long> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(messageDeviceIds);
final Set<Long> extraDeviceIds = new HashSet<>(messageDeviceIds);
extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(new ArrayList<>(missingDeviceIds), new ArrayList<>(extraDeviceIds));
}
}
}

View File

@ -1,84 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class MessageValidation {
public static <T> void validateRegistrationIds(Account account, List<T> messages, Function<T, Long> getDeviceId, Function<T, Integer> getRegistrationId)
throws StaleDevicesException {
final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = messages
.stream()
.map(message -> new Pair<>(getDeviceId.apply(message), getRegistrationId.apply(message)));
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
}
public static void validateRegistrationIds(Account account, Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream)
throws StaleDevicesException {
final List<Long> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
Optional<Device> device = account.getDevice(deviceIdAndRegistrationId.first());
return device.isPresent() && deviceIdAndRegistrationId.second() != device.get().getRegistrationId();
})
.map(Pair::first)
.collect(Collectors.toList());
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
public static <T> void validateCompleteDeviceList(Account account, Collection<T> messages, Function<T, Long> getDeviceId, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException {
Set<Long> messageDeviceIds = messages.stream().map(getDeviceId)
.collect(Collectors.toSet());
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId);
}
public static void validateCompleteDeviceList(Account account, Set<Long> messageDeviceIds, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException {
Set<Long> accountDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>();
List<Long> extraDeviceIds = new LinkedList<>();
for (Device device : account.getDevices()) {
if (device.isEnabled() &&
!(isSyncMessage && device.getId() == authenticatedDeviceId.get())) {
accountDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
}
}
}
for (Long deviceId : messageDeviceIds) {
if (!accountDeviceIds.contains(deviceId)) {
extraDeviceIds.add(deviceId);
}
}
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(missingDeviceIds, extraDeviceIds);
}
}
}

View File

@ -333,8 +333,13 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
builder.setDestinationUuid(message.getDestinationUuid().toString());
if (message.getUpdatedPni() != null) {
builder.setUpdatedPni(message.getUpdatedPni().toString());
}
builder.setServerGuid(message.getGuid().toString());
final Envelope envelope = builder.build();
if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) {

View File

@ -41,6 +41,8 @@ message Envelope {
optional uint64 server_timestamp = 10;
optional bool ephemeral = 12; // indicates that the message should not be persisted if the recipient is offline
optional string destination_uuid = 13;
optional string updated_pni = 15;
// next: 16
}
message ProvisioningUuid {

View File

@ -27,7 +27,6 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
import com.google.common.collect.ImmutableSet;
import com.vdurmont.semver4j.Semver;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
@ -38,7 +37,6 @@ import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
@ -92,6 +90,7 @@ class MessageControllerTest {
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID();
private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID();
private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
private static final UUID INTERNATIONAL_UUID = UUID.randomUUID();
@ -127,31 +126,33 @@ class MessageControllerTest {
@BeforeEach
void setup() {
final List<Device> singleDeviceList = List.of(
generateTestDevice(1, 111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis())
generateTestDevice(1, 111, 1111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis())
);
final List<Device> multiDeviceList = List.of(
generateTestDevice(1, 222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(2, 333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(3, 444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
generateTestDevice(1, 222, 2222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(2, 333, 3333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(3, 444, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
);
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, "1234".getBytes());
Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, UUID.randomUUID(), multiDeviceList, "1234".getBytes());
Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, "1234".getBytes());
internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, "1234".getBytes());
when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByAccountIdentifier(eq(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByPhoneNumberIdentifier(MULTI_DEVICE_PNI)).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByAccountIdentifier(INTERNATIONAL_UUID)).thenReturn(Optional.of(internationalAccount));
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
}
private static Device generateTestDevice(final long id, final int registrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
final Device device = new Device();
device.setId(id);
device.setRegistrationId(registrationId);
device.setPhoneNumberIdentityRegistrationId(pniRegistrationId);
device.setSignedPreKey(signedPreKey);
device.setCreated(createdAt);
device.setLastSeen(lastSeen);
@ -197,6 +198,28 @@ class MessageControllerTest {
}
}
private static Stream<Entity<?>> currentMessageSingleDevicePayloadsPni() {
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
messageStream.write(1); // version
messageStream.write(1); // count
messageStream.write(1); // device ID
messageStream.writeBytes(new byte[] { (byte)0x04, (byte)0x57 }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
try {
return Stream.of(
Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE),
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
);
} catch (Exception e) {
throw new AssertionError(e);
}
}
@ParameterizedTest
@MethodSource("currentMessageSingleDevicePayloads")
void testSendFromDisabledAccount(Entity<?> payload) throws Exception {
@ -230,7 +253,7 @@ class MessageControllerTest {
}
@ParameterizedTest
@MethodSource("currentMessageSingleDevicePayloads")
@MethodSource("currentMessageSingleDevicePayloadsPni")
void testSingleDeviceCurrentByPni(Entity<?> payload) throws Exception {
Response response =
resources.getJerseyTest()
@ -403,6 +426,50 @@ class MessageControllerTest {
}
}
@ParameterizedTest
@MethodSource
void testMultiDeviceByPni(Entity<?> payload) throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(payload);
assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
verify(messageSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false));
}
private static Stream<Entity<?>> testMultiDeviceByPni() {
ByteArrayOutputStream messageStream = new ByteArrayOutputStream();
messageStream.write(1); // version
messageStream.write(2); // count
messageStream.write(1); // device ID
messageStream.writeBytes(new byte[] { (byte)0x08, (byte)0xae }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
messageStream.write(2); // device ID
messageStream.writeBytes(new byte[] { (byte)0x0d, (byte)0x05 }); // registration ID
messageStream.write(1); // message type
messageStream.write(3); // message length
messageStream.writeBytes(new byte[] { (byte)1, (byte)2, (byte)3 }); // message contents
try {
return Stream.of(
Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_multi_device_pni.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE),
Entity.entity(messageStream.toByteArray(), MultiDeviceMessageListProvider.MEDIA_TYPE)
);
} catch (Exception e) {
throw new AssertionError(e);
}
}
@ParameterizedTest
@MethodSource
void testRegistrationIdMismatch(Entity<?> payload) throws Exception {
@ -459,9 +526,11 @@ class MessageControllerTest {
final UUID messageGuidOne = UUID.randomUUID();
final UUID sourceUuid = UUID.randomUUID();
final UUID updatedPniOne = UUID.randomUUID();
List<OutgoingMessageEntity> messages = new LinkedList<>() {{
add(new OutgoingMessageEntity(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, "hi there".getBytes(), 0));
add(new OutgoingMessageEntity(null, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, null, 0));
add(new OutgoingMessageEntity(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0));
add(new OutgoingMessageEntity(null, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, null, null, 0));
}};
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
@ -485,16 +554,19 @@ class MessageControllerTest {
assertEquals(response.getMessages().get(0).getSourceUuid(), sourceUuid);
assertEquals(response.getMessages().get(1).getSourceUuid(), sourceUuid);
assertEquals(updatedPniOne, response.getMessages().get(0).getUpdatedPni());
assertNull(response.getMessages().get(1).getUpdatedPni());
}
@Test
void testGetMessagesBadAuth() throws Exception {
void testGetMessagesBadAuth() {
final long timestampOne = 313377;
final long timestampTwo = 313388;
List<OutgoingMessageEntity> messages = new LinkedList<OutgoingMessageEntity>() {{
add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, "hi there".getBytes(), 0));
add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, 0));
List<OutgoingMessageEntity> messages = new LinkedList<>() {{
add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0));
add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0));
}};
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
@ -520,12 +592,12 @@ class MessageControllerTest {
UUID uuid1 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null)).thenReturn(Optional.of(new OutgoingMessageEntity(
uuid1, Envelope.Type.CIPHERTEXT_VALUE,
timestamp, "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, "hi".getBytes(), 0)));
timestamp, "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0)));
UUID uuid2 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null)).thenReturn(Optional.of(new OutgoingMessageEntity(
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE,
System.currentTimeMillis(), "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, 0)));
System.currentTimeMillis(), "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0)));
UUID uuid3 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid3, null)).thenReturn(Optional.empty());

View File

@ -15,14 +15,18 @@ import static org.mockito.Mockito.when;
import java.time.Clock;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
@ -198,7 +202,7 @@ class AccountsManagerChangeNumberIntegrationTest {
}
@Test
void testChangeNumber() throws InterruptedException {
void testChangeNumber() throws InterruptedException, MismatchedDevicesException {
final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222";
@ -206,7 +210,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
accountsManager.changeNumber(account, secondNumber);
accountsManager.changeNumber(account, secondNumber, null, null, null);
assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -221,7 +225,46 @@ class AccountsManagerChangeNumberIntegrationTest {
}
@Test
void testChangeNumberReturnToOriginal() throws InterruptedException {
void testChangeNumberWithPniExtensions() throws InterruptedException, MismatchedDevicesException {
final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222";
final int rotatedPniRegistrationId = 17;
final SignedPreKey rotatedSignedPreKey = new SignedPreKey(1, "test", "test");
final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, "test", null, true, new Device.DeviceCapabilities());
final Account account = accountsManager.create(originalNumber, "password", null, accountAttributes, new ArrayList<>());
account.getMasterDevice().orElseThrow().setSignedPreKey(new SignedPreKey());
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
final String pniIdentityKey = "changed-pni-identity-key";
final Map<Long, SignedPreKey> preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey);
final Map<Long, Integer> registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId);
final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, registrationIds);
assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
assertTrue(accountsManager.getByE164(secondNumber).isPresent());
assertEquals(originalUuid, accountsManager.getByE164(secondNumber).map(Account::getUuid).orElseThrow());
assertNotEquals(originalPni, accountsManager.getByE164(secondNumber).map(Account::getPhoneNumberIdentifier).orElseThrow());
assertEquals(secondNumber, accountsManager.getByAccountIdentifier(originalUuid).map(Account::getNumber).orElseThrow());
assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));
assertEquals(pniIdentityKey, updatedAccount.getPhoneNumberIdentityKey());
assertEquals(OptionalInt.of(rotatedPniRegistrationId),
updatedAccount.getMasterDevice().orElseThrow().getPhoneNumberIdentityRegistrationId());
assertEquals(rotatedSignedPreKey, updatedAccount.getMasterDevice().orElseThrow().getPhoneNumberIdentitySignedPreKey());
}
@Test
void testChangeNumberReturnToOriginal() throws InterruptedException, MismatchedDevicesException {
final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222";
@ -229,8 +272,8 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
account = accountsManager.changeNumber(account, secondNumber);
accountsManager.changeNumber(account, originalNumber);
account = accountsManager.changeNumber(account, secondNumber, null, null, null);
accountsManager.changeNumber(account, originalNumber, null, null, null);
assertTrue(accountsManager.getByE164(originalNumber).isPresent());
assertEquals(originalUuid, accountsManager.getByE164(originalNumber).map(Account::getUuid).orElseThrow());
@ -245,7 +288,7 @@ class AccountsManagerChangeNumberIntegrationTest {
}
@Test
void testChangeNumberContested() throws InterruptedException {
void testChangeNumberContested() throws InterruptedException, MismatchedDevicesException {
final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222";
@ -255,7 +298,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final UUID existingAccountUuid = existingAccount.getUuid();
accountsManager.changeNumber(account, secondNumber);
accountsManager.changeNumber(account, secondNumber, null, null, null);
assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -269,7 +312,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.of(existingAccountUuid), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));
accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber);
accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null);
final Account existingAccount2 = accountsManager.create(secondNumber, "password", null, new AccountAttributes(),
new ArrayList<>());
@ -278,7 +321,7 @@ class AccountsManagerChangeNumberIntegrationTest {
}
@Test
void testChangeNumberChaining() throws InterruptedException {
void testChangeNumberChaining() throws InterruptedException, MismatchedDevicesException {
final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222";
@ -289,7 +332,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final UUID existingAccountUuid = existingAccount.getUuid();
final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber);
final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null);
final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier();
final Account reRegisteredAccount = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>());
@ -300,7 +343,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));
final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber);
final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber, null, null, null);
assertEquals(Optional.of(originalUuid), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));

View File

@ -45,6 +45,7 @@ import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
@ -641,7 +642,7 @@ class AccountsManagerTest {
}
@Test
void testChangePhoneNumber() throws InterruptedException {
void testChangePhoneNumber() throws InterruptedException, MismatchedDevicesException {
doAnswer(invocation -> invocation.getArgument(2, BiFunction.class).apply(Optional.empty(), Optional.empty()))
.when(deletedAccountsManager).lockAndPut(anyString(), anyString(), any());
@ -651,7 +652,7 @@ class AccountsManagerTest {
final UUID originalPni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, targetNumber);
account = accountsManager.changeNumber(account, targetNumber, null, null, null);
assertEquals(targetNumber, account.getNumber());
@ -663,11 +664,11 @@ class AccountsManagerTest {
}
@Test
void testChangePhoneNumberSameNumber() throws InterruptedException {
void testChangePhoneNumberSameNumber() throws InterruptedException, MismatchedDevicesException {
final String number = "+14152222222";
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, number);
account = accountsManager.changeNumber(account, number, null, null, null);
assertEquals(number, account.getNumber());
verify(deletedAccountsManager, never()).lockAndPut(anyString(), anyString(), any());
@ -676,7 +677,7 @@ class AccountsManagerTest {
}
@Test
void testChangePhoneNumberExistingAccount() throws InterruptedException {
void testChangePhoneNumberExistingAccount() throws InterruptedException, MismatchedDevicesException {
doAnswer(invocation -> invocation.getArgument(2, BiFunction.class).apply(Optional.empty(), Optional.empty()))
.when(deletedAccountsManager).lockAndPut(anyString(), anyString(), any());
@ -691,7 +692,7 @@ class AccountsManagerTest {
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, targetNumber);
account = accountsManager.changeNumber(account, targetNumber, null, null, null);
assertEquals(targetNumber, account.getNumber());

View File

@ -4,6 +4,8 @@
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
@ -11,7 +13,9 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -19,31 +23,43 @@ import java.util.UUID;
import org.apache.commons.codec.binary.Base64;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender;
public class ChangeNumberManagerTest {
private static AccountsManager accountsManager = mock(AccountsManager.class);
private static MessageSender messageSender = mock(MessageSender.class);
private ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
private AccountsManager accountsManager;
private MessageSender messageSender;
private ChangeNumberManager changeNumberManager;
private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount;
@BeforeEach
void reset() throws Exception {
Mockito.reset(accountsManager, messageSender);
when(accountsManager.changeNumber(any(), any())).thenAnswer((Answer<Account>) invocation -> {
void setUp() throws Exception {
accountsManager = mock(AccountsManager.class);
messageSender = mock(MessageSender.class);
changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
updatedPhoneNumberIdentifiersByAccount = new HashMap<>();
when(accountsManager.changeNumber(any(), any(), any(), any(), any())).thenAnswer((Answer<Account>)invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class);
final UUID uuid = account.getUuid();
final List<Device> devices = account.getDevices();
final UUID updatedPni = UUID.randomUUID();
updatedPhoneNumberIdentifiersByAccount.put(account, updatedPni);
final Account updatedAccount = mock(Account.class);
when(updatedAccount.getUuid()).thenReturn(uuid);
when(updatedAccount.getNumber()).thenReturn(number);
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID());
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(updatedPni);
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
@ -58,8 +74,8 @@ public class ChangeNumberManagerTest {
void changeNumberNoMessages() throws Exception {
Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
changeNumberManager.changeNumber(account, "+18025551234", Collections.EMPTY_MAP, Collections.EMPTY_LIST);
verify(accountsManager).changeNumber(account, "+18025551234");
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null);
verify(accountsManager, never()).updateDevice(any(), eq(1L), any());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
}
@ -69,27 +85,112 @@ public class ChangeNumberManagerTest {
Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
var prekeys = Map.of(1L, new SignedPreKey());
changeNumberManager.changeNumber(account, "+18025551234", prekeys, Collections.EMPTY_LIST);
verify(accountsManager).changeNumber(account, "+18025551234");
verify(accountsManager).updateDevice(any(), eq(1L), any());
final String pniIdentityKey = "pni-identity-key";
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, Collections.emptyMap());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
}
@Test
void changeNumberSetPrimaryDevicePrekeyAndSendMessages() throws Exception {
Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
Device d2 = mock(Device.class);
final String originalE164 = "+18005551234";
final String changedE164 = "+18025551234";
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn(originalE164);
when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
var prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
IncomingMessage msg = mock(IncomingMessage.class);
when(account.getDevices()).thenReturn(List.of(d2));
final String pniIdentityKey = "pni-identity-key";
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.getDestinationDeviceId()).thenReturn(2L);
when(msg.getContent()).thenReturn(Base64.encodeBase64String(new byte[]{1}));
changeNumberManager.changeNumber(account, "+18025551234", prekeys, List.of(msg));
verify(accountsManager).changeNumber(account, "+18025551234");
verify(accountsManager).updateDevice(any(), eq(1L), any());
verify(accountsManager).updateDevice(any(), eq(2L), any());
verify(messageSender).sendMessage(any(), eq(d2), any(), eq(false));
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, List.of(msg), registrationIds);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(changedE164, envelope.getSource());
assertEquals(Device.MASTER_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
}
@Test
void changeNumberMismatchedRegistrationId() {
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
final List<Device> devices = new ArrayList<>();
for (int i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((long) i);
when(device.isEnabled()).thenReturn(true);
when(device.getRegistrationId()).thenReturn(i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
}
when(account.getDevices()).thenReturn(devices);
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, null, 2, 1, "foo"),
new IncomingMessage(1, null, 3, 1, "foo"));
final Map<Long, SignedPreKey> preKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", preKeys, messages, registrationIds));
}
@Test
void changeNumberMissingData() {
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
final List<Device> devices = new ArrayList<>();
for (int i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((long) i);
when(device.isEnabled()).thenReturn(true);
when(device.getRegistrationId()).thenReturn(i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
}
when(account.getDevices()).thenReturn(devices);
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, null, 2, 2, "foo"),
new IncomingMessage(1, null, 3, 3, "foo"));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(IllegalArgumentException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", null, messages, registrationIds));
}
}

View File

@ -253,9 +253,10 @@ class AccountControllerTest {
when(accountsManager.setUsername(AuthHelper.VALID_ACCOUNT, "takenusername"))
.thenThrow(new UsernameNotAvailableException());
when(changeNumberManager.changeNumber(any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class);
final String pniIdentityKey = invocation.getArgument(2, String.class);
final UUID uuid = account.getUuid();
final List<Device> devices = account.getDevices();
@ -263,8 +264,10 @@ class AccountControllerTest {
final Account updatedAccount = mock(Account.class);
when(updatedAccount.getUuid()).thenReturn(uuid);
when(updatedAccount.getNumber()).thenReturn(number);
when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey);
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID());
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
@ -1298,10 +1301,10 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any());
assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.getNumber()).isEqualTo(number);
@ -1318,12 +1321,12 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
assertThat(response.readEntity(String.class)).isBlank();
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@ -1336,7 +1339,7 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
@ -1345,7 +1348,7 @@ class AccountControllerTest {
assertThat(responseEntity.getOriginalNumber()).isEqualTo(number);
assertThat(responseEntity.getNormalizedNumber()).isEqualTo("+447700900111");
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@ -1355,10 +1358,10 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any());
}
@Test
@ -1373,11 +1376,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@ -1393,11 +1396,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code + "-incorrect", null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code + "-incorrect", null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@ -1423,11 +1426,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any());
}
@Test
@ -1453,11 +1456,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@ -1485,11 +1488,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@ -1517,82 +1520,33 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any());
}
@Test
void testChangePhoneNumberDeviceMessagesWithoutPrekeys() throws Exception {
final String number = "+18005559876";
final String code = "987654";
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
final Response response =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null,
List.of(new IncomingMessage(1, null, 1, 1, "foo")), null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
}
@Test
void testChangePhoneNumberChangePrekeysDeviceMessagesMismatchDeviceIDs() throws Exception {
final String number = "+18005559876";
final String code = "987654";
Device device2 = mock(Device.class);
when(device2.getId()).thenReturn(2L);
when(device2.isEnabled()).thenReturn(true);
Device device3 = mock(Device.class);
when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
final Response response =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
number, code, null,
List.of(
new IncomingMessage(1, null, 2, 1, "foo"),
new IncomingMessage(1, null, 4, 1, "foo")),
Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey())),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(409);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any());
}
@Test
void testChangePhoneNumberChangePrekeys() throws Exception {
final String number = "+18005559876";
final String code = "987654";
final String pniIdentityKey = "changed-pni-identity-key";
Device device2 = mock(Device.class);
when(device2.getId()).thenReturn(2L);
when(device2.isEnabled()).thenReturn(true);
when(device2.getRegistrationId()).thenReturn(2);
Device device3 = mock(Device.class);
when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true);
when(device3.getRegistrationId()).thenReturn(3);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3));
when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2));
when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
@ -1601,6 +1555,8 @@ class AccountControllerTest {
new IncomingMessage(1, null, 3, 3, "content3"));
var deviceKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
final AccountIdentityResponse accountIdentityResponse =
resources.getJerseyTest()
.target("/v1/accounts/number")
@ -1608,53 +1564,18 @@ class AccountControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
number, code, null,
deviceMessages,
deviceKeys),
pniIdentityKey, deviceMessages,
deviceKeys,
registrationIds),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any());
assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.getNumber()).isEqualTo(number);
assertThat(accountIdentityResponse.getPni()).isNotEqualTo(AuthHelper.VALID_PNI);
}
@Test
void testChangePhoneNumberChangePrekeysDeviceMessagesMismatchRegistrationID() throws Exception {
final String number = "+18005559876";
final String code = "987654";
Device device2 = mock(Device.class);
when(device2.getId()).thenReturn(2L);
when(device2.isEnabled()).thenReturn(true);
when(device2.getRegistrationId()).thenReturn(2);
Device device3 = mock(Device.class);
when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true);
when(device3.getRegistrationId()).thenReturn(3);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3));
when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2));
when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
final Response response =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
number, code, null,
List.of(
new IncomingMessage(1, null, 2, 1, "foo"),
new IncomingMessage(1, null, 3, 1, "foo")),
Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey())),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(410);
verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any());
}
@Test
void testSetRegistrationLock() {
Response response =

View File

@ -28,6 +28,7 @@ import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.UUID;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
@ -73,6 +74,8 @@ class KeysControllerTest {
private static final int SAMPLE_REGISTRATION_ID2 = 1002;
private static final int SAMPLE_REGISTRATION_ID4 = 1555;
private static final int SAMPLE_PNI_REGISTRATION_ID = 1717;
private final PreKey SAMPLE_KEY = new PreKey(1234, "test1");
private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3");
private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5");
@ -106,9 +109,11 @@ class KeysControllerTest {
.addResource(new RateLimitExceededExceptionMapper())
.build();
private Device sampleDevice;
@BeforeEach
void setup() {
final Device sampleDevice = mock(Device.class);
sampleDevice = mock(Device.class);
final Device sampleDevice2 = mock(Device.class);
final Device sampleDevice3 = mock(Device.class);
final Device sampleDevice4 = mock(Device.class);
@ -121,6 +126,7 @@ class KeysControllerTest {
when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice4.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID4);
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(SAMPLE_PNI_REGISTRATION_ID));
when(sampleDevice.isEnabled()).thenReturn(true);
when(sampleDevice2.isEnabled()).thenReturn(true);
when(sampleDevice3.isEnabled()).thenReturn(false);
@ -284,6 +290,7 @@ class KeysControllerTest {
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey());
verify(KEYS).take(EXISTS_UUID, 1);
@ -302,6 +309,28 @@ class KeysControllerTest {
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey());
verify(KEYS).take(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestByPhoneNumberIdentifierNoPniRegistrationIdTestV2() {
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey());
verify(KEYS).take(EXISTS_PNI, 1);

View File

@ -1,227 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.MessageValidation;
import org.whispersystems.textsecuregcm.util.Pair;
@ExtendWith(DropwizardExtensionsSupport.class)
class MessageValidationTest {
static Account mockAccountWithDeviceAndRegId(Object... deviceAndRegistrationIds) {
Account account = mock(Account.class);
if (deviceAndRegistrationIds.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
for (int i = 0; i < deviceAndRegistrationIds.length; i+=2) {
if (!(deviceAndRegistrationIds[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) {
throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1));
}
Long deviceId = (Long) deviceAndRegistrationIds[i];
Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1];
Device device = mock(Device.class);
when(device.getRegistrationId()).thenReturn(registrationId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
}
return account;
}
static Collection<Pair<Long, Integer>> deviceAndRegistrationIds(Object... deviceAndRegistrationIds) {
final Collection<Pair<Long, Integer>> result = new HashSet<>(deviceAndRegistrationIds.length);
if (deviceAndRegistrationIds.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
for (int i = 0; i < deviceAndRegistrationIds.length; i += 2) {
if (!(deviceAndRegistrationIds[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) {
throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1));
}
Long deviceId = (Long) deviceAndRegistrationIds[i];
Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1];
result.add(new Pair<>(deviceId, registrationId));
}
return result;
}
static Stream<Arguments> validateRegistrationIdsSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndRegId(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
deviceAndRegistrationIds(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 1492),
Set.of(1L)),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 42),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 0),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42, 2L, 255),
deviceAndRegistrationIds(1L, 0, 2L, 42),
Set.of(2L)),
arguments(
mockAccountWithDeviceAndRegId(1L, 42, 2L, 256),
deviceAndRegistrationIds(1L, 41, 2L, 257),
Set.of(1L, 2L))
);
}
@ParameterizedTest
@MethodSource("validateRegistrationIdsSource")
void testValidateRegistrationIds(
Account account,
Collection<Pair<Long, Integer>> deviceAndRegistrationIds,
Set<Long> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class, () -> {
MessageValidation.validateRegistrationIds(account, deviceAndRegistrationIds.stream());
}).getStaleDevices()).hasSameElementsAs(expectedStaleDeviceIds);
} else {
MessageValidation.validateRegistrationIds(account, deviceAndRegistrationIds.stream());
}
}
static Account mockAccountWithDeviceAndEnabled(Object... deviceIdAndEnabled) {
Account account = mock(Account.class);
if (deviceIdAndEnabled.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
final List<Device> devices = new ArrayList<>(deviceIdAndEnabled.length / 2);
for (int i = 0; i < deviceIdAndEnabled.length; i+=2) {
if (!(deviceIdAndEnabled[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceIdAndEnabled[i + 1] instanceof Boolean)) {
throw new IllegalArgumentException("enabled is not instance of boolean at index " + (i + 1));
}
Long deviceId = (Long) deviceIdAndEnabled[i];
Boolean enabled = (Boolean) deviceIdAndEnabled[i + 1];
Device device = mock(Device.class);
when(device.isEnabled()).thenReturn(enabled);
when(device.getId()).thenReturn(deviceId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
devices.add(device);
}
when(account.getDevices()).thenReturn(devices);
return account;
}
static Stream<Arguments> validateCompleteDeviceListSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 3L),
null,
null,
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L, 3L),
null,
Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
null,
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L),
Set.of(3L),
Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
Set.of(1L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(2L),
Set.of(3L),
Set.of(2L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(3L),
null,
null,
true,
1L
)
);
}
@ParameterizedTest
@MethodSource("validateCompleteDeviceListSource")
void testValidateCompleteDeviceList(
Account account,
Set<Long> deviceIds,
Collection<Long> expectedMissingDeviceIds,
Collection<Long> expectedExtraDeviceIds,
boolean isSyncMessage,
Long authenticatedDeviceId) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> MessageValidation.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId)));
if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds);
}
if (expectedExtraDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
}
} else {
MessageValidation.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId));
}
}
}

View File

@ -0,0 +1,214 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@ExtendWith(DropwizardExtensionsSupport.class)
class DestinationDeviceValidatorTest {
static Account mockAccountWithDeviceAndRegId(final Map<Long, Integer> registrationIdsByDeviceId) {
final Account account = mock(Account.class);
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> {
final Device device = mock(Device.class);
when(device.getRegistrationId()).thenReturn(registrationId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
});
return account;
}
static Stream<Arguments> validateRegistrationIdsSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF)),
Map.of(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 42)),
Map.of(1L, 1492),
Set.of(1L)),
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 42)),
Map.of(1L, 42),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 42)),
Map.of(1L, 0),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 42, 2L, 255)),
Map.of(1L, 0, 2L, 42),
Set.of(2L)),
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 42, 2L, 256)),
Map.of(1L, 41, 2L, 257),
Set.of(1L, 2L))
);
}
@ParameterizedTest
@MethodSource("validateRegistrationIdsSource")
void testValidateRegistrationIds(
Account account,
Map<Long, Integer> registrationIdsByDeviceId,
Set<Long> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId, false)).getStaleDevices())
.hasSameElementsAs(expectedStaleDeviceIds);
} else {
DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId, false);
}
}
static Account mockAccountWithDeviceAndEnabled(final Map<Long, Boolean> enabledStateByDeviceId) {
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
enabledStateByDeviceId.forEach((deviceId, enabled) -> {
final Device device = mock(Device.class);
when(device.isEnabled()).thenReturn(enabled);
when(device.getId()).thenReturn(deviceId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
devices.add(device);
});
when(account.getDevices()).thenReturn(devices);
return account;
}
static Stream<Arguments> validateCompleteDeviceListSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(1L, 3L),
null,
null,
Collections.emptySet()),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(1L, 2L, 3L),
null,
Set.of(2L),
Collections.emptySet()),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(1L),
Set.of(3L),
null,
Collections.emptySet()),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(1L, 2L),
Set.of(3L),
Set.of(2L),
Collections.emptySet()),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(1L),
Set.of(3L),
Set.of(1L),
Set.of(1L)
),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(2L),
Set.of(3L),
Set.of(2L),
Set.of(1L)
),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(3L),
null,
null,
Set.of(1L)
)
);
}
@ParameterizedTest
@MethodSource("validateCompleteDeviceListSource")
void testValidateCompleteDeviceList(
Account account,
Set<Long> deviceIds,
Collection<Long> expectedMissingDeviceIds,
Collection<Long> expectedExtraDeviceIds,
Set<Long> excludedDeviceIds) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds));
if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds);
}
if (expectedExtraDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
}
} else {
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds);
}
}
@Test
void testValidatePniRegistrationIds() {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.MASTER_ID);
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(device));
final int aciRegistrationId = 17;
final int pniRegistrationId = 89;
final int incorrectRegistrationId = aciRegistrationId + pniRegistrationId;
when(device.getRegistrationId()).thenReturn(aciRegistrationId);
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId));
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), false));
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, pniRegistrationId), true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, pniRegistrationId), false));
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), false));
assertDoesNotThrow(() -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, aciRegistrationId), true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, incorrectRegistrationId), true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, Map.of(Device.MASTER_ID, incorrectRegistrationId), false));
}
}

View File

@ -64,6 +64,7 @@ import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import javax.annotation.Nullable;
class WebSocketConnectionTest {
@ -285,6 +286,7 @@ class WebSocketConnectionTest {
.setSource("sender1")
.setSourceUuid(UUID.randomUUID().toString())
.setDestinationUuid(UUID.randomUUID().toString())
.setUpdatedPni(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis())
.setSourceDevice(1)
.setType(Envelope.Type.CIPHERTEXT)
@ -302,12 +304,12 @@ class WebSocketConnectionTest {
List<OutgoingMessageEntity> pendingMessages = new LinkedList<OutgoingMessageEntity>() {{
add(new OutgoingMessageEntity(UUID.randomUUID(), firstMessage.getType().getNumber(),
firstMessage.getTimestamp(), firstMessage.getSource(), UUID.fromString(firstMessage.getSourceUuid()),
firstMessage.getSourceDevice(), UUID.fromString(firstMessage.getDestinationUuid()),
firstMessage.getContent().toByteArray(), 0));
firstMessage.getSourceDevice(), UUID.fromString(firstMessage.getDestinationUuid()), UUID.fromString(firstMessage.getUpdatedPni()),
firstMessage.getContent().toByteArray(), 0));
add(new OutgoingMessageEntity(UUID.randomUUID(), secondMessage.getType().getNumber(),
secondMessage.getTimestamp(), secondMessage.getSource(), UUID.fromString(secondMessage.getSourceUuid()),
secondMessage.getSourceDevice(), UUID.fromString(secondMessage.getDestinationUuid()),
secondMessage.getContent().toByteArray(), 0));
secondMessage.getSourceDevice(), UUID.fromString(secondMessage.getDestinationUuid()), null,
secondMessage.getContent().toByteArray(), 0));
}};
OutgoingMessageEntityList pendingMessagesList = new OutgoingMessageEntityList(pendingMessages, false);
@ -884,7 +886,7 @@ class WebSocketConnectionTest {
private OutgoingMessageEntity createMessage(String sender, UUID senderUuid, UUID destinationUuid, long timestamp, boolean receipt, String content) {
return new OutgoingMessageEntity(UUID.randomUUID(), receipt ? Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,
timestamp, sender, senderUuid, 1, destinationUuid, content.getBytes(), 0);
timestamp, sender, senderUuid, 1, destinationUuid, null, content.getBytes(), 0);
}
}

View File

@ -0,0 +1,16 @@
{
"messages" : [{
"type" : 1,
"destinationDeviceId" : 1,
"destinationRegistrationId" : 2222,
"content" : "Zm9vYmFyego",
"timestamp" : 1234
},
{
"type" : 1,
"destinationDeviceId" : 2,
"destinationRegistrationId" : 3333,
"content" : "Zm9vYmFyego",
"timestamp" : 1234
}]
}