From b400d49e773ee310432f5e3a41e795edefe061ac Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 5 May 2025 14:54:12 -0400 Subject: [PATCH] Require PQ keys when changing numbers or distributing key material --- .../entities/ChangeNumberRequest.java | 30 +-- ...eNumberIdentityKeyDistributionRequest.java | 27 +-- .../storage/AccountsManager.java | 71 ++---- .../storage/ChangeNumberManager.java | 15 +- .../textsecuregcm/storage/KeysManager.java | 5 - .../storage/RepeatedUseSignedPreKeyStore.java | 16 -- .../controllers/AccountControllerV2Test.java | 72 ++++-- ...ntsManagerChangeNumberIntegrationTest.java | 69 +++++- .../storage/AccountsManagerTest.java | 208 +++++------------- .../storage/KeysManagerTest.java | 29 +-- .../tests/util/DevicesHelper.java | 9 - 11 files changed, 224 insertions(+), 327 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java index 5c686868c..789466f28 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java @@ -17,6 +17,7 @@ import javax.annotation.Nullable; import jakarta.validation.Valid; import jakarta.validation.constraints.AssertTrue; import jakarta.validation.constraints.NotBlank; +import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotNull; import org.signal.libsignal.protocol.IdentityKey; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; @@ -51,34 +52,27 @@ public record ChangeNumberRequest( arraySchema=@Schema(description=""" A list of synchronization messages to send to companion devices to supply the private keysManager associated with the new identity key and their new prekeys. - Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")) + Exactly one message must be supplied for each device other than the sending (primary) device.""")) @NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages, @Schema(description=""" - A new signed elliptic-curve prekey for each enabled device on the account, including this one. + A new signed elliptic-curve prekey for each device on the account, including this one. Each must be accompanied by a valid signature from the new identity key in this request.""") - @NotNull @Valid Map devicePniSignedPrekeys, + @NotNull @NotEmpty @Valid Map devicePniSignedPrekeys, @Schema(description=""" - A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. - May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored. - If present, must contain one prekey per enabled device including this one. - Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. + A new signed post-quantum last-resort prekey for each device on the account, including this one. Each must be accompanied by a valid signature from the new identity key in this request.""") - @Valid Map devicePniPqLastResortPrekeys, + @NotNull @NotEmpty @Valid Map devicePniPqLastResortPrekeys, - @Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one") - @NotNull Map pniRegistrationIds) implements PhoneVerificationRequest { + @Schema(description="the new phone-number-identity registration ID for each device on the account, including this one") + @NotNull @NotEmpty Map pniRegistrationIds) implements PhoneVerificationRequest { public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) { - List> spks = new ArrayList<>(); - if (devicePniSignedPrekeys != null) { - spks.addAll(devicePniSignedPrekeys.values()); - } - if (devicePniPqLastResortPrekeys != null) { - spks.addAll(devicePniPqLastResortPrekeys.values()); - } - return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "change-number"); + final List> spks = new ArrayList<>(devicePniSignedPrekeys.values()); + spks.addAll(devicePniPqLastResortPrekeys.values()); + + return PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "change-number"); } @AssertTrue diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java index d05ddbf08..eaf5a4d5b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java @@ -9,6 +9,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import io.swagger.v3.oas.annotations.media.ArraySchema; import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.Valid; +import jakarta.validation.constraints.NotEmpty; import jakarta.validation.constraints.NotNull; import java.util.ArrayList; import java.util.List; @@ -29,36 +30,36 @@ public record PhoneNumberIdentityKeyDistributionRequest( arraySchema=@Schema(description=""" A list of synchronization messages to send to companion devices to supply the private keys associated with the new identity key and their new prekeys. - Exactly one message must be supplied for each enabled device other than the sending (primary) device. + Exactly one message must be supplied for each device other than the sending (primary) device. """)) List<@NotNull @Valid IncomingMessage> deviceMessages, @NotNull + @NotEmpty @Valid @Schema(description=""" - A new signed elliptic-curve prekey for each enabled device on the account, including this one. + A new signed elliptic-curve prekey for each device on the account, including this one. Each must be accompanied by a valid signature from the new identity key in this request.""") Map devicePniSignedPrekeys, + @NotNull + @NotEmpty + @Valid @Schema(description=""" - A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. - May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored. - If present, must contain one prekey per enabled device including this one. - Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. + A new signed post-quantum last-resort prekey for each device on the account, including this one. Each must be accompanied by a valid signature from the new identity key in this request.""") - @Valid Map devicePniPqLastResortPrekeys, + Map devicePniPqLastResortPrekeys, @NotNull + @NotEmpty @Valid @Schema(description="The new registration ID to use for the phone-number identity of each device, including this one.") Map pniRegistrationIds) { public boolean isSignatureValidOnEachSignedPreKey(@Nullable final String userAgent) { - List> spks = new ArrayList<>(devicePniSignedPrekeys.values()); - if (devicePniPqLastResortPrekeys != null) { - spks.addAll(devicePniPqLastResortPrekeys.values()); - } - return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks, userAgent, "distribute-pni-keys"); - } + final List> signedPreKeys = new ArrayList<>(devicePniSignedPrekeys.values()); + signedPreKeys.addAll(devicePniPqLastResortPrekeys.values()); + return PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, signedPreKeys, userAgent, "distribute-pni-keys"); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 520bbf8b1..848db12a2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -62,7 +62,6 @@ import java.util.stream.Stream; import javax.annotation.Nullable; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; -import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.StringUtils; import org.signal.libsignal.protocol.IdentityKey; import org.slf4j.Logger; @@ -641,19 +640,16 @@ public class AccountsManager extends RedisPubSubAdapter implemen } public Account changeNumber(final Account account, - final String targetNumber, - @Nullable final IdentityKey pniIdentityKey, - @Nullable final Map pniSignedPreKeys, - @Nullable final Map pniPqLastResortPreKeys, - @Nullable final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { + final String targetNumber, + final IdentityKey pniIdentityKey, + final Map pniSignedPreKeys, + final Map pniPqLastResortPreKeys, + final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier(); final UUID targetPhoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber).join(); if (originalPhoneNumberIdentifier.equals(targetPhoneNumberIdentifier)) { - if (pniIdentityKey != null) { - throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePniKeys"); - } return account; } @@ -694,7 +690,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen .join(); final Collection keyWriteItems = - buildPniKeyWriteItems(uuid, targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys); + buildPniKeyWriteItems(targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys); final Account numberChangedAccount = updateWithRetries( account, @@ -715,7 +711,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen public Account updatePniKeys(final Account account, final IdentityKey pniIdentityKey, final Map pniSignedPreKeys, - @Nullable final Map pniPqLastResortPreKeys, + final Map pniPqLastResortPreKeys, final Map pniRegistrationIds) throws MismatchedDevicesException { validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds); @@ -724,7 +720,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen final UUID pni = account.getIdentifier(IdentityType.PNI); final Collection keyWriteItems = - buildPniKeyWriteItems(pni, pni, pniSignedPreKeys, pniPqLastResortPreKeys); + buildPniKeyWriteItems(pni, pniSignedPreKeys, pniPqLastResortPreKeys); return redisDeleteAsync(account) .thenCompose(ignored -> keysManager.deleteSingleUsePreKeys(pni)) @@ -739,41 +735,24 @@ public class AccountsManager extends RedisPubSubAdapter implemen } private Collection buildPniKeyWriteItems( - final UUID enabledDevicesIdentifier, final UUID phoneNumberIdentifier, - @Nullable final Map pniSignedPreKeys, - @Nullable final Map pniPqLastResortPreKeys) { + final Map pniSignedPreKeys, + final Map pniPqLastResortPreKeys) { final List keyWriteItems = new ArrayList<>(); - if (pniSignedPreKeys != null) { - pniSignedPreKeys.forEach((deviceId, signedPreKey) -> - keyWriteItems.add(keysManager.buildWriteItemForEcSignedPreKey(phoneNumberIdentifier, deviceId, signedPreKey))); - } + pniSignedPreKeys.forEach((deviceId, signedPreKey) -> + keyWriteItems.add(keysManager.buildWriteItemForEcSignedPreKey(phoneNumberIdentifier, deviceId, signedPreKey))); - if (pniPqLastResortPreKeys != null) { - keysManager.getPqEnabledDevices(enabledDevicesIdentifier) - .thenAccept(deviceIds -> deviceIds.stream() - .filter(pniPqLastResortPreKeys::containsKey) - .map(deviceId -> keysManager.buildWriteItemForLastResortKey(phoneNumberIdentifier, - deviceId, - pniPqLastResortPreKeys.get(deviceId))) - .forEach(keyWriteItems::add)) - .join(); - } + pniPqLastResortPreKeys.forEach((deviceId, lastResortKey) -> + keyWriteItems.add(keysManager.buildWriteItemForLastResortKey(phoneNumberIdentifier, deviceId, lastResortKey))); return keyWriteItems; } private void setPniKeys(final Account account, - @Nullable final IdentityKey pniIdentityKey, - @Nullable final Map pniRegistrationIds) { - - if (ObjectUtils.allNull(pniIdentityKey, pniRegistrationIds)) { - return; - } else if (!ObjectUtils.allNotNull(pniIdentityKey, pniRegistrationIds)) { - throw new IllegalArgumentException("PNI identity key and registration IDs must be all null or all non-null"); - } + final IdentityKey pniIdentityKey, + final Map pniRegistrationIds) { account.getDevices() .forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId()))); @@ -782,22 +761,15 @@ public class AccountsManager extends RedisPubSubAdapter implemen } private void validateDevices(final Account account, - @Nullable final Map pniSignedPreKeys, - @Nullable final Map pniPqLastResortPreKeys, - @Nullable final Map pniRegistrationIds) throws MismatchedDevicesException { - if (pniSignedPreKeys == null && pniRegistrationIds == null) { - return; - } else if (pniSignedPreKeys == null || pniRegistrationIds == null) { - throw new IllegalArgumentException("Signed pre-keys and registration IDs must both be null or both be non-null"); - } + final Map pniSignedPreKeys, + final Map pniPqLastResortPreKeys, + final Map pniRegistrationIds) throws MismatchedDevicesException { // Check that all including primary ID are in signed pre-keys validateCompleteDeviceList(account, pniSignedPreKeys.keySet()); // Check that all including primary ID are in Pq pre-keys - if (pniPqLastResortPreKeys != null) { - validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet()); - } + validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet()); // Check that all devices are accounted for in the map of new PNI registration IDs validateCompleteDeviceList(account, pniRegistrationIds.keySet()); @@ -816,8 +788,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen extraDeviceIds.removeAll(accountDeviceIds); if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) { - throw new MismatchedDevicesException( - new MismatchedDevices(missingDeviceIds, extraDeviceIds, Collections.emptySet())); + throw new MismatchedDevicesException(new MismatchedDevices(missingDeviceIds, extraDeviceIds, Set.of())); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index 17449a509..7fb942c52 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -42,13 +42,14 @@ public class ChangeNumberManager { this.clock = clock; } - public Account changeNumber(final Account account, final String number, - @Nullable final IdentityKey pniIdentityKey, - @Nullable final Map deviceSignedPreKeys, - @Nullable final Map devicePqLastResortPreKeys, - @Nullable final List deviceMessages, - @Nullable final Map pniRegistrationIds, - @Nullable final String senderUserAgent) + public Account changeNumber(final Account account, + final String number, + final IdentityKey pniIdentityKey, + final Map deviceSignedPreKeys, + final Map devicePqLastResortPreKeys, + final List deviceMessages, + final Map pniRegistrationIds, + final String senderUserAgent) throws InterruptedException, MismatchedDevicesException, MessageTooLargeException { if (!(ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds) || diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index a6eed2238..40ecc6fab 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -5,7 +5,6 @@ package org.whispersystems.textsecuregcm.storage; -import com.google.common.annotations.VisibleForTesting; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -114,10 +113,6 @@ public class KeysManager { return ecSignedPreKeys.find(identifier, deviceId); } - public CompletableFuture> getPqEnabledDevices(final UUID identifier) { - return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().toFuture(); - } - public CompletableFuture getEcCount(final UUID identifier, final byte deviceId) { return ecPreKeys.getCount(identifier, deviceId); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java index 2b67cbbea..c51133ead 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java @@ -14,14 +14,12 @@ import java.util.concurrent.CompletableFuture; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.util.AttributeValues; -import reactor.core.publisher.Flux; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.Delete; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.Put; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; -import software.amazon.awssdk.services.dynamodb.model.QueryRequest; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; /** @@ -116,20 +114,6 @@ public abstract class RepeatedUseSignedPreKeyStore> { return findFuture; } - public Flux getDeviceIdsWithKeys(final UUID identifier) { - return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() - .tableName(tableName) - .keyConditionExpression("#uuid = :uuid") - .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) - .expressionAttributeValues(Map.of( - ":uuid", getPartitionKey(identifier))) - .projectionExpression(KEY_DEVICE_ID) - .consistentRead(true) - .build()) - .items()) - .map(item -> Byte.parseByte(item.get(KEY_DEVICE_ID).n())); - } - protected static Map getPrimaryKey(final UUID identifier, final byte deviceId) { return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(identifier), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java index a119bb91e..8d9293bb5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java @@ -72,6 +72,8 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse; import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest; import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession; import org.whispersystems.textsecuregcm.identity.IdentityType; @@ -92,6 +94,7 @@ import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; @@ -100,7 +103,8 @@ class AccountControllerV2Test { private static final long SESSION_EXPIRATION_SECONDS = Duration.ofMinutes(10).toSeconds(); - private static final IdentityKey IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair(); + private static final IdentityKey IDENTITY_KEY = new IdentityKey(IDENTITY_KEY_PAIR.getPublicKey()); private static final String NEW_NUMBER = PhoneNumberUtil.getInstance().format( PhoneNumberUtil.getInstance().getExampleNumber("US"), @@ -185,9 +189,11 @@ class AccountControllerV2Test { .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity( - new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", new IdentityKey(Curve.generateKeyPair().getPublicKey()), + new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", IDENTITY_KEY, Collections.emptyList(), - Collections.emptyMap(), null, Collections.emptyMap()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR)), + Map.of(Device.PRIMARY_ID, 17)), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(), @@ -207,10 +213,11 @@ class AccountControllerV2Test { .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity( - new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null, - new IdentityKey(Curve.generateKeyPair().getPublicKey()), + new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null, IDENTITY_KEY, Collections.emptyList(), - Collections.emptyMap(), null, Collections.emptyMap()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR)), + Map.of(Device.PRIMARY_ID, 17)), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), @@ -291,9 +298,11 @@ class AccountControllerV2Test { .thenReturn(CompletableFuture.completedFuture( Optional.of(new RegistrationServiceSession(new byte[16], NEW_NUMBER, true, null, null, null, SESSION_EXPIRATION_SECONDS)))); - final ChangeNumberRequest changeNumberRequest = new ChangeNumberRequest(encodeSessionId("session"), - null, NEW_NUMBER, "123", new IdentityKey(Curve.generateKeyPair().getPublicKey()), - Collections.emptyList(), Collections.emptyMap(), null, Map.of((byte) 1, pniRegistrationId)); + final ChangeNumberRequest changeNumberRequest = new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", IDENTITY_KEY, + Collections.emptyList(), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR)), + Map.of(Device.PRIMARY_ID, pniRegistrationId)); try (final Response response = resources.getJerseyTest() .target("/v2/accounts/number") @@ -503,9 +512,11 @@ class AccountControllerV2Test { .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity( - new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", new IdentityKey(Curve.generateKeyPair().getPublicKey()), + new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", IDENTITY_KEY, Collections.emptyList(), - Collections.emptyMap(), null, Collections.emptyMap()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR)), + Map.of(Device.PRIMARY_ID, 17)), MediaType.APPLICATION_JSON_TYPE))) { assertEquals(413, response.getStatus()); @@ -519,17 +530,17 @@ class AccountControllerV2Test { return requestJson("", recoveryPassword, newNumber, 123); } - /** - * Valid request JSON with the given pniRegistrationId - */ - private static String requestJsonRegistrationIds(final Integer pniRegistrationId) { - return requestJson("", new byte[0], "+18005551234", pniRegistrationId); - } - /** * Valid request JSON with the give session ID and recovery password */ - private static String requestJson(final String sessionId, final byte[] recoveryPassword, final String newNumber, final Integer pniRegistrationId) { + private static String requestJson(final String sessionId, + final byte[] recoveryPassword, + final String newNumber, + final Integer pniRegistrationId) { + + final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR); + final KEMSignedPreKey pniLastResortPreKey = KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR); + return String.format(""" { "sessionId": "%s", @@ -538,10 +549,17 @@ class AccountControllerV2Test { "reglock": "1234", "pniIdentityKey": "%s", "deviceMessages": [], - "devicePniSignedPrekeys": {}, + "devicePniSignedPrekeys": {"1": {"keyId": %d, "publicKey": "%s", "signature": "%s"}}, + "devicePniPqLastResortPrekeys": {"1": {"keyId": %d, "publicKey": "%s", "signature": "%s"}}, "pniRegistrationIds": {"1": %d} } - """, encodeSessionId(sessionId), encodeRecoveryPassword(recoveryPassword), newNumber, Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize()), pniRegistrationId); + """, encodeSessionId(sessionId), + encodeRecoveryPassword(recoveryPassword), + newNumber, + Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize()), + pniSignedPreKey.keyId(), Base64.getEncoder().encodeToString(pniSignedPreKey.serializedPublicKey()), Base64.getEncoder().encodeToString(pniSignedPreKey.signature()), + pniLastResortPreKey.keyId(), Base64.getEncoder().encodeToString(pniLastResortPreKey.serializedPublicKey()), Base64.getEncoder().encodeToString(pniLastResortPreKey.signature()), + pniRegistrationId); } /** @@ -698,15 +716,21 @@ class AccountControllerV2Test { * Valid request JSON for a {@link org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest} */ private static String requestJson() { + final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR); + final KEMSignedPreKey pniLastResortPreKey = KeysHelper.signedKEMPreKey(2, IDENTITY_KEY_PAIR); + return String.format(""" { "pniIdentityKey": "%s", "deviceMessages": [], "devicePniSignedPrekeys": {}, - "devicePniSignedPqPrekeys": {}, - "pniRegistrationIds": {} + "devicePniSignedPrekeys": {"1": {"keyId": %d, "publicKey": "%s", "signature": "%s"}}, + "devicePniPqLastResortPrekeys": {"1": {"keyId": %d, "publicKey": "%s", "signature": "%s"}}, + "pniRegistrationIds": {"1": 17} } - """, Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize())); + """, Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize()), + pniSignedPreKey.keyId(), Base64.getEncoder().encodeToString(pniSignedPreKey.serializedPublicKey()), Base64.getEncoder().encodeToString(pniSignedPreKey.signature()), + pniLastResortPreKey.keyId(), Base64.getEncoder().encodeToString(pniLastResortPreKey.serializedPublicKey()), Base64.getEncoder().encodeToString(pniLastResortPreKey.signature())); } /** diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index 20206f18c..87e047f7d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; @@ -173,7 +174,14 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); - accountsManager.changeNumber(account, secondNumber, null, null, null, null); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + + accountsManager.changeNumber(account, + secondNumber, + new IdentityKey(pniIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, 1)); assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); @@ -193,6 +201,7 @@ class AccountsManagerChangeNumberIntegrationTest { final int rotatedPniRegistrationId = 17; final ECKeyPair rotatedPniIdentityKeyPair = Curve.generateKeyPair(); final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, rotatedPniIdentityKeyPair); + final KEMSignedPreKey rotatedKemSignedPreKey = KeysHelper.signedKEMPreKey(2L, rotatedPniIdentityKeyPair); final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, rotatedPniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, Set.of()); final Account account = AccountsHelper.createAccount(accountsManager, originalNumber, accountAttributes); @@ -204,9 +213,10 @@ class AccountsManagerChangeNumberIntegrationTest { final IdentityKey pniIdentityKey = new IdentityKey(rotatedPniIdentityKeyPair.getPublicKey()); final Map preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey); + final Map kemSignedPreKeys = Map.of(Device.PRIMARY_ID, rotatedKemSignedPreKey); final Map registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId); - final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds); + final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, kemSignedPreKeys, registrationIds); final UUID secondPni = updatedAccount.getPhoneNumberIdentifier(); assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); @@ -240,9 +250,24 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); - account = accountsManager.changeNumber(account, secondNumber, null, null, null, null); + final ECKeyPair originalIdentityKeyPair = Curve.generateKeyPair(); + final ECKeyPair secondIdentityKeyPair = Curve.generateKeyPair(); + + account = accountsManager.changeNumber(account, + secondNumber, + new IdentityKey(secondIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, secondIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, secondIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, 1)); + final UUID secondPni = account.getPhoneNumberIdentifier(); - accountsManager.changeNumber(account, originalNumber, null, null, null, null); + + accountsManager.changeNumber(account, + originalNumber, + new IdentityKey(originalIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(3, originalIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(4, originalIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, 2)); assertTrue(accountsManager.getByE164(originalNumber).isPresent()); assertEquals(originalUuid, accountsManager.getByE164(originalNumber).map(Account::getUuid).orElseThrow()); @@ -266,11 +291,20 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); + final ECKeyPair originalIdentityKeyPair = Curve.generateKeyPair(); + final ECKeyPair secondIdentityKeyPair = Curve.generateKeyPair(); + final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber); final UUID existingAccountUuid = existingAccount.getUuid(); - accountsManager.changeNumber(account, secondNumber, null, null, null, null); + accountsManager.changeNumber(account, + secondNumber, + new IdentityKey(secondIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, secondIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, secondIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, 1)); + final UUID secondPni = accountsManager.getByE164(secondNumber).get().getPhoneNumberIdentifier(); assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); @@ -285,7 +319,12 @@ class AccountsManagerChangeNumberIntegrationTest { assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni)); - accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null, null); + accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), + originalNumber, + new IdentityKey(originalIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, originalIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, originalIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, 1)); final Account existingAccount2 = AccountsHelper.createAccount(accountsManager, secondNumber); @@ -305,8 +344,15 @@ class AccountsManagerChangeNumberIntegrationTest { final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber); final UUID existingAccountUuid = existingAccount.getUuid(); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + + final Account changedNumberAccount = accountsManager.changeNumber(account, + secondNumber, + new IdentityKey(pniIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, 1)); - final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null, null); final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier(); final Account reRegisteredAccount = AccountsHelper.createAccount(accountsManager, originalNumber); @@ -317,7 +363,14 @@ class AccountsManagerChangeNumberIntegrationTest { assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni)); - final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber, null, null, null, null); + final ECKeyPair reRegisteredPniIdentityKeyPair = Curve.generateKeyPair(); + + final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, + secondNumber, + new IdentityKey(reRegisteredPniIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, reRegisteredPniIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, reRegisteredPniIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, 1)); assertEquals(Optional.of(originalUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalPni)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 88de8f8f5..097d2b940 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -1048,9 +1048,18 @@ class AccountsManagerTest { final String targetNumber = "+14153333333"; final UUID uuid = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID(); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - account = accountsManager.changeNumber(account, targetNumber, null, null, null, null); + final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(1, pniIdentityKeyPair); + final KEMSignedPreKey kemLastResortPreKey = KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair); + + Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + account = accountsManager.changeNumber(account, + targetNumber, + new IdentityKey(pniIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, ecSignedPreKey), + Map.of(Device.PRIMARY_ID, kemLastResortPreKey), + Map.of(Device.PRIMARY_ID, 1)); assertEquals(targetNumber, account.getNumber()); @@ -1058,18 +1067,26 @@ class AccountsManagerTest { verify(keysManager).deleteSingleUsePreKeys(originalPni); verify(keysManager).deleteSingleUsePreKeys(phoneNumberIdentifiersByE164.get(targetNumber)); + verify(keysManager).buildWriteItemForEcSignedPreKey(phoneNumberIdentifiersByE164.get(targetNumber), Device.PRIMARY_ID, ecSignedPreKey); + verify(keysManager).buildWriteItemForLastResortKey(phoneNumberIdentifiersByE164.get(targetNumber), Device.PRIMARY_ID, kemLastResortPreKey); } @Test void testChangePhoneNumberSameNumber() throws InterruptedException, MismatchedDevicesException { final String number = "+14152222222"; + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); phoneNumberIdentifiersByE164.put(number, account.getPhoneNumberIdentifier()); - account = accountsManager.changeNumber(account, number, null, null, null, null); + account = accountsManager.changeNumber(account, + number, + new IdentityKey(pniIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, 1)); assertEquals(number, account.getNumber()); - verify(keysManager, never()).deleteSingleUsePreKeys(any()); + verifyNoInteractions(keysManager); } @Test @@ -1077,31 +1094,20 @@ class AccountsManagerTest { final String originalNumber = "+22923456789"; // the canonical form of numbers may change over time, so we use PNIs as stable identifiers final String newNumber = "+2290123456789"; + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); Account account = AccountsHelper.generateTestAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); phoneNumberIdentifiersByE164.put(originalNumber, account.getPhoneNumberIdentifier()); phoneNumberIdentifiersByE164.put(newNumber, account.getPhoneNumberIdentifier()); - account = accountsManager.changeNumber(account, newNumber, null, null, null, null); + account = accountsManager.changeNumber(account, + newNumber, + new IdentityKey(pniIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair)), + Map.of(Device.PRIMARY_ID, 1)); assertEquals(originalNumber, account.getNumber()); - verify(keysManager, never()).deleteSingleUsePreKeys(any()); - } - - @Test - void testChangePhoneNumberSameNumberWithPniData() { - final String number = "+14152222222"; - - Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - phoneNumberIdentifiersByE164.put(number, account.getPhoneNumberIdentifier()); - final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - assertThrows(IllegalArgumentException.class, - () -> accountsManager.changeNumber( - account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), - Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of((byte) 1, 101)), - "AccountsManager should not allow use of changeNumber with new PNI keys but without changing number"); - - verify(accounts, never()).update(any()); verifyNoInteractions(keysManager); } @@ -1113,12 +1119,21 @@ class AccountsManagerTest { final UUID uuid = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID(); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); - Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - account = accountsManager.changeNumber(account, targetNumber, null, null, null, null); + final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(1, pniIdentityKeyPair); + final KEMSignedPreKey kemLastResoryPreKey = KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair); + + Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, List.of(DevicesHelper.createDevice(Device.PRIMARY_ID)), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); + account = accountsManager.changeNumber(account, + targetNumber, + new IdentityKey(pniIdentityKeyPair.getPublicKey()), + Map.of(Device.PRIMARY_ID, ecSignedPreKey), + Map.of(Device.PRIMARY_ID, kemLastResoryPreKey), + Map.of(Device.PRIMARY_ID, 1)); assertEquals(targetNumber, account.getNumber()); @@ -1129,6 +1144,9 @@ class AccountsManagerTest { verify(keysManager).deleteSingleUsePreKeys(originalPni); verify(keysManager, atLeastOnce()).deleteSingleUsePreKeys(targetPni); verify(keysManager).deleteSingleUsePreKeys(newPni); + verify(keysManager).buildWriteItemsForRemovedDevice(existingAccountUuid, targetPni, Device.PRIMARY_ID); + verify(keysManager).buildWriteItemForEcSignedPreKey(newPni, Device.PRIMARY_ID, ecSignedPreKey); + verify(keysManager).buildWriteItemForLastResortKey(newPni, Device.PRIMARY_ID, kemLastResoryPreKey); verifyNoMoreInteractions(keysManager); } @@ -1141,27 +1159,22 @@ class AccountsManagerTest { final UUID originalPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID(); final byte deviceId2 = 2; - final byte deviceId3 = 3; final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final Map newSignedKeys = Map.of( Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), - deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair), - deviceId3, KeysHelper.signedECPreKey(3, identityKeyPair)); + deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); final Map newSignedPqKeys = Map.of( Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(4, identityKeyPair), - deviceId2, KeysHelper.signedKEMPreKey(5, identityKeyPair), - deviceId3, KeysHelper.signedKEMPreKey(6, identityKeyPair)); - final Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202, deviceId3, 203); + deviceId2, KeysHelper.signedKEMPreKey(5, identityKeyPair)); + final Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); - when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID, deviceId3))); when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); final List devices = List.of( DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), - DevicesHelper.createDevice(deviceId2, 0L, 102), - DevicesHelper.createDisabledDevice(deviceId3, 103)); + DevicesHelper.createDevice(deviceId2, 0L, 102)); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final Account updatedAccount = accountsManager.changeNumber( account, targetNumber, new IdentityKey(Curve.generateKeyPair().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds); @@ -1175,24 +1188,20 @@ class AccountsManagerTest { verify(keysManager, atLeastOnce()).deleteSingleUsePreKeys(targetPni); verify(keysManager).deleteSingleUsePreKeys(newPni); verify(keysManager).deleteSingleUsePreKeys(originalPni); - verify(keysManager).getPqEnabledDevices(uuid); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId2), any()); - verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId3), any()); verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(Device.PRIMARY_ID), any()); - verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(deviceId3), any()); + verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(deviceId2), any()); verifyNoMoreInteractions(keysManager); } @Test - void testChangePhoneNumberWithMismatchedPqKeys() throws InterruptedException, MismatchedDevicesException { + void testChangePhoneNumberWithMismatchedPqKeys() { final String originalNumber = "+14152222222"; final String targetNumber = "+14153333333"; - final UUID existingAccountUuid = UUID.randomUUID(); final UUID uuid = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID(); - final UUID targetPni = UUID.randomUUID(); final byte deviceId2 = 2; final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final Map newSignedKeys = Map.of( @@ -1202,11 +1211,6 @@ class AccountsManagerTest { Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair)); final Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); - final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); - when(keysManager.getPqEnabledDevices(uuid)).thenReturn( - CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID))); - final List devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), DevicesHelper.createDevice(deviceId2, 0L, 102)); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); @@ -1242,6 +1246,9 @@ class AccountsManagerTest { Map newSignedKeys = Map.of( Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); + Map newSignedKemKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair), + deviceId2, KeysHelper.signedKEMPreKey(2, identityKeyPair)); Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); UUID oldUuid = account.getUuid(); @@ -1249,10 +1256,9 @@ class AccountsManagerTest { final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); - when(keysManager.getPqEnabledDevices(any())).thenReturn(CompletableFuture.completedFuture(Collections.emptyList())); when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); - final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds); + final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, newSignedKemKeys, newRegistrationIds); // non-PNI stuff should not change assertEquals(oldUuid, updatedAccount.getUuid()); @@ -1272,111 +1278,7 @@ class AccountsManagerTest { verify(keysManager).deleteSingleUsePreKeys(oldPni); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any()); - verify(keysManager, never()).buildWriteItemForLastResortKey(any(), anyByte(), any()); - } - - @Test - void testPniPqUpdate() throws MismatchedDevicesException { - final String number = "+14152222222"; - final byte deviceId2 = 2; - - List devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), - DevicesHelper.createDevice(deviceId2, 0L, 102)); - - Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final Map newSignedKeys = Map.of( - Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), - deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); - final Map newSignedPqKeys = Map.of( - Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), - deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair)); - Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); - - UUID oldUuid = account.getUuid(); - UUID oldPni = account.getPhoneNumberIdentifier(); - - when(keysManager.getPqEnabledDevices(oldPni)).thenReturn( - CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID))); - when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); - - final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); - - final Account updatedAccount = - accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, newSignedPqKeys, newRegistrationIds); - - // non-PNI-keys stuff should not change - assertEquals(oldUuid, updatedAccount.getUuid()); - assertEquals(number, updatedAccount.getNumber()); - assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier()); - assertNull(updatedAccount.getIdentityKey(IdentityType.ACI)); - assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102), - updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); - - // PNI keys should - assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI)); - assertEquals(newRegistrationIds, - updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); - - verify(accounts).updateTransactionallyAsync(any(), any()); - - verify(keysManager).deleteSingleUsePreKeys(oldPni); - verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); - verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any()); - verify(keysManager).buildWriteItemForLastResortKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); - verify(keysManager, never()).buildWriteItemForLastResortKey(eq(oldPni), eq(deviceId2), any()); - } - - @Test - void testPniNonPqToPqUpdate() throws MismatchedDevicesException { - final String number = "+14152222222"; - final byte deviceId2 = 2; - - List devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), - DevicesHelper.createDevice(deviceId2, 0L, 102)); - - Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final Map newSignedKeys = Map.of( - Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), - deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); - final Map newSignedPqKeys = Map.of( - Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), - deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair)); - Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); - - UUID oldUuid = account.getUuid(); - UUID oldPni = account.getPhoneNumberIdentifier(); - - when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of())); - when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); - - final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); - - final Account updatedAccount = - accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, newSignedPqKeys, newRegistrationIds); - - // non-PNI-keys stuff should not change - assertEquals(oldUuid, updatedAccount.getUuid()); - assertEquals(number, updatedAccount.getNumber()); - assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier()); - assertNull(updatedAccount.getIdentityKey(IdentityType.ACI)); - assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102), - updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); - - // PNI keys should - assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI)); - assertEquals(newRegistrationIds, - updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); - - verify(accounts).updateTransactionallyAsync(any(), any()); - - verify(keysManager).deleteSingleUsePreKeys(oldPni); - verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); - verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any()); - verify(keysManager, never()).buildWriteItemForLastResortKey(any(), anyByte(), any()); + verify(keysManager).buildWriteItemForLastResortKey(eq(oldPni), eq(deviceId2), any()); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java index bdc13f39e..236665c8a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java @@ -7,13 +7,10 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertIterableEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.UUID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -184,8 +181,6 @@ class KeysManagerTest { @Test void testStorePqLastResort() { - assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size()); - final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final byte deviceId2 = 2; @@ -194,35 +189,21 @@ class KeysManagerTest { keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair)).join(); keysManager.storePqLastResort(ACCOUNT_UUID, (byte) 2, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join(); - assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size()); - assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId()); - assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId()); + assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().orElseThrow().keyId()); + assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().orElseThrow().keyId()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().isPresent()); keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair)).join(); keysManager.storePqLastResort(ACCOUNT_UUID, deviceId3, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join(); - assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing new last-resort keys should not create duplicates"); - assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(), + assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().orElseThrow().keyId(), "storing new last-resort keys should overwrite old ones"); - assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId(), + assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().orElseThrow().keyId(), "storing new last-resort keys should leave untouched ones alone"); - assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().get().keyId(), + assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().orElseThrow().keyId(), "storing new last-resort keys should overwrite old ones"); } - @Test - void testGetPqEnabledDevices() { - keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); - keysManager.storePqLastResort(ACCOUNT_UUID, (byte) (DEVICE_ID + 1), generateTestKEMSignedPreKey(2)).join(); - keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), List.of(generateTestKEMSignedPreKey(3))).join(); - keysManager.storePqLastResort(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), generateTestKEMSignedPreKey(4)).join(); - - assertIterableEquals( - Set.of((byte) (DEVICE_ID + 1), (byte) (DEVICE_ID + 2)), - Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID).join())); - } - private static ECPreKey generateTestPreKey(final long keyId) { return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey()); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java index bc40205c1..a854c5174 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java @@ -29,13 +29,4 @@ public class DevicesHelper { return device; } - - public static Device createDisabledDevice(final byte deviceId, final int registrationId) { - final Device device = new Device(); - device.setId(deviceId); - device.setUserAgent("OWT"); - device.setRegistrationId(registrationId); - - return device; - } }