Always read from new and old PQ prekey stores, add experiment to start writing to new prekey store

This commit is contained in:
ravi-signal 2025-07-09 09:17:17 -05:00 committed by GitHub
parent 80c11e7eda
commit c9f21d5970
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 184 additions and 36 deletions

View File

@ -368,6 +368,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
MetricsUtil.configureRegistries(config, environment, dynamicConfigurationManager); MetricsUtil.configureRegistries(config, environment, dynamicConfigurationManager);
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(dynamicConfigurationManager);
if (config.getServerFactory() instanceof DefaultServerFactory defaultServerFactory) { if (config.getServerFactory() instanceof DefaultServerFactory defaultServerFactory) {
defaultServerFactory.getApplicationConnectors() defaultServerFactory.getApplicationConnectors()
.forEach(connectorFactory -> { .forEach(connectorFactory -> {
@ -444,7 +446,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPagedKemKeys().getTableName(), config.getDynamoDbTables().getPagedKemKeys().getTableName(),
config.getPagedSingleUseKEMPreKeyStore().bucket()), config.getPagedSingleUseKEMPreKeyStore().bucket()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getEcSignedPreKeys().getTableName()), new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getKemLastResortKeys().getTableName())); new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getKemLastResortKeys().getTableName()),
experimentEnrollmentManager);
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getMessages().getTableName(), config.getDynamoDbTables().getMessages().getTableName(),
config.getDynamoDbTables().getMessages().getExpiration(), config.getDynamoDbTables().getMessages().getExpiration(),
@ -604,8 +607,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ExternalServiceCredentialsGenerator svr2CredentialsGenerator = SecureValueRecovery2Controller.credentialsGenerator( ExternalServiceCredentialsGenerator svr2CredentialsGenerator = SecureValueRecovery2Controller.credentialsGenerator(
config.getSvr2Configuration()); config.getSvr2Configuration());
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(
dynamicConfigurationManager);
RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager =
new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords); new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords);
UsernameHashZkProofVerifier usernameHashZkProofVerifier = new UsernameHashZkProofVerifier(); UsernameHashZkProofVerifier usernameHashZkProofVerifier = new UsernameHashZkProofVerifier();

View File

@ -5,7 +5,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import java.time.Instant; import io.micrometer.core.instrument.Metrics;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -13,6 +13,8 @@ import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
@ -23,18 +25,25 @@ public class KeysManager {
private final PagedSingleUseKEMPreKeyStore pagedPqPreKeys; private final PagedSingleUseKEMPreKeyStore pagedPqPreKeys;
private final RepeatedUseECSignedPreKeyStore ecSignedPreKeys; private final RepeatedUseECSignedPreKeyStore ecSignedPreKeys;
private final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys; private final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
public static String PAGED_KEYS_EXPERIMENT_NAME = "pagedPreKeys";
private static final String TAKE_PQ_NAME = MetricsUtil.name(KeysManager.class, "takePq");
public KeysManager( public KeysManager(
final SingleUseECPreKeyStore ecPreKeys, final SingleUseECPreKeyStore ecPreKeys,
final SingleUseKEMPreKeyStore pqPreKeys, final SingleUseKEMPreKeyStore pqPreKeys,
final PagedSingleUseKEMPreKeyStore pagedPqPreKeys, final PagedSingleUseKEMPreKeyStore pagedPqPreKeys,
final RepeatedUseECSignedPreKeyStore ecSignedPreKeys, final RepeatedUseECSignedPreKeyStore ecSignedPreKeys,
final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys) { final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.ecPreKeys = ecPreKeys; this.ecPreKeys = ecPreKeys;
this.pqPreKeys = pqPreKeys; this.pqPreKeys = pqPreKeys;
this.pagedPqPreKeys = pagedPqPreKeys; this.pagedPqPreKeys = pagedPqPreKeys;
this.ecSignedPreKeys = ecSignedPreKeys; this.ecSignedPreKeys = ecSignedPreKeys;
this.pqLastResortKeys = pqLastResortKeys; this.pqLastResortKeys = pqLastResortKeys;
this.experimentEnrollmentManager = experimentEnrollmentManager;
} }
public TransactWriteItem buildWriteItemForEcSignedPreKey(final UUID identifier, public TransactWriteItem buildWriteItemForEcSignedPreKey(final UUID identifier,
@ -79,22 +88,31 @@ public class KeysManager {
); );
} }
public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final byte deviceId, final ECSignedPreKey ecSignedPreKey) { public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final byte deviceId,
final ECSignedPreKey ecSignedPreKey) {
return ecSignedPreKeys.store(identifier, deviceId, ecSignedPreKey); return ecSignedPreKeys.store(identifier, deviceId, ecSignedPreKey);
} }
public CompletableFuture<Void> storePqLastResort(final UUID identifier, final byte deviceId, final KEMSignedPreKey lastResortKey) { public CompletableFuture<Void> storePqLastResort(final UUID identifier, final byte deviceId,
final KEMSignedPreKey lastResortKey) {
return pqLastResortKeys.store(identifier, deviceId, lastResortKey); return pqLastResortKeys.store(identifier, deviceId, lastResortKey);
} }
public CompletableFuture<Void> storeEcOneTimePreKeys(final UUID identifier, final byte deviceId, public CompletableFuture<Void> storeEcOneTimePreKeys(final UUID identifier, final byte deviceId,
final List<ECPreKey> preKeys) { final List<ECPreKey> preKeys) {
return ecPreKeys.store(identifier, deviceId, preKeys); return ecPreKeys.store(identifier, deviceId, preKeys);
} }
public CompletableFuture<Void> storeKemOneTimePreKeys(final UUID identifier, final byte deviceId, public CompletableFuture<Void> storeKemOneTimePreKeys(final UUID identifier, final byte deviceId,
final List<KEMSignedPreKey> preKeys) { final List<KEMSignedPreKey> preKeys) {
return pqPreKeys.store(identifier, deviceId, preKeys); final boolean enrolledInPagedKeys = experimentEnrollmentManager.isEnrolled(identifier, PAGED_KEYS_EXPERIMENT_NAME);
final CompletableFuture<Void> deleteOtherKeys = enrolledInPagedKeys
? pqPreKeys.delete(identifier, deviceId)
: pagedPqPreKeys.delete(identifier, deviceId);
return deleteOtherKeys.thenCompose(ignored -> enrolledInPagedKeys
? pagedPqPreKeys.store(identifier, deviceId, preKeys)
: pqPreKeys.store(identifier, deviceId, preKeys));
} }
public CompletableFuture<Optional<ECPreKey>> takeEC(final UUID identifier, final byte deviceId) { public CompletableFuture<Optional<ECPreKey>> takeEC(final UUID identifier, final byte deviceId) {
@ -102,10 +120,36 @@ public class KeysManager {
} }
public CompletableFuture<Optional<KEMSignedPreKey>> takePQ(final UUID identifier, final byte deviceId) { public CompletableFuture<Optional<KEMSignedPreKey>> takePQ(final UUID identifier, final byte deviceId) {
return pqPreKeys.take(identifier, deviceId) final boolean enrolledInPagedKeys = experimentEnrollmentManager.isEnrolled(identifier, PAGED_KEYS_EXPERIMENT_NAME);
return tagTakePQ(pagedPqPreKeys.take(identifier, deviceId), PQSource.PAGE, enrolledInPagedKeys)
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(ignored -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
.orElseGet(() -> tagTakePQ(pqPreKeys.take(identifier, deviceId), PQSource.ROW, enrolledInPagedKeys)))
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey .thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey)) .map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
.orElseGet(() -> pqLastResortKeys.find(identifier, deviceId))); .orElseGet(() -> tagTakePQ(pqLastResortKeys.find(identifier, deviceId), PQSource.LAST_RESORT, enrolledInPagedKeys)));
}
private enum PQSource {
PAGE,
ROW,
LAST_RESORT
}
private CompletableFuture<Optional<KEMSignedPreKey>> tagTakePQ(CompletableFuture<Optional<KEMSignedPreKey>> prekey, final PQSource source, final boolean enrolledInPagedKeys) {
return prekey.thenApply(maybeSingleUsePreKey -> {
final Optional<String> maybeSourceTag = maybeSingleUsePreKey
// If we found a PK, use this source tag
.map(ignore -> source.name())
// If we didn't and this is our last resort, we didn't find a PK
.or(() -> source == PQSource.LAST_RESORT ? Optional.of("absent") : Optional.empty());
maybeSourceTag.ifPresent(sourceTag -> {
Metrics.counter(TAKE_PQ_NAME,
"source", sourceTag,
"enrolled", Boolean.toString(enrolledInPagedKeys))
.increment();
});
return maybeSingleUsePreKey;
});
} }
public CompletableFuture<Optional<KEMSignedPreKey>> getLastResort(final UUID identifier, final byte deviceId) { public CompletableFuture<Optional<KEMSignedPreKey>> getLastResort(final UUID identifier, final byte deviceId) {
@ -121,20 +165,24 @@ public class KeysManager {
} }
public CompletableFuture<Integer> getPqCount(final UUID identifier, final byte deviceId) { public CompletableFuture<Integer> getPqCount(final UUID identifier, final byte deviceId) {
return pqPreKeys.getCount(identifier, deviceId); return pagedPqPreKeys.getCount(identifier, deviceId).thenCompose(count -> count == 0
? pqPreKeys.getCount(identifier, deviceId)
: CompletableFuture.completedFuture(count));
} }
public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID identifier) { public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID identifier) {
return CompletableFuture.allOf( return CompletableFuture.allOf(
ecPreKeys.delete(identifier), ecPreKeys.delete(identifier),
pqPreKeys.delete(identifier) pqPreKeys.delete(identifier),
pagedPqPreKeys.delete(identifier)
); );
} }
public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID accountUuid, final byte deviceId) { public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID accountUuid, final byte deviceId) {
return CompletableFuture.allOf( return CompletableFuture.allOf(
ecPreKeys.delete(accountUuid, deviceId), ecPreKeys.delete(accountUuid, deviceId),
pqPreKeys.delete(accountUuid, deviceId) pqPreKeys.delete(accountUuid, deviceId),
pagedPqPreKeys.delete(accountUuid, deviceId)
); );
} }

View File

@ -22,7 +22,6 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.InvalidKeyException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -44,7 +43,12 @@ import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue; import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.*; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Object;
/** /**
* @implNote This version of a {@link SingleUsePreKeyStore} store bundles prekeys into "pages", which are stored in on * @implNote This version of a {@link SingleUsePreKeyStore} store bundles prekeys into "pages", which are stored in on

View File

@ -33,6 +33,7 @@ import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller; import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher; import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
@ -122,6 +123,9 @@ record CommandDependencies(
new DynamicConfigurationManager<>( new DynamicConfigurationManager<>(
configuration.getDynamicConfig().build(awsCredentialsProvider, dynamicConfigurationExecutor), DynamicConfiguration.class); configuration.getDynamicConfig().build(awsCredentialsProvider, dynamicConfigurationExecutor), DynamicConfiguration.class);
dynamicConfigurationManager.start(); dynamicConfigurationManager.start();
ExperimentEnrollmentManager experimentEnrollmentManager =
new ExperimentEnrollmentManager(dynamicConfigurationManager);
final ClientResources.Builder redisClientResourcesBuilder = ClientResources.builder(); final ClientResources.Builder redisClientResourcesBuilder = ClientResources.builder();
FaultTolerantRedisClusterClient cacheCluster = configuration.getCacheClusterConfiguration() FaultTolerantRedisClusterClient cacheCluster = configuration.getCacheClusterConfiguration()
@ -224,7 +228,8 @@ record CommandDependencies(
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName()), configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName())); configuration.getDynamoDbTables().getKemLastResortKeys().getTableName()),
experimentEnrollmentManager);
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(), configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(), configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -44,6 +44,7 @@ import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
@ -66,6 +67,7 @@ public class AccountCreationDeletionIntegrationTest {
DynamoDbExtensionSchema.Tables.USERNAMES, DynamoDbExtensionSchema.Tables.USERNAMES,
DynamoDbExtensionSchema.Tables.EC_KEYS, DynamoDbExtensionSchema.Tables.EC_KEYS,
DynamoDbExtensionSchema.Tables.PQ_KEYS, DynamoDbExtensionSchema.Tables.PQ_KEYS,
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@ -105,7 +107,8 @@ public class AccountCreationDeletionIntegrationTest {
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()), DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName())); DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()),
mock(ExperimentEnrollmentManager.class));
final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName());

View File

@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
@ -60,6 +61,7 @@ class AccountsManagerChangeNumberIntegrationTest {
Tables.USERNAMES, Tables.USERNAMES,
Tables.EC_KEYS, Tables.EC_KEYS,
Tables.PQ_KEYS, Tables.PQ_KEYS,
Tables.PAGED_PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@ -96,7 +98,8 @@ class AccountsManagerChangeNumberIntegrationTest {
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()), DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName())); DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()),
mock(ExperimentEnrollmentManager.class));
final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName());

View File

@ -73,6 +73,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
Tables.DELETED_ACCOUNTS, Tables.DELETED_ACCOUNTS,
Tables.EC_KEYS, Tables.EC_KEYS,
Tables.PQ_KEYS, Tables.PQ_KEYS,
Tables.PAGED_PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);

View File

@ -38,6 +38,7 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@ -73,6 +74,7 @@ class AccountsManagerUsernameIntegrationTest {
Tables.PNI_ASSIGNMENTS, Tables.PNI_ASSIGNMENTS,
Tables.EC_KEYS, Tables.EC_KEYS,
Tables.PQ_KEYS, Tables.PQ_KEYS,
Tables.PAGED_PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@ -109,7 +111,8 @@ class AccountsManagerUsernameIntegrationTest {
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()), DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName())); DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()),
mock(ExperimentEnrollmentManager.class));
accounts = Mockito.spy(new Accounts( accounts = Mockito.spy(new Accounts(
Clock.systemUTC(), Clock.systemUTC(),

View File

@ -36,6 +36,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension; import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
@ -62,6 +63,7 @@ public class AddRemoveDeviceIntegrationTest {
DynamoDbExtensionSchema.Tables.USERNAMES, DynamoDbExtensionSchema.Tables.USERNAMES,
DynamoDbExtensionSchema.Tables.EC_KEYS, DynamoDbExtensionSchema.Tables.EC_KEYS,
DynamoDbExtensionSchema.Tables.PQ_KEYS, DynamoDbExtensionSchema.Tables.PQ_KEYS,
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@ -104,7 +106,8 @@ public class AddRemoveDeviceIntegrationTest {
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()), DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName())); DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()),
mock(ExperimentEnrollmentManager.class));
final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName());

View File

@ -8,18 +8,26 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.when;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
@ -27,10 +35,15 @@ import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
class KeysManagerTest { class KeysManagerTest {
private KeysManager keysManager; private KeysManager keysManager;
private ExperimentEnrollmentManager experimentEnrollmentManager;
private SingleUseKEMPreKeyStore singleUseKEMPreKeyStore;
private PagedSingleUseKEMPreKeyStore pagedSingleUseKEMPreKeyStore;
@RegisterExtension @RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension( static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); Tables.EC_KEYS, Tables.PQ_KEYS, Tables.PAGED_PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension @RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket"); static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
@ -43,15 +56,20 @@ class KeysManagerTest {
@BeforeEach @BeforeEach
void setup() { void setup() {
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(); final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
singleUseKEMPreKeyStore = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, Tables.PQ_KEYS.tableName());
pagedSingleUseKEMPreKeyStore = new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
S3_EXTENSION.getBucketName());
keysManager = new KeysManager( keysManager = new KeysManager(
new SingleUseECPreKeyStore(dynamoDbAsyncClient, Tables.EC_KEYS.tableName()), new SingleUseECPreKeyStore(dynamoDbAsyncClient, Tables.EC_KEYS.tableName()),
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, Tables.PQ_KEYS.tableName()), singleUseKEMPreKeyStore,
new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient, pagedSingleUseKEMPreKeyStore,
S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
S3_EXTENSION.getBucketName()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()), new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName())); new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()),
experimentEnrollmentManager);
} }
@Test @Test
@ -67,18 +85,58 @@ class KeysManagerTest {
"Repeatedly storing same key should have no effect"); "Repeatedly storing same key should have no effect");
} }
@Test @ParameterizedTest
void storeKemOneTimePreKeys() { @ValueSource(booleans = {true, false})
void storeKemOneTimePreKeysClearsOld(boolean inPagedExperiment) {
final List<KEMSignedPreKey> oldPreKeys = List.of(generateTestKEMSignedPreKey(1));
// Leave a key in the 'other' key store
(inPagedExperiment
? singleUseKEMPreKeyStore.store(ACCOUNT_UUID, DEVICE_ID, oldPreKeys)
: pagedSingleUseKEMPreKeyStore.store(ACCOUNT_UUID, DEVICE_ID, oldPreKeys))
.join();
when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME))
.thenReturn(inPagedExperiment);
final List<KEMSignedPreKey> newPreKeys = List.of(generateTestKEMSignedPreKey(2));
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, newPreKeys).join();
final int expectedPagedKeyCount = inPagedExperiment ? 1 : 0;
final int expectedUnpagedKeyCount = 1 - expectedPagedKeyCount;
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedPagedKeyCount, pagedSingleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedUnpagedKeyCount, singleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
final KEMSignedPreKey key = keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join().orElseThrow();
assertEquals(2, key.keyId());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void storeKemOneTimePreKeys(boolean inPagedExperiment) {
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Initial pre-key count for an account should be zero"); "Initial pre-key count for an account should be zero");
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME))
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); .thenReturn(inPagedExperiment);
final int expectedPagedKeyCount = inPagedExperiment ? 1 : 0;
final int expectedUnpagedKeyCount = 1 - expectedPagedKeyCount;
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedPagedKeyCount, pagedSingleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedUnpagedKeyCount, singleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedPagedKeyCount, pagedSingleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(expectedUnpagedKeyCount, singleUseKEMPreKeyStore.getCount(ACCOUNT_UUID, DEVICE_ID).join());
} }
@Test @Test
void storeEcSignedPreKeys() { void storeEcSignedPreKeys() {
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isEmpty()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isEmpty());
@ -128,9 +186,24 @@ class KeysManagerTest {
} }
@Test @Test
void testDeleteSingleUsePreKeysByAccount() { void takeWithExistingExperimentalKey() {
// Put a key in the new store, even though we're not in the experiment. This simulates a take when operating
// in mixed mode on experiment rollout
pagedSingleUseKEMPreKeyStore.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join().orElseThrow().keyId());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testDeleteSingleUsePreKeysByAccount(final boolean inPagedExperiment) {
int keyId = 1; int keyId = 1;
when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME))
.thenReturn(inPagedExperiment);
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join(); keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join(); keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join();
@ -155,10 +228,14 @@ class KeysManagerTest {
} }
} }
@Test @ParameterizedTest
void testDeleteSingleUsePreKeysByAccountAndDevice() { @ValueSource(booleans = {true, false})
void testDeleteSingleUsePreKeysByAccountAndDevice(final boolean inPagedExperiment) {
int keyId = 1; int keyId = 1;
when(experimentEnrollmentManager.isEnrolled(ACCOUNT_UUID, KeysManager.PAGED_KEYS_EXPERIMENT_NAME))
.thenReturn(inPagedExperiment);
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join(); keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join(); keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join();