Require PQ keys when changing numbers or distributing key material

This commit is contained in:
Jon Chambers 2025-05-05 14:54:12 -04:00 committed by Jon Chambers
parent e43487155f
commit b400d49e77
11 changed files with 224 additions and 327 deletions

View File

@ -17,6 +17,7 @@ import javax.annotation.Nullable;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import jakarta.validation.constraints.AssertTrue; import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
@ -51,34 +52,27 @@ public record ChangeNumberRequest(
arraySchema=@Schema(description=""" arraySchema=@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keysManager A list of synchronization messages to send to companion devices to supply the private keysManager
associated with the new identity key and their new prekeys. associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")) Exactly one message must be supplied for each device other than the sending (primary) device."""))
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages, @NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
@Schema(description=""" @Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one. A new signed elliptic-curve prekey for each device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""") Each must be accompanied by a valid signature from the new identity key in this request.""")
@NotNull @Valid Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys, @NotNull @NotEmpty @Valid Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
@Schema(description=""" @Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. A new signed post-quantum last-resort prekey for each device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""") Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys, @NotNull @NotEmpty @Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one") @Schema(description="the new phone-number-identity registration ID for each device on the account, including this one")
@NotNull Map<Byte, Integer> pniRegistrationIds) implements PhoneVerificationRequest { @NotNull @NotEmpty Map<Byte, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) { public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) {
List<SignedPreKey<?>> spks = new ArrayList<>(); final List<SignedPreKey<?>> spks = new ArrayList<>(devicePniSignedPrekeys.values());
if (devicePniSignedPrekeys != null) {
spks.addAll(devicePniSignedPrekeys.values());
}
if (devicePniPqLastResortPrekeys != null) {
spks.addAll(devicePniPqLastResortPrekeys.values()); spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "change-number"); return PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "change-number");
} }
@AssertTrue @AssertTrue

View File

@ -9,6 +9,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.swagger.v3.oas.annotations.media.ArraySchema; import io.swagger.v3.oas.annotations.media.ArraySchema;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.NotNull;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -29,36 +30,36 @@ public record PhoneNumberIdentityKeyDistributionRequest(
arraySchema=@Schema(description=""" arraySchema=@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys. associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device. Exactly one message must be supplied for each device other than the sending (primary) device.
""")) """))
List<@NotNull @Valid IncomingMessage> deviceMessages, List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull @NotNull
@NotEmpty
@Valid @Valid
@Schema(description=""" @Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one. A new signed elliptic-curve prekey for each device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""") Each must be accompanied by a valid signature from the new identity key in this request.""")
Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys, Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
@NotNull
@NotEmpty
@Valid
@Schema(description=""" @Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. A new signed post-quantum last-resort prekey for each device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""") Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys, Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@NotNull @NotNull
@NotEmpty
@Valid @Valid
@Schema(description="The new registration ID to use for the phone-number identity of each device, including this one.") @Schema(description="The new registration ID to use for the phone-number identity of each device, including this one.")
Map<Byte, Integer> pniRegistrationIds) { Map<Byte, Integer> pniRegistrationIds) {
public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) { public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) {
List<SignedPreKey<?>> spks = new ArrayList<>(devicePniSignedPrekeys.values()); final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>(devicePniSignedPrekeys.values());
if (devicePniPqLastResortPrekeys != null) { signedPreKeys.addAll(devicePniPqLastResortPrekeys.values());
spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "distribute-pni-keys");
}
return PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, signedPreKeys, userAgent, "distribute-pni-keys");
}
} }

View File

@ -62,7 +62,6 @@ import java.util.stream.Stream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.crypto.Mac; import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec; import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -642,18 +641,15 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
public Account changeNumber(final Account account, public Account changeNumber(final Account account,
final String targetNumber, final String targetNumber,
@Nullable final IdentityKey pniIdentityKey, final IdentityKey pniIdentityKey,
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys, final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys, final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { final Map<Byte, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier(); final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier();
final UUID targetPhoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber).join(); final UUID targetPhoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber).join();
if (originalPhoneNumberIdentifier.equals(targetPhoneNumberIdentifier)) { if (originalPhoneNumberIdentifier.equals(targetPhoneNumberIdentifier)) {
if (pniIdentityKey != null) {
throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePniKeys");
}
return account; return account;
} }
@ -694,7 +690,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
.join(); .join();
final Collection<TransactWriteItem> keyWriteItems = final Collection<TransactWriteItem> keyWriteItems =
buildPniKeyWriteItems(uuid, targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys); buildPniKeyWriteItems(targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys);
final Account numberChangedAccount = updateWithRetries( final Account numberChangedAccount = updateWithRetries(
account, account,
@ -715,7 +711,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
public Account updatePniKeys(final Account account, public Account updatePniKeys(final Account account,
final IdentityKey pniIdentityKey, final IdentityKey pniIdentityKey,
final Map<Byte, ECSignedPreKey> pniSignedPreKeys, final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys, final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException { final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds); validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
@ -724,7 +720,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
final UUID pni = account.getIdentifier(IdentityType.PNI); final UUID pni = account.getIdentifier(IdentityType.PNI);
final Collection<TransactWriteItem> keyWriteItems = final Collection<TransactWriteItem> keyWriteItems =
buildPniKeyWriteItems(pni, pni, pniSignedPreKeys, pniPqLastResortPreKeys); buildPniKeyWriteItems(pni, pniSignedPreKeys, pniPqLastResortPreKeys);
return redisDeleteAsync(account) return redisDeleteAsync(account)
.thenCompose(ignored -> keysManager.deleteSingleUsePreKeys(pni)) .thenCompose(ignored -> keysManager.deleteSingleUsePreKeys(pni))
@ -739,41 +735,24 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
} }
private Collection<TransactWriteItem> buildPniKeyWriteItems( private Collection<TransactWriteItem> buildPniKeyWriteItems(
final UUID enabledDevicesIdentifier,
final UUID phoneNumberIdentifier, final UUID phoneNumberIdentifier,
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys, final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys) { final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys) {
final List<TransactWriteItem> keyWriteItems = new ArrayList<>(); final List<TransactWriteItem> keyWriteItems = new ArrayList<>();
if (pniSignedPreKeys != null) {
pniSignedPreKeys.forEach((deviceId, signedPreKey) -> pniSignedPreKeys.forEach((deviceId, signedPreKey) ->
keyWriteItems.add(keysManager.buildWriteItemForEcSignedPreKey(phoneNumberIdentifier, deviceId, signedPreKey))); keyWriteItems.add(keysManager.buildWriteItemForEcSignedPreKey(phoneNumberIdentifier, deviceId, signedPreKey)));
}
if (pniPqLastResortPreKeys != null) { pniPqLastResortPreKeys.forEach((deviceId, lastResortKey) ->
keysManager.getPqEnabledDevices(enabledDevicesIdentifier) keyWriteItems.add(keysManager.buildWriteItemForLastResortKey(phoneNumberIdentifier, deviceId, lastResortKey)));
.thenAccept(deviceIds -> deviceIds.stream()
.filter(pniPqLastResortPreKeys::containsKey)
.map(deviceId -> keysManager.buildWriteItemForLastResortKey(phoneNumberIdentifier,
deviceId,
pniPqLastResortPreKeys.get(deviceId)))
.forEach(keyWriteItems::add))
.join();
}
return keyWriteItems; return keyWriteItems;
} }
private void setPniKeys(final Account account, private void setPniKeys(final Account account,
@Nullable final IdentityKey pniIdentityKey, final IdentityKey pniIdentityKey,
@Nullable final Map<Byte, Integer> pniRegistrationIds) { final Map<Byte, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniRegistrationIds)) {
return;
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key and registration IDs must be all null or all non-null");
}
account.getDevices() account.getDevices()
.forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId()))); .forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId())));
@ -782,22 +761,15 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
} }
private void validateDevices(final Account account, private void validateDevices(final Account account,
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys, final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys, final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException { final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
if (pniSignedPreKeys == null && pniRegistrationIds == null) {
return;
} else if (pniSignedPreKeys == null || pniRegistrationIds == null) {
throw new IllegalArgumentException("Signed pre-keys and registration IDs must both be null or both be non-null");
}
// Check that all including primary ID are in signed pre-keys // Check that all including primary ID are in signed pre-keys
validateCompleteDeviceList(account, pniSignedPreKeys.keySet()); validateCompleteDeviceList(account, pniSignedPreKeys.keySet());
// Check that all including primary ID are in Pq pre-keys // Check that all including primary ID are in Pq pre-keys
if (pniPqLastResortPreKeys != null) {
validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet()); validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet());
}
// Check that all devices are accounted for in the map of new PNI registration IDs // Check that all devices are accounted for in the map of new PNI registration IDs
validateCompleteDeviceList(account, pniRegistrationIds.keySet()); validateCompleteDeviceList(account, pniRegistrationIds.keySet());
@ -816,8 +788,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
extraDeviceIds.removeAll(accountDeviceIds); extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) { if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException( throw new MismatchedDevicesException(new MismatchedDevices(missingDeviceIds, extraDeviceIds, Set.of()));
new MismatchedDevices(missingDeviceIds, extraDeviceIds, Collections.emptySet()));
} }
} }

View File

@ -42,13 +42,14 @@ public class ChangeNumberManager {
this.clock = clock; this.clock = clock;
} }
public Account changeNumber(final Account account, final String number, public Account changeNumber(final Account account,
@Nullable final IdentityKey pniIdentityKey, final String number,
@Nullable final Map<Byte, ECSignedPreKey> deviceSignedPreKeys, final IdentityKey pniIdentityKey,
@Nullable final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys, final Map<Byte, ECSignedPreKey> deviceSignedPreKeys,
@Nullable final List<IncomingMessage> deviceMessages, final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds, final List<IncomingMessage> deviceMessages,
@Nullable final String senderUserAgent) final Map<Byte, Integer> pniRegistrationIds,
final String senderUserAgent)
throws InterruptedException, MismatchedDevicesException, MessageTooLargeException { throws InterruptedException, MismatchedDevicesException, MessageTooLargeException {
if (!(ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds) || if (!(ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds) ||

View File

@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -114,10 +113,6 @@ public class KeysManager {
return ecSignedPreKeys.find(identifier, deviceId); return ecSignedPreKeys.find(identifier, deviceId);
} }
public CompletableFuture<List<Byte>> getPqEnabledDevices(final UUID identifier) {
return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().toFuture();
}
public CompletableFuture<Integer> getEcCount(final UUID identifier, final byte deviceId) { public CompletableFuture<Integer> getEcCount(final UUID identifier, final byte deviceId) {
return ecPreKeys.getCount(identifier, deviceId); return ecPreKeys.getCount(identifier, deviceId);
} }

View File

@ -14,14 +14,12 @@ import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.Delete; import software.amazon.awssdk.services.dynamodb.model.Delete;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.Put; import software.amazon.awssdk.services.dynamodb.model.Put;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
/** /**
@ -116,20 +114,6 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
return findFuture; return findFuture;
} }
public Flux<Byte> getDeviceIdsWithKeys(final UUID identifier) {
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(identifier)))
.projectionExpression(KEY_DEVICE_ID)
.consistentRead(true)
.build())
.items())
.map(item -> Byte.parseByte(item.get(KEY_DEVICE_ID).n()));
}
protected static Map<String, AttributeValue> getPrimaryKey(final UUID identifier, final byte deviceId) { protected static Map<String, AttributeValue> getPrimaryKey(final UUID identifier, final byte deviceId) {
return Map.of( return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier), KEY_ACCOUNT_UUID, getPartitionKey(identifier),

View File

@ -72,6 +72,8 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse; import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest; import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest; import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession; import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
@ -92,6 +94,7 @@ import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -100,7 +103,8 @@ class AccountControllerV2Test {
private static final long SESSION_EXPIRATION_SECONDS = Duration.ofMinutes(10).toSeconds(); private static final long SESSION_EXPIRATION_SECONDS = Duration.ofMinutes(10).toSeconds();
private static final IdentityKey IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey()); private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private static final IdentityKey IDENTITY_KEY = new IdentityKey(IDENTITY_KEY_PAIR.getPublicKey());
private static final String NEW_NUMBER = PhoneNumberUtil.getInstance().format( private static final String NEW_NUMBER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.getInstance().getExampleNumber("US"),
@ -185,9 +189,11 @@ class AccountControllerV2Test {
.header(HttpHeaders.AUTHORIZATION, .header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity( .put(Entity.entity(
new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", new IdentityKey(Curve.generateKeyPair().getPublicKey()), new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", IDENTITY_KEY,
Collections.emptyList(), Collections.emptyList(),
Collections.emptyMap(), null, Collections.emptyMap()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR)),
Map.of(Device.PRIMARY_ID, 17)),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(), verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
@ -207,10 +213,11 @@ class AccountControllerV2Test {
.header(HttpHeaders.AUTHORIZATION, .header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity( .put(Entity.entity(
new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null, new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null, IDENTITY_KEY,
new IdentityKey(Curve.generateKeyPair().getPublicKey()),
Collections.emptyList(), Collections.emptyList(),
Collections.emptyMap(), null, Collections.emptyMap()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR)),
Map.of(Device.PRIMARY_ID, 17)),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(),
@ -291,9 +298,11 @@ class AccountControllerV2Test {
.thenReturn(CompletableFuture.completedFuture( .thenReturn(CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NEW_NUMBER, true, null, null, null, Optional.of(new RegistrationServiceSession(new byte[16], NEW_NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS)))); SESSION_EXPIRATION_SECONDS))));
final ChangeNumberRequest changeNumberRequest = new ChangeNumberRequest(encodeSessionId("session"), final ChangeNumberRequest changeNumberRequest = new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", IDENTITY_KEY,
null, NEW_NUMBER, "123", new IdentityKey(Curve.generateKeyPair().getPublicKey()), Collections.emptyList(),
Collections.emptyList(), Collections.emptyMap(), null, Map.of((byte) 1, pniRegistrationId)); Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR)),
Map.of(Device.PRIMARY_ID, pniRegistrationId));
try (final Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()
.target("/v2/accounts/number") .target("/v2/accounts/number")
@ -503,9 +512,11 @@ class AccountControllerV2Test {
.header(HttpHeaders.AUTHORIZATION, .header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity( .put(Entity.entity(
new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", new IdentityKey(Curve.generateKeyPair().getPublicKey()), new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", IDENTITY_KEY,
Collections.emptyList(), Collections.emptyList(),
Collections.emptyMap(), null, Collections.emptyMap()), Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR)),
Map.of(Device.PRIMARY_ID, 17)),
MediaType.APPLICATION_JSON_TYPE))) { MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(413, response.getStatus()); assertEquals(413, response.getStatus());
@ -519,17 +530,17 @@ class AccountControllerV2Test {
return requestJson("", recoveryPassword, newNumber, 123); return requestJson("", recoveryPassword, newNumber, 123);
} }
/**
* Valid request JSON with the given pniRegistrationId
*/
private static String requestJsonRegistrationIds(final Integer pniRegistrationId) {
return requestJson("", new byte[0], "+18005551234", pniRegistrationId);
}
/** /**
* Valid request JSON with the give session ID and recovery password * Valid request JSON with the give session ID and recovery password
*/ */
private static String requestJson(final String sessionId, final byte[] recoveryPassword, final String newNumber, final Integer pniRegistrationId) { private static String requestJson(final String sessionId,
final byte[] recoveryPassword,
final String newNumber,
final Integer pniRegistrationId) {
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR);
final KEMSignedPreKey pniLastResortPreKey = KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR);
return String.format(""" return String.format("""
{ {
"sessionId": "%s", "sessionId": "%s",
@ -538,10 +549,17 @@ class AccountControllerV2Test {
"reglock": "1234", "reglock": "1234",
"pniIdentityKey": "%s", "pniIdentityKey": "%s",
"deviceMessages": [], "deviceMessages": [],
"devicePniSignedPrekeys": {}, "devicePniSignedPrekeys": {"1": {"keyId": %d, "publicKey": "%s", "signature": "%s"}},
"devicePniPqLastResortPrekeys": {"1": {"keyId": %d, "publicKey": "%s", "signature": "%s"}},
"pniRegistrationIds": {"1": %d} "pniRegistrationIds": {"1": %d}
} }
""", encodeSessionId(sessionId), encodeRecoveryPassword(recoveryPassword), newNumber, Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize()), pniRegistrationId); """, encodeSessionId(sessionId),
encodeRecoveryPassword(recoveryPassword),
newNumber,
Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize()),
pniSignedPreKey.keyId(), Base64.getEncoder().encodeToString(pniSignedPreKey.serializedPublicKey()), Base64.getEncoder().encodeToString(pniSignedPreKey.signature()),
pniLastResortPreKey.keyId(), Base64.getEncoder().encodeToString(pniLastResortPreKey.serializedPublicKey()), Base64.getEncoder().encodeToString(pniLastResortPreKey.signature()),
pniRegistrationId);
} }
/** /**
@ -698,15 +716,21 @@ class AccountControllerV2Test {
* Valid request JSON for a {@link org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest} * Valid request JSON for a {@link org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest}
*/ */
private static String requestJson() { private static String requestJson() {
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR);
final KEMSignedPreKey pniLastResortPreKey = KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR);
return String.format(""" return String.format("""
{ {
"pniIdentityKey": "%s", "pniIdentityKey": "%s",
"deviceMessages": [], "deviceMessages": [],
"devicePniSignedPrekeys": {}, "devicePniSignedPrekeys": {},
"devicePniSignedPqPrekeys": {}, "devicePniSignedPrekeys": {"1": {"keyId": %d, "publicKey": "%s", "signature": "%s"}},
"pniRegistrationIds": {} "devicePniPqLastResortPrekeys": {"1": {"keyId": %d, "publicKey": "%s", "signature": "%s"}},
"pniRegistrationIds": {"1": 17}
} }
""", Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize())); """, Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize()),
pniSignedPreKey.keyId(), Base64.getEncoder().encodeToString(pniSignedPreKey.serializedPublicKey()), Base64.getEncoder().encodeToString(pniSignedPreKey.signature()),
pniLastResortPreKey.keyId(), Base64.getEncoder().encodeToString(pniLastResortPreKey.serializedPublicKey()), Base64.getEncoder().encodeToString(pniLastResortPreKey.signature()));
} }
/** /**

View File

@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
@ -173,7 +174,14 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
accountsManager.changeNumber(account, secondNumber, null, null, null, null); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
accountsManager.changeNumber(account,
secondNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, 1));
assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -193,6 +201,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final int rotatedPniRegistrationId = 17; final int rotatedPniRegistrationId = 17;
final ECKeyPair rotatedPniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair rotatedPniIdentityKeyPair = Curve.generateKeyPair();
final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, rotatedPniIdentityKeyPair); final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, rotatedPniIdentityKeyPair);
final KEMSignedPreKey rotatedKemSignedPreKey = KeysHelper.signedKEMPreKey(2L, rotatedPniIdentityKeyPair);
final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, rotatedPniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, Set.of()); final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, rotatedPniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, Set.of());
final Account account = AccountsHelper.createAccount(accountsManager, originalNumber, accountAttributes); final Account account = AccountsHelper.createAccount(accountsManager, originalNumber, accountAttributes);
@ -204,9 +213,10 @@ class AccountsManagerChangeNumberIntegrationTest {
final IdentityKey pniIdentityKey = new IdentityKey(rotatedPniIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(rotatedPniIdentityKeyPair.getPublicKey());
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey); final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey);
final Map<Byte, KEMSignedPreKey> kemSignedPreKeys = Map.of(Device.PRIMARY_ID, rotatedKemSignedPreKey);
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId); final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId);
final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds); final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, kemSignedPreKeys, registrationIds);
final UUID secondPni = updatedAccount.getPhoneNumberIdentifier(); final UUID secondPni = updatedAccount.getPhoneNumberIdentifier();
assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -240,9 +250,24 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
account = accountsManager.changeNumber(account, secondNumber, null, null, null, null); final ECKeyPair originalIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair secondIdentityKeyPair = Curve.generateKeyPair();
account = accountsManager.changeNumber(account,
secondNumber,
new IdentityKey(secondIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, secondIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, secondIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, 1));
final UUID secondPni = account.getPhoneNumberIdentifier(); final UUID secondPni = account.getPhoneNumberIdentifier();
accountsManager.changeNumber(account, originalNumber, null, null, null, null);
accountsManager.changeNumber(account,
originalNumber,
new IdentityKey(originalIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(3, originalIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(4, originalIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, 2));
assertTrue(accountsManager.getByE164(originalNumber).isPresent()); assertTrue(accountsManager.getByE164(originalNumber).isPresent());
assertEquals(originalUuid, accountsManager.getByE164(originalNumber).map(Account::getUuid).orElseThrow()); assertEquals(originalUuid, accountsManager.getByE164(originalNumber).map(Account::getUuid).orElseThrow());
@ -266,11 +291,20 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
final ECKeyPair originalIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair secondIdentityKeyPair = Curve.generateKeyPair();
final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber); final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber);
final UUID existingAccountUuid = existingAccount.getUuid(); final UUID existingAccountUuid = existingAccount.getUuid();
accountsManager.changeNumber(account, secondNumber, null, null, null, null); accountsManager.changeNumber(account,
secondNumber,
new IdentityKey(secondIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, secondIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, secondIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, 1));
final UUID secondPni = accountsManager.getByE164(secondNumber).get().getPhoneNumberIdentifier(); final UUID secondPni = accountsManager.getByE164(secondNumber).get().getPhoneNumberIdentifier();
assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -285,7 +319,12 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni)); assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni));
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni));
accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null, null); accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(),
originalNumber,
new IdentityKey(originalIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, originalIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, originalIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, 1));
final Account existingAccount2 = AccountsHelper.createAccount(accountsManager, secondNumber); final Account existingAccount2 = AccountsHelper.createAccount(accountsManager, secondNumber);
@ -305,8 +344,15 @@ class AccountsManagerChangeNumberIntegrationTest {
final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber); final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber);
final UUID existingAccountUuid = existingAccount.getUuid(); final UUID existingAccountUuid = existingAccount.getUuid();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Account changedNumberAccount = accountsManager.changeNumber(account,
secondNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, 1));
final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null, null);
final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier(); final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier();
final Account reRegisteredAccount = AccountsHelper.createAccount(accountsManager, originalNumber); final Account reRegisteredAccount = AccountsHelper.createAccount(accountsManager, originalNumber);
@ -317,7 +363,14 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni));
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni));
final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber, null, null, null, null); final ECKeyPair reRegisteredPniIdentityKeyPair = Curve.generateKeyPair();
final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount,
secondNumber,
new IdentityKey(reRegisteredPniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, reRegisteredPniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, reRegisteredPniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, 1));
assertEquals(Optional.of(originalUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni)); assertEquals(Optional.of(originalUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni));
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni));

View File

@ -1048,9 +1048,18 @@ class AccountsManagerTest {
final String targetNumber = "+14153333333"; final String targetNumber = "+14153333333";
final UUID uuid = UUID.randomUUID(); final UUID uuid = UUID.randomUUID();
final UUID originalPni = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(1, pniIdentityKeyPair);
account = accountsManager.changeNumber(account, targetNumber, null, null, null, null); final KEMSignedPreKey kemLastResortPreKey = KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair);
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
account = accountsManager.changeNumber(account,
targetNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, ecSignedPreKey),
Map.of(Device.PRIMARY_ID, kemLastResortPreKey),
Map.of(Device.PRIMARY_ID, 1));
assertEquals(targetNumber, account.getNumber()); assertEquals(targetNumber, account.getNumber());
@ -1058,18 +1067,26 @@ class AccountsManagerTest {
verify(keysManager).deleteSingleUsePreKeys(originalPni); verify(keysManager).deleteSingleUsePreKeys(originalPni);
verify(keysManager).deleteSingleUsePreKeys(phoneNumberIdentifiersByE164.get(targetNumber)); verify(keysManager).deleteSingleUsePreKeys(phoneNumberIdentifiersByE164.get(targetNumber));
verify(keysManager).buildWriteItemForEcSignedPreKey(phoneNumberIdentifiersByE164.get(targetNumber), Device.PRIMARY_ID, ecSignedPreKey);
verify(keysManager).buildWriteItemForLastResortKey(phoneNumberIdentifiersByE164.get(targetNumber), Device.PRIMARY_ID, kemLastResortPreKey);
} }
@Test @Test
void testChangePhoneNumberSameNumber() throws InterruptedException, MismatchedDevicesException { void testChangePhoneNumberSameNumber() throws InterruptedException, MismatchedDevicesException {
final String number = "+14152222222"; final String number = "+14152222222";
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
phoneNumberIdentifiersByE164.put(number, account.getPhoneNumberIdentifier()); phoneNumberIdentifiersByE164.put(number, account.getPhoneNumberIdentifier());
account = accountsManager.changeNumber(account, number, null, null, null, null); account = accountsManager.changeNumber(account,
number,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, 1));
assertEquals(number, account.getNumber()); assertEquals(number, account.getNumber());
verify(keysManager, never()).deleteSingleUsePreKeys(any()); verifyNoInteractions(keysManager);
} }
@Test @Test
@ -1077,31 +1094,20 @@ class AccountsManagerTest {
final String originalNumber = "+22923456789"; final String originalNumber = "+22923456789";
// the canonical form of numbers may change over time, so we use PNIs as stable identifiers // the canonical form of numbers may change over time, so we use PNIs as stable identifiers
final String newNumber = "+2290123456789"; final String newNumber = "+2290123456789";
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
Account account = AccountsHelper.generateTestAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), Account account = AccountsHelper.generateTestAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(),
new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
phoneNumberIdentifiersByE164.put(originalNumber, account.getPhoneNumberIdentifier()); phoneNumberIdentifiersByE164.put(originalNumber, account.getPhoneNumberIdentifier());
phoneNumberIdentifiersByE164.put(newNumber, account.getPhoneNumberIdentifier()); phoneNumberIdentifiersByE164.put(newNumber, account.getPhoneNumberIdentifier());
account = accountsManager.changeNumber(account, newNumber, null, null, null, null); account = accountsManager.changeNumber(account,
newNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair)),
Map.of(Device.PRIMARY_ID, 1));
assertEquals(originalNumber, account.getNumber()); assertEquals(originalNumber, account.getNumber());
verify(keysManager, never()).deleteSingleUsePreKeys(any());
}
@Test
void testChangePhoneNumberSameNumberWithPniData() {
final String number = "+14152222222";
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
phoneNumberIdentifiersByE164.put(number, account.getPhoneNumberIdentifier());
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
assertThrows(IllegalArgumentException.class,
() -> accountsManager.changeNumber(
account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of((byte) 1, 101)),
"AccountsManager should not allow use of changeNumber with new PNI keys but without changing number");
verify(accounts, never()).update(any());
verifyNoInteractions(keysManager); verifyNoInteractions(keysManager);
} }
@ -1113,12 +1119,21 @@ class AccountsManagerTest {
final UUID uuid = UUID.randomUUID(); final UUID uuid = UUID.randomUUID();
final UUID originalPni = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(1, pniIdentityKeyPair);
account = accountsManager.changeNumber(account, targetNumber, null, null, null, null); final KEMSignedPreKey kemLastResoryPreKey = KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair);
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
account = accountsManager.changeNumber(account,
targetNumber,
new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Map.of(Device.PRIMARY_ID, ecSignedPreKey),
Map.of(Device.PRIMARY_ID, kemLastResoryPreKey),
Map.of(Device.PRIMARY_ID, 1));
assertEquals(targetNumber, account.getNumber()); assertEquals(targetNumber, account.getNumber());
@ -1129,6 +1144,9 @@ class AccountsManagerTest {
verify(keysManager).deleteSingleUsePreKeys(originalPni); verify(keysManager).deleteSingleUsePreKeys(originalPni);
verify(keysManager, atLeastOnce()).deleteSingleUsePreKeys(targetPni); verify(keysManager, atLeastOnce()).deleteSingleUsePreKeys(targetPni);
verify(keysManager).deleteSingleUsePreKeys(newPni); verify(keysManager).deleteSingleUsePreKeys(newPni);
verify(keysManager).buildWriteItemsForRemovedDevice(existingAccountUuid, targetPni, Device.PRIMARY_ID);
verify(keysManager).buildWriteItemForEcSignedPreKey(newPni, Device.PRIMARY_ID, ecSignedPreKey);
verify(keysManager).buildWriteItemForLastResortKey(newPni, Device.PRIMARY_ID, kemLastResoryPreKey);
verifyNoMoreInteractions(keysManager); verifyNoMoreInteractions(keysManager);
} }
@ -1141,27 +1159,22 @@ class AccountsManagerTest {
final UUID originalPni = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID();
final byte deviceId2 = 2; final byte deviceId2 = 2;
final byte deviceId3 = 3;
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of( final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair), deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
deviceId3, KeysHelper.signedECPreKey(3, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of( final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(4, identityKeyPair), Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(4, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(5, identityKeyPair), deviceId2, KeysHelper.signedKEMPreKey(5, identityKeyPair));
deviceId3, KeysHelper.signedKEMPreKey(6, identityKeyPair)); final Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
final Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202, deviceId3, 203);
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID, deviceId3)));
when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
final List<Device> devices = List.of( final List<Device> devices = List.of(
DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102), DevicesHelper.createDevice(deviceId2, 0L, 102));
DevicesHelper.createDisabledDevice(deviceId3, 103));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final Account updatedAccount = accountsManager.changeNumber( final Account updatedAccount = accountsManager.changeNumber(
account, targetNumber, new IdentityKey(Curve.generateKeyPair().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds); account, targetNumber, new IdentityKey(Curve.generateKeyPair().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds);
@ -1175,24 +1188,20 @@ class AccountsManagerTest {
verify(keysManager, atLeastOnce()).deleteSingleUsePreKeys(targetPni); verify(keysManager, atLeastOnce()).deleteSingleUsePreKeys(targetPni);
verify(keysManager).deleteSingleUsePreKeys(newPni); verify(keysManager).deleteSingleUsePreKeys(newPni);
verify(keysManager).deleteSingleUsePreKeys(originalPni); verify(keysManager).deleteSingleUsePreKeys(originalPni);
verify(keysManager).getPqEnabledDevices(uuid);
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId2), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId2), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId3), any());
verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(deviceId3), any()); verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(deviceId2), any());
verifyNoMoreInteractions(keysManager); verifyNoMoreInteractions(keysManager);
} }
@Test @Test
void testChangePhoneNumberWithMismatchedPqKeys() throws InterruptedException, MismatchedDevicesException { void testChangePhoneNumberWithMismatchedPqKeys() {
final String originalNumber = "+14152222222"; final String originalNumber = "+14152222222";
final String targetNumber = "+14153333333"; final String targetNumber = "+14153333333";
final UUID existingAccountUuid = UUID.randomUUID();
final UUID uuid = UUID.randomUUID(); final UUID uuid = UUID.randomUUID();
final UUID originalPni = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final byte deviceId2 = 2; final byte deviceId2 = 2;
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of( final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
@ -1202,11 +1211,6 @@ class AccountsManagerTest {
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair)); Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair));
final Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); final Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(
CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID)));
final List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), final List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102)); DevicesHelper.createDevice(deviceId2, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
@ -1242,6 +1246,9 @@ class AccountsManagerTest {
Map<Byte, ECSignedPreKey> newSignedKeys = Map.of( Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Byte, KEMSignedPreKey> newSignedKemKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(2, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid(); UUID oldUuid = account.getUuid();
@ -1249,10 +1256,9 @@ class AccountsManagerTest {
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
when(keysManager.getPqEnabledDevices(any())).thenReturn(CompletableFuture.completedFuture(Collections.emptyList()));
when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds); final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, newSignedKemKeys, newRegistrationIds);
// non-PNI stuff should not change // non-PNI stuff should not change
assertEquals(oldUuid, updatedAccount.getUuid()); assertEquals(oldUuid, updatedAccount.getUuid());
@ -1272,111 +1278,7 @@ class AccountsManagerTest {
verify(keysManager).deleteSingleUsePreKeys(oldPni); verify(keysManager).deleteSingleUsePreKeys(oldPni);
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any());
verify(keysManager, never()).buildWriteItemForLastResortKey(any(), anyByte(), any()); verify(keysManager).buildWriteItemForLastResortKey(eq(oldPni), eq(deviceId2), any());
}
@Test
void testPniPqUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
final byte deviceId2 = 2;
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(
CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID)));
when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Account updatedAccount =
accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, newSignedPqKeys, newRegistrationIds);
// non-PNI-keys stuff should not change
assertEquals(oldUuid, updatedAccount.getUuid());
assertEquals(number, updatedAccount.getNumber());
assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier());
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt())));
verify(accounts).updateTransactionallyAsync(any(), any());
verify(keysManager).deleteSingleUsePreKeys(oldPni);
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any());
verify(keysManager).buildWriteItemForLastResortKey(eq(oldPni), eq(Device.PRIMARY_ID), any());
verify(keysManager, never()).buildWriteItemForLastResortKey(eq(oldPni), eq(deviceId2), any());
}
@Test
void testPniNonPqToPqUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
final byte deviceId2 = 2;
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of()));
when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Account updatedAccount =
accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, newSignedPqKeys, newRegistrationIds);
// non-PNI-keys stuff should not change
assertEquals(oldUuid, updatedAccount.getUuid());
assertEquals(number, updatedAccount.getNumber());
assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier());
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt())));
verify(accounts).updateTransactionallyAsync(any(), any());
verify(keysManager).deleteSingleUsePreKeys(oldPni);
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any());
verify(keysManager, never()).buildWriteItemForLastResortKey(any(), anyByte(), any());
} }
@Test @Test

View File

@ -7,13 +7,10 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -184,8 +181,6 @@ class KeysManagerTest {
@Test @Test
void testStorePqLastResort() { void testStorePqLastResort() {
assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size());
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final byte deviceId2 = 2; final byte deviceId2 = 2;
@ -194,35 +189,21 @@ class KeysManagerTest {
keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair)).join(); keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair)).join();
keysManager.storePqLastResort(ACCOUNT_UUID, (byte) 2, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join(); keysManager.storePqLastResort(ACCOUNT_UUID, (byte) 2, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join();
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size()); assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().orElseThrow().keyId());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId()); assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().orElseThrow().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().isPresent()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().isPresent());
keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair)).join(); keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair)).join();
keysManager.storePqLastResort(ACCOUNT_UUID, deviceId3, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join(); keysManager.storePqLastResort(ACCOUNT_UUID, deviceId3, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join();
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing new last-resort keys should not create duplicates"); assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().orElseThrow().keyId(),
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(),
"storing new last-resort keys should overwrite old ones"); "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId(), assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().orElseThrow().keyId(),
"storing new last-resort keys should leave untouched ones alone"); "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().get().keyId(), assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().orElseThrow().keyId(),
"storing new last-resort keys should overwrite old ones"); "storing new last-resort keys should overwrite old ones");
} }
@Test
void testGetPqEnabledDevices() {
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, (byte) (DEVICE_ID + 1), generateTestKEMSignedPreKey(2)).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), List.of(generateTestKEMSignedPreKey(3))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), generateTestKEMSignedPreKey(4)).join();
assertIterableEquals(
Set.of((byte) (DEVICE_ID + 1), (byte) (DEVICE_ID + 2)),
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID).join()));
}
private static ECPreKey generateTestPreKey(final long keyId) { private static ECPreKey generateTestPreKey(final long keyId) {
return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey()); return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey());
} }

View File

@ -29,13 +29,4 @@ public class DevicesHelper {
return device; return device;
} }
public static Device createDisabledDevice(final byte deviceId, final int registrationId) {
final Device device = new Device();
device.setId(deviceId);
device.setUserAgent("OWT");
device.setRegistrationId(registrationId);
return device;
}
} }