From c9f21d5970d19bc241be5437a8d0d6be409b3e3a Mon Sep 17 00:00:00 2001 From: ravi-signal <99042880+ravi-signal@users.noreply.github.com> Date: Wed, 9 Jul 2025 09:17:17 -0500 Subject: [PATCH] Always read from new and old PQ prekey stores, add experiment to start writing to new prekey store --- .../textsecuregcm/WhisperServerService.java | 7 +- .../textsecuregcm/storage/KeysManager.java | 72 ++++++++++-- .../storage/PagedSingleUseKEMPreKeyStore.java | 8 +- .../workers/CommandDependencies.java | 7 +- ...ccountCreationDeletionIntegrationTest.java | 5 +- ...ntsManagerChangeNumberIntegrationTest.java | 5 +- ...ConcurrentModificationIntegrationTest.java | 1 + ...ccountsManagerUsernameIntegrationTest.java | 5 +- .../AddRemoveDeviceIntegrationTest.java | 5 +- .../storage/KeysManagerTest.java | 105 +++++++++++++++--- 10 files changed, 184 insertions(+), 36 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index a887fff57..0023cc8b6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -368,6 +368,8 @@ public class WhisperServerService extends Application { @@ -444,7 +446,8 @@ public class WhisperServerService extends Application storeEcSignedPreKeys(final UUID identifier, final byte deviceId, final ECSignedPreKey ecSignedPreKey) { + public CompletableFuture storeEcSignedPreKeys(final UUID identifier, final byte deviceId, + final ECSignedPreKey ecSignedPreKey) { return ecSignedPreKeys.store(identifier, deviceId, ecSignedPreKey); } - public CompletableFuture storePqLastResort(final UUID identifier, final byte deviceId, final KEMSignedPreKey lastResortKey) { + public CompletableFuture storePqLastResort(final UUID identifier, final byte deviceId, + final KEMSignedPreKey lastResortKey) { return pqLastResortKeys.store(identifier, deviceId, lastResortKey); } public CompletableFuture storeEcOneTimePreKeys(final UUID identifier, final byte deviceId, - final List preKeys) { + final List preKeys) { return ecPreKeys.store(identifier, deviceId, preKeys); } public CompletableFuture storeKemOneTimePreKeys(final UUID identifier, final byte deviceId, - final List preKeys) { - return pqPreKeys.store(identifier, deviceId, preKeys); + final List preKeys) { + final boolean enrolledInPagedKeys = experimentEnrollmentManager.isEnrolled(identifier, PAGED_KEYS_EXPERIMENT_NAME); + final CompletableFuture deleteOtherKeys = enrolledInPagedKeys + ? pqPreKeys.delete(identifier, deviceId) + : pagedPqPreKeys.delete(identifier, deviceId); + return deleteOtherKeys.thenCompose(ignored -> enrolledInPagedKeys + ? pagedPqPreKeys.store(identifier, deviceId, preKeys) + : pqPreKeys.store(identifier, deviceId, preKeys)); + } public CompletableFuture> takeEC(final UUID identifier, final byte deviceId) { @@ -102,10 +120,36 @@ public class KeysManager { } public CompletableFuture> takePQ(final UUID identifier, final byte deviceId) { - return pqPreKeys.take(identifier, deviceId) + final boolean enrolledInPagedKeys = experimentEnrollmentManager.isEnrolled(identifier, PAGED_KEYS_EXPERIMENT_NAME); + return tagTakePQ(pagedPqPreKeys.take(identifier, deviceId), PQSource.PAGE, enrolledInPagedKeys) + .thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey + .map(ignored -> CompletableFuture.completedFuture(maybeSingleUsePreKey)) + .orElseGet(() -> tagTakePQ(pqPreKeys.take(identifier, deviceId), PQSource.ROW, enrolledInPagedKeys))) .thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey .map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey)) - .orElseGet(() -> pqLastResortKeys.find(identifier, deviceId))); + .orElseGet(() -> tagTakePQ(pqLastResortKeys.find(identifier, deviceId), PQSource.LAST_RESORT, enrolledInPagedKeys))); + } + + private enum PQSource { + PAGE, + ROW, + LAST_RESORT + } + private CompletableFuture> tagTakePQ(CompletableFuture> prekey, final PQSource source, final boolean enrolledInPagedKeys) { + return prekey.thenApply(maybeSingleUsePreKey -> { + final Optional maybeSourceTag = maybeSingleUsePreKey + // If we found a PK, use this source tag + .map(ignore -> source.name()) + // If we didn't and this is our last resort, we didn't find a PK + .or(() -> source == PQSource.LAST_RESORT ? Optional.of("absent") : Optional.empty()); + maybeSourceTag.ifPresent(sourceTag -> { + Metrics.counter(TAKE_PQ_NAME, + "source", sourceTag, + "enrolled", Boolean.toString(enrolledInPagedKeys)) + .increment(); + }); + return maybeSingleUsePreKey; + }); } public CompletableFuture> getLastResort(final UUID identifier, final byte deviceId) { @@ -121,20 +165,24 @@ public class KeysManager { } public CompletableFuture getPqCount(final UUID identifier, final byte deviceId) { - return pqPreKeys.getCount(identifier, deviceId); + return pagedPqPreKeys.getCount(identifier, deviceId).thenCompose(count -> count == 0 + ? pqPreKeys.getCount(identifier, deviceId) + : CompletableFuture.completedFuture(count)); } public CompletableFuture deleteSingleUsePreKeys(final UUID identifier) { return CompletableFuture.allOf( ecPreKeys.delete(identifier), - pqPreKeys.delete(identifier) + pqPreKeys.delete(identifier), + pagedPqPreKeys.delete(identifier) ); } public CompletableFuture deleteSingleUsePreKeys(final UUID accountUuid, final byte deviceId) { return CompletableFuture.allOf( ecPreKeys.delete(accountUuid, deviceId), - pqPreKeys.delete(accountUuid, deviceId) + pqPreKeys.delete(accountUuid, deviceId), + pagedPqPreKeys.delete(accountUuid, deviceId) ); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java index d15582993..9d0b6a9be 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java @@ -22,7 +22,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.function.Function; import java.util.stream.Collectors; - import org.signal.libsignal.protocol.InvalidKeyException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,7 +43,12 @@ import software.amazon.awssdk.services.dynamodb.model.QueryRequest; import software.amazon.awssdk.services.dynamodb.model.ReturnValue; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.model.*; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Request; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Object; /** * @implNote This version of a {@link SingleUsePreKeyStore} store bundles prekeys into "pages", which are stored in on diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index 2a015f887..8a6055ff3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -33,6 +33,7 @@ import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher; @@ -122,6 +123,9 @@ record CommandDependencies( new DynamicConfigurationManager<>( configuration.getDynamicConfig().build(awsCredentialsProvider, dynamicConfigurationExecutor), DynamicConfiguration.class); dynamicConfigurationManager.start(); + ExperimentEnrollmentManager experimentEnrollmentManager = + new ExperimentEnrollmentManager(dynamicConfigurationManager); + final ClientResources.Builder redisClientResourcesBuilder = ClientResources.builder(); FaultTolerantRedisClusterClient cacheCluster = configuration.getCacheClusterConfiguration() @@ -224,7 +228,8 @@ record CommandDependencies( new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName()), new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, - configuration.getDynamoDbTables().getKemLastResortKeys().getTableName())); + configuration.getDynamoDbTables().getKemLastResortKeys().getTableName()), + experimentEnrollmentManager); MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, configuration.getDynamoDbTables().getMessages().getTableName(), configuration.getDynamoDbTables().getMessages().getExpiration(), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java index c47411be8..7df1be6bf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java @@ -44,6 +44,7 @@ import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; @@ -66,6 +67,7 @@ public class AccountCreationDeletionIntegrationTest { DynamoDbExtensionSchema.Tables.USERNAMES, DynamoDbExtensionSchema.Tables.EC_KEYS, DynamoDbExtensionSchema.Tables.PQ_KEYS, + DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS, DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); @@ -105,7 +107,8 @@ public class AccountCreationDeletionIntegrationTest { new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()), new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, - DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName())); + DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()), + mock(ExperimentEnrollmentManager.class)); final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index d1869812b..944a266ba 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; @@ -60,6 +61,7 @@ class AccountsManagerChangeNumberIntegrationTest { Tables.USERNAMES, Tables.EC_KEYS, Tables.PQ_KEYS, + Tables.PAGED_PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); @@ -96,7 +98,8 @@ class AccountsManagerChangeNumberIntegrationTest { new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()), new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, - DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName())); + DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()), + mock(ExperimentEnrollmentManager.class)); final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index b04d93397..2d33d394a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -73,6 +73,7 @@ class AccountsManagerConcurrentModificationIntegrationTest { Tables.DELETED_ACCOUNTS, Tables.EC_KEYS, Tables.PQ_KEYS, + Tables.PAGED_PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java index add5d5b5c..c4c1cd84e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java @@ -38,6 +38,7 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.Mockito; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; @@ -73,6 +74,7 @@ class AccountsManagerUsernameIntegrationTest { Tables.PNI_ASSIGNMENTS, Tables.EC_KEYS, Tables.PQ_KEYS, + Tables.PAGED_PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); @@ -109,7 +111,8 @@ class AccountsManagerUsernameIntegrationTest { new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()), new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, - DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName())); + DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()), + mock(ExperimentEnrollmentManager.class)); accounts = Mockito.spy(new Accounts( Clock.systemUTC(), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java index c5f2d4f31..69970697e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java @@ -36,6 +36,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.DeviceInfo; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisServerExtension; @@ -62,6 +63,7 @@ public class AddRemoveDeviceIntegrationTest { DynamoDbExtensionSchema.Tables.USERNAMES, DynamoDbExtensionSchema.Tables.EC_KEYS, DynamoDbExtensionSchema.Tables.PQ_KEYS, + DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS, DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); @@ -104,7 +106,8 @@ public class AddRemoveDeviceIntegrationTest { new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()), new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, - DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName())); + DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()), + mock(ExperimentEnrollmentManager.class)); final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java index 0a98cb72f..0d1765341 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java @@ -8,18 +8,26 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.when; import java.util.List; import java.util.Optional; import java.util.UUID; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; @@ -27,10 +35,15 @@ import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; class KeysManagerTest { private KeysManager keysManager; + private ExperimentEnrollmentManager experimentEnrollmentManager; + + private SingleUseKEMPreKeyStore singleUseKEMPreKeyStore; + private PagedSingleUseKEMPreKeyStore pagedSingleUseKEMPreKeyStore; @RegisterExtension static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension( - Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); + Tables.EC_KEYS, Tables.PQ_KEYS, Tables.PAGED_PQ_KEYS, + Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); @RegisterExtension static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket"); @@ -43,15 +56,20 @@ class KeysManagerTest { @BeforeEach void setup() { final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(); + experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); + singleUseKEMPreKeyStore = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, Tables.PQ_KEYS.tableName()); + pagedSingleUseKEMPreKeyStore = new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient, + S3_EXTENSION.getS3Client(), + DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(), + S3_EXTENSION.getBucketName()); + keysManager = new KeysManager( new SingleUseECPreKeyStore(dynamoDbAsyncClient, Tables.EC_KEYS.tableName()), - new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, Tables.PQ_KEYS.tableName()), - new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient, - S3_EXTENSION.getS3Client(), - DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(), - S3_EXTENSION.getBucketName()), + singleUseKEMPreKeyStore, + pagedSingleUseKEMPreKeyStore, new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()), - new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName())); + new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()), + experimentEnrollmentManager); } @Test @@ -67,18 +85,58 @@ class KeysManagerTest { "Repeatedly storing same key should have no effect"); } - @Test - void storeKemOneTimePreKeys() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void storeKemOneTimePreKeysClearsOld(boolean inPagedExperiment) { + final List oldPreKeys = List.of(generateTestKEMSignedPreKey(1)); + + // Leave a key in the 'other' key store + (inPagedExperiment + ? singleUseKEMPreKeyStore.store(ACCOUNT_UUID, DEVICE_ID, oldPreKeys) + : pagedSingleUseKEMPreKeyStore.store(ACCOUNT_UUID, DEVICE_ID, oldPreKeys)) + .join(); + + when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME)) + .thenReturn(inPagedExperiment); + + + final List newPreKeys = List.of(generateTestKEMSignedPreKey(2)); + keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, newPreKeys).join(); + + final int expectedPagedKeyCount = inPagedExperiment ? 1 : 0; + final int expectedUnpagedKeyCount = 1 - expectedPagedKeyCount; + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(expectedPagedKeyCount, pagedSingleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(expectedUnpagedKeyCount, singleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join()); + + final KEMSignedPreKey key = keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join().orElseThrow(); + assertEquals(2, key.keyId()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void storeKemOneTimePreKeys(boolean inPagedExperiment) { assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), "Initial pre-key count for an account should be zero"); - keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); + when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME)) + .thenReturn(inPagedExperiment); + + final int expectedPagedKeyCount = inPagedExperiment ? 1 : 0; + final int expectedUnpagedKeyCount = 1 - expectedPagedKeyCount; keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(expectedPagedKeyCount, pagedSingleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(expectedUnpagedKeyCount, singleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join()); + + keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(expectedPagedKeyCount, pagedSingleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(expectedUnpagedKeyCount, singleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join()); } + @Test void storeEcSignedPreKeys() { assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isEmpty()); @@ -128,9 +186,24 @@ class KeysManagerTest { } @Test - void testDeleteSingleUsePreKeysByAccount() { + void takeWithExistingExperimentalKey() { + // Put a key in the new store, even though we're not in the experiment. This simulates a take when operating + // in mixed mode on experiment rollout + pagedSingleUseKEMPreKeyStore.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); + + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(1, keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join().orElseThrow().keyId()); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testDeleteSingleUsePreKeysByAccount(final boolean inPagedExperiment) { int keyId = 1; + when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME)) + .thenReturn(inPagedExperiment); + for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join(); keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join(); @@ -155,10 +228,14 @@ class KeysManagerTest { } } - @Test - void testDeleteSingleUsePreKeysByAccountAndDevice() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testDeleteSingleUsePreKeysByAccountAndDevice(final boolean inPagedExperiment) { int keyId = 1; + when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME)) + .thenReturn(inPagedExperiment); + for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join(); keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join();