Add paged prekey store

This commit is contained in:
Ravi Khadiwala 2025-05-20 10:47:45 -05:00 committed by ravi-signal
parent 6d8701665e
commit 2bb14892af
23 changed files with 1125 additions and 54 deletions

11
pom.xml
View File

@ -77,6 +77,10 @@
<slf4j.version>2.0.17</slf4j.version> <slf4j.version>2.0.17</slf4j.version>
<stripe.version>23.10.0</stripe.version> <stripe.version>23.10.0</stripe.version>
<swagger.version>2.2.27</swagger.version> <swagger.version>2.2.27</swagger.version>
<testcontainers.version>1.21.1</testcontainers.version>
<!-- image to use in tests that run localstack via docker. -->
<localstack.image>localstack/localstack:3.5.0</localstack.image>
<!-- eclipse-temurin:21.0.6_7-jre-jammy (note: always use the multi-arch manifest *LIST* here) --> <!-- eclipse-temurin:21.0.6_7-jre-jammy (note: always use the multi-arch manifest *LIST* here) -->
<docker.image.sha256>02fc89fa8766a9ba221e69225f8d1c10bb91885ddbd3c112448e23488ba40ab6</docker.image.sha256> <docker.image.sha256>02fc89fa8766a9ba221e69225f8d1c10bb91885ddbd3c112448e23488ba40ab6</docker.image.sha256>
@ -311,6 +315,13 @@
<artifactId>logback-access-common</artifactId> <artifactId>logback-access-common</artifactId>
<version>${logback-access.version}</version> <version>${logback-access.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers-bom</artifactId>
<version>${testcontainers.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>

View File

@ -138,6 +138,8 @@ dynamoDbTables:
tableName: Example_EC_Signed_Pre_Keys tableName: Example_EC_Signed_Pre_Keys
pqKeys: pqKeys:
tableName: Example_PQ_Keys tableName: Example_PQ_Keys
pagedPqKeys:
tableName: Example_PQ_Paged_Keys
pqLastResortKeys: pqLastResortKeys:
tableName: Example_PQ_Last_Resort_Keys tableName: Example_PQ_Last_Resort_Keys
messages: messages:
@ -174,6 +176,10 @@ dynamoDbTables:
verificationSessions: verificationSessions:
tableName: Example_VerificationSessions tableName: Example_VerificationSessions
pagedSingleUseKEMPreKeyStore:
bucket: preKeyBucket # S3 Bucket name
region: us-west-2 # AWS region
cacheCluster: # Redis server configuration for cache cluster cacheCluster: # Redis server configuration for cache cluster
configurationUri: redis://redis.example.com:6379/ configurationUri: redis://redis.example.com:6379/

View File

@ -485,6 +485,18 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>localstack</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>com.google.auth</groupId> <groupId>com.google.auth</groupId>
<artifactId>google-auth-library-oauth2-http</artifactId> <artifactId>google-auth-library-oauth2-http</artifactId>
@ -712,6 +724,9 @@
<configuration> <configuration>
<!-- add-opens: work around PATCH not being a supported method on HttpUrlConnection --> <!-- add-opens: work around PATCH not being a supported method on HttpUrlConnection -->
<argLine>-javaagent:${org.mockito:mockito-core:jar} --add-opens=java.base/java.net=ALL-UNNAMED</argLine> <argLine>-javaagent:${org.mockito:mockito-core:jar} --add-opens=java.base/java.net=ALL-UNNAMED</argLine>
<systemPropertyVariables>
<localstackImage>${localstack.image}</localstackImage>
</systemPropertyVariables>
</configuration> </configuration>
</plugin> </plugin>

View File

@ -45,6 +45,7 @@ import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalit
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
import org.whispersystems.textsecuregcm.configuration.NoiseTunnelConfiguration; import org.whispersystems.textsecuregcm.configuration.NoiseTunnelConfiguration;
import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration; import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration;
import org.whispersystems.textsecuregcm.configuration.PagedSingleUseKEMPreKeyStoreConfiguration;
import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration; import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.RegistrationServiceClientFactory; import org.whispersystems.textsecuregcm.configuration.RegistrationServiceClientFactory;
import org.whispersystems.textsecuregcm.configuration.RemoteConfigConfiguration; import org.whispersystems.textsecuregcm.configuration.RemoteConfigConfiguration;
@ -257,6 +258,11 @@ public class WhisperServerConfiguration extends Configuration {
@NotNull @NotNull
private OneTimeDonationConfiguration oneTimeDonations; private OneTimeDonationConfiguration oneTimeDonations;
@Valid
@JsonProperty
@NotNull
private PagedSingleUseKEMPreKeyStoreConfiguration pagedSingleUseKEMPreKeyStore;
@Valid @Valid
@NotNull @NotNull
@JsonProperty @JsonProperty
@ -478,6 +484,10 @@ public class WhisperServerConfiguration extends Configuration {
return oneTimeDonations; return oneTimeDonations;
} }
public PagedSingleUseKEMPreKeyStoreConfiguration getPagedSingleUseKEMPreKeyStore() {
return pagedSingleUseKEMPreKeyStore;
}
public ReportMessageConfiguration getReportMessageConfiguration() { public ReportMessageConfiguration getReportMessageConfiguration() {
return reportMessage; return reportMessage;
} }

View File

@ -225,6 +225,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.OneTimeDonationsManager; import org.whispersystems.textsecuregcm.storage.OneTimeDonationsManager;
import org.whispersystems.textsecuregcm.storage.PagedSingleUseKEMPreKeyStore;
import org.whispersystems.textsecuregcm.storage.PersistentTimer; import org.whispersystems.textsecuregcm.storage.PersistentTimer;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.Profiles; import org.whispersystems.textsecuregcm.storage.Profiles;
@ -235,8 +236,12 @@ import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswords;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.RemoteConfigs; import org.whispersystems.textsecuregcm.storage.RemoteConfigs;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager; import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.storage.RepeatedUseECSignedPreKeyStore;
import org.whispersystems.textsecuregcm.storage.RepeatedUseKEMSignedPreKeyStore;
import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb; import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.storage.SingleUseECPreKeyStore;
import org.whispersystems.textsecuregcm.storage.SingleUseKEMPreKeyStore;
import org.whispersystems.textsecuregcm.storage.SubscriptionManager; import org.whispersystems.textsecuregcm.storage.SubscriptionManager;
import org.whispersystems.textsecuregcm.storage.Subscriptions; import org.whispersystems.textsecuregcm.storage.Subscriptions;
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager; import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
@ -425,13 +430,21 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName()); config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient, Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName()); config.getDynamoDbTables().getProfiles().getTableName());
S3AsyncClient asyncKeysS3Client = S3AsyncClient.builder()
.credentialsProvider(awsCredentialsProvider)
.region(Region.of(config.getPagedSingleUseKEMPreKeyStore().region()))
.build();
KeysManager keysManager = new KeysManager( KeysManager keysManager = new KeysManager(
new SingleUseECPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getEcKeys().getTableName()),
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getKemKeys().getTableName()),
new PagedSingleUseKEMPreKeyStore(
dynamoDbAsyncClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getEcKeys().getTableName(), asyncKeysS3Client,
config.getDynamoDbTables().getKemKeys().getTableName(), config.getDynamoDbTables().getPagedKemKeys().getTableName(),
config.getDynamoDbTables().getEcSignedPreKeys().getTableName(), config.getPagedSingleUseKEMPreKeyStore().bucket()),
config.getDynamoDbTables().getKemLastResortKeys().getTableName() new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
); new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, config.getDynamoDbTables().getKemLastResortKeys().getTableName()));
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getMessages().getTableName(), config.getDynamoDbTables().getMessages().getTableName(),
config.getDynamoDbTables().getMessages().getExpiration(), config.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -60,6 +60,7 @@ public class DynamoDbTables {
private final Table ecSignedPreKeys; private final Table ecSignedPreKeys;
private final Table kemKeys; private final Table kemKeys;
private final Table kemLastResortKeys; private final Table kemLastResortKeys;
private final Table pagedKemKeys;
private final TableWithExpiration messages; private final TableWithExpiration messages;
private final TableWithExpiration onetimeDonations; private final TableWithExpiration onetimeDonations;
private final Table phoneNumberIdentifiers; private final Table phoneNumberIdentifiers;
@ -88,6 +89,7 @@ public class DynamoDbTables {
@JsonProperty("ecSignedPreKeys") final Table ecSignedPreKeys, @JsonProperty("ecSignedPreKeys") final Table ecSignedPreKeys,
@JsonProperty("pqKeys") final Table kemKeys, @JsonProperty("pqKeys") final Table kemKeys,
@JsonProperty("pqLastResortKeys") final Table kemLastResortKeys, @JsonProperty("pqLastResortKeys") final Table kemLastResortKeys,
@JsonProperty("pagedPqKeys") final Table pagedKemKeys,
@JsonProperty("messages") final TableWithExpiration messages, @JsonProperty("messages") final TableWithExpiration messages,
@JsonProperty("onetimeDonations") final TableWithExpiration onetimeDonations, @JsonProperty("onetimeDonations") final TableWithExpiration onetimeDonations,
@JsonProperty("phoneNumberIdentifiers") final Table phoneNumberIdentifiers, @JsonProperty("phoneNumberIdentifiers") final Table phoneNumberIdentifiers,
@ -114,6 +116,7 @@ public class DynamoDbTables {
this.ecKeys = ecKeys; this.ecKeys = ecKeys;
this.ecSignedPreKeys = ecSignedPreKeys; this.ecSignedPreKeys = ecSignedPreKeys;
this.kemKeys = kemKeys; this.kemKeys = kemKeys;
this.pagedKemKeys = pagedKemKeys;
this.kemLastResortKeys = kemLastResortKeys; this.kemLastResortKeys = kemLastResortKeys;
this.messages = messages; this.messages = messages;
this.onetimeDonations = onetimeDonations; this.onetimeDonations = onetimeDonations;
@ -202,6 +205,12 @@ public class DynamoDbTables {
return kemKeys; return kemKeys;
} }
@NotNull
@Valid
public Table getPagedKemKeys() {
return pagedKemKeys;
}
@NotNull @NotNull
@Valid @Valid
public Table getKemLastResortKeys() { public Table getKemLastResortKeys() {

View File

@ -0,0 +1,15 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
public record PagedSingleUseKEMPreKeyStoreConfiguration(
@NotBlank String bucket,
@NotBlank String region) {
}

View File

@ -0,0 +1,136 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.nio.ByteBuffer;
import java.util.List;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
class KEMPreKeyPage {
static final byte FORMAT = 1;
// Serialized pages start with a 4 byte magic constant, followed by 3 bytes of 0s and then the format byte
static final int HEADER_MAGIC = 0xC21C6DB8;
static final int HEADER_SIZE = 8;
// Serialize bigendian to produce the serialized page header
private static final long HEADER = ((long) HEADER_MAGIC) << 32L | (long) FORMAT;
// The length of libsignal's serialized KEM public key, which is a single-byte version followed by the public key
private static final int SERIALIZED_PUBKEY_LENGTH = 1569;
private static final int SERIALIZED_SIGNATURE_LENGTH = 64;
private static final int KEY_ID_LENGTH = Long.BYTES;
// The internal prefix byte libsignal uses to indicate a key is of type KEMKeyType.KYBER_1024. Currently, this
// is the only type of key allowed to be written to a prekey page
private static final byte KEM_KEY_TYPE_KYBER_1024 = 0x08;
@VisibleForTesting
static final int SERIALIZED_PREKEY_LENGTH = KEY_ID_LENGTH + SERIALIZED_PUBKEY_LENGTH + SERIALIZED_SIGNATURE_LENGTH;
private KEMPreKeyPage() {}
/**
* Serialize the list of preKeys into a single buffer
*
* @param format the format to serialize as. Currently, the only valid format is {@link KEMPreKeyPage#FORMAT}
* @param preKeys the preKeys to serialize
* @return The serialized buffer and a format to store alongside the buffer
*/
static ByteBuffer serialize(final byte format, final List<KEMSignedPreKey> preKeys) {
if (format != FORMAT) {
throw new IllegalArgumentException("Unknown format: " + format + ", must be " + FORMAT);
}
if (preKeys.isEmpty()) {
throw new IllegalArgumentException("PreKeys cannot be empty");
}
final ByteBuffer buffer = ByteBuffer.allocate(HEADER_SIZE + SERIALIZED_PREKEY_LENGTH * preKeys.size());
buffer.putLong(HEADER);
for (KEMSignedPreKey preKey : preKeys) {
buffer.putLong(preKey.keyId());
final byte[] publicKeyBytes = preKey.serializedPublicKey();
if (publicKeyBytes[0] != KEM_KEY_TYPE_KYBER_1024) {
// 0x08 is libsignal's current KEM key format. If some future version of libsignal supports additional KEM
// keys, we'll have to roll out read support before rolling out write support. Otherwise, we may write keys
// to storage that are not readable by other chat instances.
throw new IllegalArgumentException("Format 1 only supports " + KEM_KEY_TYPE_KYBER_1024 + " public keys");
}
if (publicKeyBytes.length != SERIALIZED_PUBKEY_LENGTH) {
throw new IllegalArgumentException("Unexpected public key length " + publicKeyBytes.length);
}
buffer.put(publicKeyBytes);
if (preKey.signature().length != SERIALIZED_SIGNATURE_LENGTH) {
throw new IllegalArgumentException("prekey signature length must be " + SERIALIZED_SIGNATURE_LENGTH);
}
buffer.put(preKey.signature());
}
buffer.flip();
return buffer;
}
/**
* Deserialize a single {@link KEMSignedPreKey}
*
* @param format The format of the page this buffer is from
* @param buffer The key to deserialize. The position of the buffer should be the start of the key, and the limit of
* the buffer should be the end of the key. After a successful deserialization the position of the
* buffer will be the limit
* @return The deserialized key
* @throws InvalidKeyException
*/
static KEMSignedPreKey deserializeKey(int format, ByteBuffer buffer) throws InvalidKeyException {
if (format != FORMAT) {
throw new IllegalArgumentException("Unknown prekey page format " + format);
}
if (buffer.remaining() != SERIALIZED_PREKEY_LENGTH) {
throw new IllegalArgumentException("PreKeys must be length " + SERIALIZED_PREKEY_LENGTH);
}
final long keyId = buffer.getLong();
final byte[] publicKeyBytes = new byte[SERIALIZED_PUBKEY_LENGTH];
buffer.get(publicKeyBytes);
final KEMPublicKey kemPublicKey = new KEMPublicKey(publicKeyBytes);
final byte[] signature = new byte[SERIALIZED_SIGNATURE_LENGTH];
buffer.get(signature);
return new KEMSignedPreKey(keyId, kemPublicKey, signature);
}
/**
* The location of a specific key within a serialized page
*/
record KeyLocation(int start, int length) {
int getStartInclusive() {
return start;
}
int getEndInclusive() {
return start + length - 1;
}
}
/**
* Get the location of the key at the provided index within a page
*
* @param format The format of the page
* @param index The index of the key to retrieve
* @return An {@link KeyLocation} indicating where within the page the key is
*/
static KeyLocation keyLocation(final int format, final int index) {
if (format != FORMAT) {
throw new IllegalArgumentException("unknown format " + format);
}
final int startOffset = HEADER_SIZE + (index * SERIALIZED_PREKEY_LENGTH);
return new KeyLocation(startOffset, SERIALIZED_PREKEY_LENGTH);
}
}

View File

@ -12,26 +12,27 @@ import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
public class KeysManager { public class KeysManager {
private final SingleUseECPreKeyStore ecPreKeys; private final SingleUseECPreKeyStore ecPreKeys;
private final SingleUseKEMPreKeyStore pqPreKeys; private final SingleUseKEMPreKeyStore pqPreKeys;
private final PagedSingleUseKEMPreKeyStore pagedPqPreKeys;
private final RepeatedUseECSignedPreKeyStore ecSignedPreKeys; private final RepeatedUseECSignedPreKeyStore ecSignedPreKeys;
private final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys; private final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys;
public KeysManager( public KeysManager(
final DynamoDbAsyncClient dynamoDbAsyncClient, final SingleUseECPreKeyStore ecPreKeys,
final String ecTableName, final SingleUseKEMPreKeyStore pqPreKeys,
final String pqTableName, final PagedSingleUseKEMPreKeyStore pagedPqPreKeys,
final String ecSignedPreKeysTableName, final RepeatedUseECSignedPreKeyStore ecSignedPreKeys,
final String pqLastResortTableName) { final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys) {
this.ecPreKeys = new SingleUseECPreKeyStore(dynamoDbAsyncClient, ecTableName); this.ecPreKeys = ecPreKeys;
this.pqPreKeys = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, pqTableName); this.pqPreKeys = pqPreKeys;
this.ecSignedPreKeys = new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, ecSignedPreKeysTableName); this.pagedPqPreKeys = pagedPqPreKeys;
this.pqLastResortKeys = new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, pqLastResortTableName); this.ecSignedPreKeys = ecSignedPreKeys;
this.pqLastResortKeys = pqLastResortKeys;
} }
public TransactWriteItem buildWriteItemForEcSignedPreKey(final UUID identifier, public TransactWriteItem buildWriteItemForEcSignedPreKey(final UUID identifier,

View File

@ -0,0 +1,367 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
/**
* @implNote This version of a {@link SingleUsePreKeyStore} store bundles prekeys into "pages", which are stored in on
* an object store and referenced via dynamodb. Each device may only have a single active page at a time. Crashes or
* errors may leave orphaned pages which are no longer referenced by the database. A background process must
* periodically check for orphaned pages and remove them.
* @see SingleUsePreKeyStore
*/
public class PagedSingleUseKEMPreKeyStore {
private static final Logger log = LoggerFactory.getLogger(PagedSingleUseKEMPreKeyStore.class);
private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final S3AsyncClient s3AsyncClient;
private final String tableName;
private final String bucketName;
private final Timer getKeyCountTimer = Metrics.timer(name(getClass(), "getCount"));
private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch"));
private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice"));
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary
.builder(name(getClass(), "availableKeyCount"))
.publishPercentileHistogram()
.register(Metrics.globalRegistry);
private final String takeKeyTimerName = name(getClass(), "takeKey");
private static final String KEY_PRESENT_TAG_NAME = "keyPresent";
static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID = "D";
static final String ATTR_PAGE_ID = "ID";
static final String ATTR_PAGE_IDX = "I";
static final String ATTR_PAGE_NUM_KEYS = "N";
static final String ATTR_PAGE_FORMAT_VERSION = "F";
public PagedSingleUseKEMPreKeyStore(
final DynamoDbAsyncClient dynamoDbAsyncClient,
final S3AsyncClient s3AsyncClient,
final String tableName,
final String bucketName) {
this.s3AsyncClient = s3AsyncClient;
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName;
this.bucketName = bucketName;
}
/**
* Stores a batch of single-use pre-keys for a specific device. All previously-stored keys for the device are cleared
* before storing new keys.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @param preKeys a collection of single-use pre-keys to store for the target device
* @return a future that completes when all previously-stored keys have been removed and the given collection of
* pre-keys has been stored in its place
*/
public CompletableFuture<Void> store(
final UUID identifier, final byte deviceId, final List<KEMSignedPreKey> preKeys) {
final Timer.Sample sample = Timer.start();
final List<KEMSignedPreKey> sorted = preKeys.stream().sorted(Comparator.comparing(KEMSignedPreKey::keyId)).toList();
final int bundleFormat = KEMPreKeyPage.FORMAT;
final ByteBuffer bundle = KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, sorted);
// Write the bundle to S3, then update the database. Delete the S3 object that was in the database before. This can
// leave orphans in S3 if we fail to update after writing to S3, or fail to delete the old page. However, it can
// never leave a broken pointer in the database. To keep this invariant, we must make sure to generate a new
// name for the page any time we were to retry this entire operation.
return writeBundleToS3(identifier, deviceId, bundle)
.thenCompose(pageId -> dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId),
ATTR_PAGE_ID, AttributeValues.fromUUID(pageId),
ATTR_PAGE_IDX, AttributeValues.fromInt(0),
ATTR_PAGE_NUM_KEYS, AttributeValues.fromInt(sorted.size()),
ATTR_PAGE_FORMAT_VERSION, AttributeValues.fromInt(bundleFormat)
))
.returnValues(ReturnValue.ALL_OLD)
.build()))
.thenCompose(response -> {
if (response.hasAttributes()) {
final UUID pageId = AttributeValues.getUUID(response.attributes(), ATTR_PAGE_ID, null);
if (pageId == null) {
log.error("Replaced record: {} with no pageId", response.attributes());
return CompletableFuture.completedFuture(null);
}
return deleteBundleFromS3(identifier, deviceId, pageId);
} else {
return CompletableFuture.completedFuture(null);
}
})
.whenComplete((result, error) -> sample.stop(storeKeyBatchTimer));
}
/**
* Attempts to retrieve a single-use pre-key for a specific device. Keys may only be returned by this method at most
* once; once the key is returned, it is removed from the key store and subsequent calls to this method will never
* return the same key.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @return a future that yields a single-use pre-key if one is available or empty if no single-use pre-keys are
* available for the target device
*/
public CompletableFuture<Optional<KEMSignedPreKey>> take(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)))
.updateExpression("SET #index = #index + :one")
.conditionExpression("#id = :id AND #index < #numkeys")
.expressionAttributeNames(Map.of(
"#id", KEY_ACCOUNT_UUID,
"#index", ATTR_PAGE_IDX,
"#numkeys", ATTR_PAGE_NUM_KEYS))
.expressionAttributeValues(Map.of(
":one", AttributeValues.n(1),
":id", AttributeValues.fromUUID(identifier)))
.returnValues(ReturnValue.ALL_OLD)
.build())
.thenCompose(updateItemResponse -> {
if (!updateItemResponse.hasAttributes()) {
throw new IllegalStateException("update succeeded but did not return an item");
}
final int index = AttributeValues.getInt(updateItemResponse.attributes(), ATTR_PAGE_IDX, -1);
final UUID pageId = AttributeValues.getUUID(updateItemResponse.attributes(), ATTR_PAGE_ID, null);
final int format = AttributeValues.getInt(updateItemResponse.attributes(), ATTR_PAGE_FORMAT_VERSION, -1);
if (index < 0 || format < 0 || pageId == null) {
throw new CompletionException(
new IOException("unexpected page descriptor " + updateItemResponse.attributes()));
}
return readPreKeyAtIndexFromS3(identifier, deviceId, pageId, format, index).thenApply(Optional::of);
})
// If this check fails, it means that the item did not exist, or its index was already at the last key. Either
// way, there are no keys left so we return empty
.exceptionally(ExceptionUtils.exceptionallyHandler(
ConditionalCheckFailedException.class,
e -> Optional.empty()))
.whenComplete((maybeKey, throwable) ->
sample.stop(Metrics.timer(
takeKeyTimerName,
KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent()))));
}
/**
* Returns the number of single-use pre-keys available for a given device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @return a future that yields the approximate number of single-use pre-keys currently available for the target
* device
*/
public CompletableFuture<Integer> getCount(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)))
.consistentRead(true)
.projectionExpression("#total, #index")
.expressionAttributeNames(Map.of(
"#total", ATTR_PAGE_NUM_KEYS,
"#index", ATTR_PAGE_IDX))
.build())
.thenApply(getResponse -> {
if (!getResponse.hasItem()) {
return 0;
}
final int numKeys = AttributeValues.getInt(getResponse.item(), ATTR_PAGE_NUM_KEYS, -1);
final int index = AttributeValues.getInt(getResponse.item(), ATTR_PAGE_IDX, -1);
if (numKeys < 0 || index < 0 || index > numKeys) {
log.error("unexpected index/length in page descriptor: {}", getResponse.item());
return 0;
}
return numKeys - index;
})
.whenComplete((keyCount, throwable) -> {
sample.stop(getKeyCountTimer);
if (throwable == null && keyCount != null) {
availableKeyCountDistributionSummary.record(keyCount);
}
});
}
/**
* Removes all single-use pre-keys for all devices associated with the given account/identity.
*
* @param identifier the identifier for the account/identity for which to remove single-use pre-keys
* @return a future that completes when all single-use pre-keys have been removed for all devices associated with the
* given account/identity
*/
public CompletableFuture<Void> delete(final UUID identifier) {
final Timer.Sample sample = Timer.start();
return deleteItems(identifier, Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid")
.projectionExpression("#uuid,#deviceid,#pageid")
.expressionAttributeNames(Map.of(
"#uuid", KEY_ACCOUNT_UUID,
"#deviceid", KEY_DEVICE_ID,
"#pageid", ATTR_PAGE_ID))
.expressionAttributeValues(Map.of(":uuid", AttributeValues.fromUUID(identifier)))
.consistentRead(true)
.build())
.items()))
.thenRun(() -> sample.stop(deleteForAccountTimer));
}
/**
* Removes all single-use pre-keys for a specific device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @return a future that completes when all single-use pre-keys have been removed for the target device
*/
public CompletableFuture<Void> delete(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)))
.consistentRead(true)
.projectionExpression("#uuid,#deviceid,#pageid")
.expressionAttributeNames(Map.of(
"#uuid", KEY_ACCOUNT_UUID,
"#deviceid", KEY_DEVICE_ID,
"#pageid", ATTR_PAGE_ID))
.build())
.thenCompose(getItemResponse -> deleteItems(identifier, getItemResponse.hasItem()
? Flux.just(getItemResponse.item())
: Flux.empty()))
.thenRun(() -> sample.stop(deleteForDeviceTimer));
}
private CompletableFuture<Void> deleteItems(final UUID identifier,
final Flux<Map<String, AttributeValue>> items) {
return items
.flatMap(item -> {
final UUID aci = AttributeValues.getUUID(item, KEY_ACCOUNT_UUID, null);
final byte deviceId = (byte) AttributeValues.getInt(item, KEY_DEVICE_ID, -1);
final UUID pageId = AttributeValues.getUUID(item, ATTR_PAGE_ID, null);
if (aci == null || deviceId < 0 || pageId == null) {
log.error("can't delete page from unexpected page descriptor {}", item);
}
return Mono.fromFuture(deleteBundleFromS3(aci, deviceId, pageId))
.thenReturn(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)));
})
.flatMap(itemToDelete -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(DeleteItemRequest.builder()
.tableName(tableName)
.key(itemToDelete)
.build())))
.then()
.toFuture()
.thenRun(Util.NOOP);
}
private static String s3Key(final UUID identifier, final byte deviceId, final UUID pageId) {
return String.format("%s/%s/%s", identifier, deviceId, pageId);
}
private CompletableFuture<UUID> writeBundleToS3(final UUID identifier, final byte deviceId,
final ByteBuffer bundle) {
final UUID pageId = UUID.randomUUID();
return s3AsyncClient.putObject(PutObjectRequest.builder()
.bucket(bucketName)
.key(s3Key(identifier, deviceId, pageId)).build(),
AsyncRequestBody.fromByteBuffer(bundle))
.thenApply(ignoredResponse -> pageId);
}
private CompletableFuture<Void> deleteBundleFromS3(final UUID identifier, final byte deviceId, final UUID pageId) {
return s3AsyncClient.deleteObject(DeleteObjectRequest.builder()
.bucket(bucketName)
.key(s3Key(identifier, deviceId, pageId))
.build())
.thenRun(Util.NOOP);
}
private CompletableFuture<KEMSignedPreKey> readPreKeyAtIndexFromS3(
final UUID identifier, final byte deviceId, final UUID pageId, final int format, final int index) {
final KEMPreKeyPage.KeyLocation keyLocation = KEMPreKeyPage.keyLocation(format, index);
return s3AsyncClient.getObject(GetObjectRequest.builder()
.bucket(bucketName)
.key(s3Key(identifier, deviceId, pageId))
// An RFC9110 range header, inclusive on both ends
// https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2
.range("bytes=%s-%s".formatted(keyLocation.getStartInclusive(), keyLocation.getEndInclusive()))
.build(), AsyncResponseTransformer.toBytes())
.thenApply(bytes -> {
final ByteBuffer serialized = bytes.asByteBuffer();
if (serialized.remaining() != keyLocation.length()) {
log.error("Unexpected ranged read response, requested {} got {} for offset {} in page {}",
keyLocation.length(), serialized.remaining(), keyLocation, s3Key(identifier, deviceId, pageId));
throw new CompletionException(new IOException("Invalid response to ranged read"));
}
try {
return KEMPreKeyPage.deserializeKey(format, bytes.asByteBuffer());
} catch (InvalidKeyException e) {
throw new CompletionException(new IOException(e));
}
});
}
}

View File

@ -19,7 +19,7 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<ECPreKey> { public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<ECPreKey> {
private static final String PARSE_BYTE_ARRAY_COUNTER_NAME = name(SingleUseECPreKeyStore.class, "parseByteArray"); private static final String PARSE_BYTE_ARRAY_COUNTER_NAME = name(SingleUseECPreKeyStore.class, "parseByteArray");
protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { public SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName); super(dynamoDbAsyncClient, tableName);
} }

View File

@ -16,7 +16,7 @@ import java.util.UUID;
public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore<KEMSignedPreKey> { public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore<KEMSignedPreKey> {
protected SingleUseKEMPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { public SingleUseKEMPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName); super(dynamoDbAsyncClient, tableName);
} }

View File

@ -57,13 +57,18 @@ import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PagedSingleUseKEMPreKeyStore;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.Profiles; import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswords; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswords;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.RepeatedUseECSignedPreKeyStore;
import org.whispersystems.textsecuregcm.storage.RepeatedUseKEMSignedPreKeyStore;
import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb; import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.storage.SingleUseECPreKeyStore;
import org.whispersystems.textsecuregcm.storage.SingleUseKEMPreKeyStore;
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt; import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers; import reactor.core.scheduler.Schedulers;
@ -204,13 +209,20 @@ record CommandDependencies(
configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName()); configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient, Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName()); configuration.getDynamoDbTables().getProfiles().getTableName());
S3AsyncClient asyncKeysS3Client = S3AsyncClient.builder()
.credentialsProvider(awsCredentialsProvider)
.region(Region.of(configuration.getPagedSingleUseKEMPreKeyStore().region()))
.build();
KeysManager keys = new KeysManager( KeysManager keys = new KeysManager(
dynamoDbAsyncClient, new SingleUseECPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getEcKeys().getTableName()),
configuration.getDynamoDbTables().getEcKeys().getTableName(), new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getKemKeys().getTableName()),
configuration.getDynamoDbTables().getKemKeys().getTableName(), new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient, asyncKeysS3Client,
configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName(), configuration.getDynamoDbTables().getPagedKemKeys().getTableName(),
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName() configuration.getPagedSingleUseKEMPreKeyStore().bucket()),
); new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName()));
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(), configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(), configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -50,6 +50,7 @@ import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
public class AccountCreationDeletionIntegrationTest { public class AccountCreationDeletionIntegrationTest {
@ -71,6 +72,9 @@ public class AccountCreationDeletionIntegrationTest {
@RegisterExtension @RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault()); private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private ScheduledExecutorService executor; private ScheduledExecutorService executor;
@ -90,13 +94,18 @@ public class AccountCreationDeletionIntegrationTest {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
keysManager = new KeysManager( keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), new SingleUseECPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()),
DynamoDbExtensionSchema.Tables.EC_KEYS.tableName(), new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName()),
DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName(), new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(), S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName() DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
); S3_EXTENSION.getBucketName()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()));
final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName());

View File

@ -44,6 +44,7 @@ import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
class AccountsManagerChangeNumberIntegrationTest { class AccountsManagerChangeNumberIntegrationTest {
@ -65,6 +66,9 @@ class AccountsManagerChangeNumberIntegrationTest {
@RegisterExtension @RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
private KeysManager keysManager; private KeysManager keysManager;
private DisconnectionRequestManager disconnectionRequestManager; private DisconnectionRequestManager disconnectionRequestManager;
private ScheduledExecutorService executor; private ScheduledExecutorService executor;
@ -81,13 +85,18 @@ class AccountsManagerChangeNumberIntegrationTest {
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
keysManager = new KeysManager( keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), new SingleUseECPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()),
Tables.EC_KEYS.tableName(), new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName()),
Tables.PQ_KEYS.tableName(), new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(), S3_EXTENSION.getS3Client(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName() DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
); S3_EXTENSION.getBucketName()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()));
final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName());

View File

@ -47,6 +47,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
@ -78,6 +79,9 @@ class AccountsManagerUsernameIntegrationTest {
@RegisterExtension @RegisterExtension
static RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); static RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
private AccountsManager accountsManager; private AccountsManager accountsManager;
private Accounts accounts; private Accounts accounts;
@ -94,13 +98,18 @@ class AccountsManagerUsernameIntegrationTest {
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
final KeysManager keysManager = new KeysManager( final KeysManager keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), new SingleUseECPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()),
Tables.EC_KEYS.tableName(), new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName()),
Tables.PQ_KEYS.tableName(), new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(), S3_EXTENSION.getS3Client(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName() DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
); S3_EXTENSION.getBucketName()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()));
accounts = Mockito.spy(new Accounts( accounts = Mockito.spy(new Accounts(
Clock.systemUTC(), Clock.systemUTC(),

View File

@ -45,6 +45,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestClock;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
public class AddRemoveDeviceIntegrationTest { public class AddRemoveDeviceIntegrationTest {
@ -70,6 +71,9 @@ public class AddRemoveDeviceIntegrationTest {
@RegisterExtension @RegisterExtension
static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build(); static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build();
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
private ExecutorService accountLockExecutor; private ExecutorService accountLockExecutor;
private ScheduledExecutorService messagePollExecutor; private ScheduledExecutorService messagePollExecutor;
@ -89,13 +93,18 @@ public class AddRemoveDeviceIntegrationTest {
clock = TestClock.pinned(Instant.now()); clock = TestClock.pinned(Instant.now());
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
keysManager = new KeysManager( keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), new SingleUseECPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()),
DynamoDbExtensionSchema.Tables.EC_KEYS.tableName(), new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName()),
DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName(), new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(), S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName() DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
); S3_EXTENSION.getBucketName()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()));
final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName());

View File

@ -143,6 +143,20 @@ public final class DynamoDbExtensionSchema {
.build()), .build()),
List.of(), List.of()), List.of(), List.of()),
PAGED_PQ_KEYS("paged_pq_keys_test",
PagedSingleUseKEMPreKeyStore.KEY_ACCOUNT_UUID,
PagedSingleUseKEMPreKeyStore.KEY_DEVICE_ID,
List.of(
AttributeDefinition.builder()
.attributeName(PagedSingleUseKEMPreKeyStore.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(PagedSingleUseKEMPreKeyStore.KEY_DEVICE_ID)
.attributeType(ScalarAttributeType.N)
.build()),
List.of(), List.of()),
PUSH_NOTIFICATION_EXPERIMENT_SAMPLES("push_notification_experiment_samples_test", PUSH_NOTIFICATION_EXPERIMENT_SAMPLES("push_notification_experiment_samples_test",
PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME, PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME,
PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID, PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID,

View File

@ -0,0 +1,102 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
class KEMPreKeyPageTest {
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@Test
void serializeSinglePreKey() {
final ByteBuffer page = KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, List.of(generatePreKey(5)));
final int actualMagic = page.getInt();
assertEquals(KEMPreKeyPage.HEADER_MAGIC, actualMagic);
final int version = page.getInt();
assertEquals(version, 1);
assertEquals(KEMPreKeyPage.SERIALIZED_PREKEY_LENGTH, page.remaining());
}
@Test
void emptyPreKeys() {
assertThrows(IllegalArgumentException.class, () -> KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, Collections.emptyList()));
}
@Test
void roundTripSingleton() throws InvalidKeyException {
final KEMSignedPreKey preKey = generatePreKey(5);
final ByteBuffer buffer = KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, List.of(preKey));
final long serializedLength = buffer.remaining();
assertEquals(KEMPreKeyPage.HEADER_SIZE + KEMPreKeyPage.SERIALIZED_PREKEY_LENGTH, serializedLength);
final KEMPreKeyPage.KeyLocation keyLocation = KEMPreKeyPage.keyLocation(1, 0);
assertEquals(KEMPreKeyPage.HEADER_SIZE, keyLocation.getStartInclusive());
assertEquals(serializedLength, KEMPreKeyPage.HEADER_SIZE + keyLocation.length());
buffer.position(keyLocation.getStartInclusive());
final KEMSignedPreKey deserializedPreKey = KEMPreKeyPage.deserializeKey(1, buffer);
assertEquals(5L, deserializedPreKey.keyId());
assertEquals(preKey, deserializedPreKey);
}
@Test
void roundTripMultiple() throws InvalidKeyException {
final List<KEMSignedPreKey> keys = Arrays.asList(generatePreKey(1), generatePreKey(2), generatePreKey(5));
final ByteBuffer page = KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, keys);
assertEquals(KEMPreKeyPage.HEADER_SIZE + KEMPreKeyPage.SERIALIZED_PREKEY_LENGTH * 3, page.remaining());
for (int i = 0; i < keys.size(); i++) {
final KEMPreKeyPage.KeyLocation keyLocation = KEMPreKeyPage.keyLocation(1, i);
assertEquals(
KEMPreKeyPage.HEADER_SIZE + KEMPreKeyPage.SERIALIZED_PREKEY_LENGTH * i,
keyLocation.getStartInclusive());
final ByteBuffer buf = page.slice(keyLocation.getStartInclusive(), keyLocation.length());
final KEMSignedPreKey actual = KEMPreKeyPage.deserializeKey(1, buf);
assertEquals(keys.get(i), actual);
}
}
@Test
void wrongFormat() {
assertThrows(IllegalArgumentException.class, () ->
KEMPreKeyPage.deserializeKey(2,
ByteBuffer.allocate(KEMPreKeyPage.HEADER_SIZE + KEMPreKeyPage.SERIALIZED_PREKEY_LENGTH)));
}
@Test
void wrongSize() {
assertThrows(IllegalArgumentException.class, () -> KEMPreKeyPage.deserializeKey(1, ByteBuffer.allocate(100)));
}
@Test
void negativeKeyId() throws InvalidKeyException {
final KEMSignedPreKey preKey = generatePreKey(-1);
ByteBuffer page = KEMPreKeyPage.serialize(KEMPreKeyPage.FORMAT, List.of(preKey));
page.position(KEMPreKeyPage.HEADER_SIZE);
KEMSignedPreKey deserializedPreKey = KEMPreKeyPage.deserializeKey(1, page);
assertEquals(-1L, deserializedPreKey.keyId());
}
private static KEMSignedPreKey generatePreKey(long keyId) {
return KeysHelper.signedKEMPreKey((int) keyId, IDENTITY_KEY_PAIR);
}
}

View File

@ -22,6 +22,7 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
class KeysManagerTest { class KeysManagerTest {
@ -31,6 +32,9 @@ class KeysManagerTest {
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension( static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension("testbucket");
private static final UUID ACCOUNT_UUID = UUID.randomUUID(); private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final byte DEVICE_ID = 1; private static final byte DEVICE_ID = 1;
@ -38,13 +42,16 @@ class KeysManagerTest {
@BeforeEach @BeforeEach
void setup() { void setup() {
final DynamoDbAsyncClient dynamoDbAsyncClient = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient();
keysManager = new KeysManager( keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), new SingleUseECPreKeyStore(dynamoDbAsyncClient, Tables.EC_KEYS.tableName()),
Tables.EC_KEYS.tableName(), new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, Tables.PQ_KEYS.tableName()),
Tables.PQ_KEYS.tableName(), new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(), S3_EXTENSION.getS3Client(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName() DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
); S3_EXTENSION.getBucketName()),
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName()),
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName()));
} }
@Test @Test

View File

@ -0,0 +1,218 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.testcontainers.containers.localstack.LocalStackContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.S3Object;
class PagedSingleUseKEMPreKeyStoreTest {
private static final int KEY_COUNT = 100;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private static final String BUCKET_NAME = "testbucket";
private PagedSingleUseKEMPreKeyStore keyStore;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS);
@RegisterExtension
static final S3LocalStackExtension S3_EXTENSION = new S3LocalStackExtension(BUCKET_NAME);
@BeforeEach
void setUp() {
keyStore = new PagedSingleUseKEMPreKeyStore(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
S3_EXTENSION.getS3Client(),
DynamoDbExtensionSchema.Tables.PAGED_PQ_KEYS.tableName(),
BUCKET_NAME);
}
@Test
void storeTake() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
assertEquals(Optional.empty(), keyStore.take(accountIdentifier, deviceId).join());
final List<KEMSignedPreKey> preKeys = generateRandomPreKeys();
assertDoesNotThrow(() -> keyStore.store(accountIdentifier, deviceId, preKeys).join());
final List<KEMSignedPreKey> sortedPreKeys = preKeys.stream()
.sorted(Comparator.comparing(preKey -> preKey.keyId()))
.toList();
assertEquals(Optional.of(sortedPreKeys.get(0)), keyStore.take(accountIdentifier, deviceId).join());
assertEquals(Optional.of(sortedPreKeys.get(1)), keyStore.take(accountIdentifier, deviceId).join());
}
@Test
void storeTwice() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
final List<KEMSignedPreKey> preKeys1 = generateRandomPreKeys();
keyStore.store(accountIdentifier, deviceId, preKeys1).join();
List<String> oldPages = listPages(accountIdentifier).stream().map(S3Object::key).collect(Collectors.toList());
assertEquals(1, oldPages.size());
final List<KEMSignedPreKey> preKeys2 = generateRandomPreKeys();
keyStore.store(accountIdentifier, deviceId, preKeys2).join();
List<String> newPages = listPages(accountIdentifier).stream().map(S3Object::key).collect(Collectors.toList());
assertEquals(1, newPages.size());
assertNotEquals(oldPages.getFirst(), newPages.getFirst());
assertEquals(
preKeys2.stream().sorted(Comparator.comparing(preKey -> preKey.keyId())).toList(),
IntStream.range(0, preKeys2.size())
.mapToObj(i -> keyStore.take(accountIdentifier, deviceId).join())
.map(Optional::orElseThrow)
.toList());
assertTrue(keyStore.take(accountIdentifier, deviceId).join().isEmpty());
}
@Test
void takeAll() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
final List<KEMSignedPreKey> preKeys = generateRandomPreKeys();
assertDoesNotThrow(() -> keyStore.store(accountIdentifier, deviceId, preKeys).join());
final List<KEMSignedPreKey> sortedPreKeys = preKeys.stream()
.sorted(Comparator.comparing(preKey -> preKey.keyId()))
.toList();
for (int i = 0; i < KEY_COUNT; i++) {
assertEquals(Optional.of(sortedPreKeys.get(i)), keyStore.take(accountIdentifier, deviceId).join());
}
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
assertTrue(keyStore.take(accountIdentifier, deviceId).join().isEmpty());
}
@Test
void getCount() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
final List<KEMSignedPreKey> preKeys = generateRandomPreKeys();
keyStore.store(accountIdentifier, deviceId, preKeys).join();
assertEquals(KEY_COUNT, keyStore.getCount(accountIdentifier, deviceId).join());
for (int i = 0; i < KEY_COUNT; i++) {
keyStore.take(accountIdentifier, deviceId).join();
assertEquals(KEY_COUNT - (i + 1), keyStore.getCount(accountIdentifier, deviceId).join());
}
}
@Test
void deleteSingleDevice() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
assertDoesNotThrow(() -> keyStore.delete(accountIdentifier, deviceId).join());
final List<KEMSignedPreKey> preKeys = generateRandomPreKeys();
keyStore.store(accountIdentifier, deviceId, preKeys).join();
keyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
assertDoesNotThrow(() -> keyStore.delete(accountIdentifier, deviceId).join());
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(KEY_COUNT, keyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join());
final List<S3Object> pages = listPages(accountIdentifier);
assertEquals(1, pages.size());
assertTrue(pages.get(0).key().startsWith("%s/%s".formatted(accountIdentifier, deviceId + 1)));
}
@Test
void deleteAllDevices() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
assertDoesNotThrow(() -> keyStore.delete(accountIdentifier).join());
final List<KEMSignedPreKey> preKeys = generateRandomPreKeys();
keyStore.store(accountIdentifier, deviceId, preKeys).join();
keyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
assertDoesNotThrow(() -> keyStore.delete(accountIdentifier).join());
assertEquals(0, keyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(0, keyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join());
assertEquals(0, listPages(accountIdentifier).size());
}
private List<S3Object> listPages(final UUID identifier) {
return Flux.from(S3_EXTENSION.getS3Client().listObjectsV2Paginator(ListObjectsV2Request.builder()
.bucket(BUCKET_NAME)
.prefix(identifier.toString())
.build()))
.concatMap(response -> Flux.fromIterable(response.contents()))
.collectList()
.block();
}
private List<KEMSignedPreKey> generateRandomPreKeys() {
final Set<Integer> keyIds = new HashSet<>(KEY_COUNT);
while (keyIds.size() < KEY_COUNT) {
keyIds.add(Math.abs(ThreadLocalRandom.current().nextInt()));
}
return keyIds.stream()
.map(keyId -> KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR))
.toList();
}
}

View File

@ -0,0 +1,93 @@
/*
* Copyright 2021-2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3;
import java.util.Objects;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.testcontainers.containers.localstack.LocalStackContainer;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
import software.amazon.awssdk.services.s3.model.DeleteBucketRequest;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
@Testcontainers
public class S3LocalStackExtension implements BeforeEachCallback, AfterEachCallback, BeforeAllCallback,
AfterAllCallback {
private final static DockerImageName LOCAL_STACK_IMAGE =
DockerImageName.parse(Objects.requireNonNull(
System.getProperty("localstackImage"),
"Local stack image not found; must provide localstackImage system property"));
private static LocalStackContainer LOCAL_STACK = new LocalStackContainer(LOCAL_STACK_IMAGE).withServices(S3);
private final String bucketName;
private S3AsyncClient s3Client;
public S3LocalStackExtension(final String bucketName) {
this.bucketName = bucketName;
}
@Override
public void afterEach(ExtensionContext context) {
Flux.from(s3Client.listObjectsV2Paginator(ListObjectsV2Request.builder()
.bucket(bucketName)
.build())
.contents())
.flatMap(obj -> Mono.fromFuture(() -> s3Client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucketName)
.key(obj.key())
.build())), 100)
.then()
.block();
s3Client.deleteBucket(DeleteBucketRequest.builder().bucket(bucketName).build()).join();
}
@Override
public void beforeEach(ExtensionContext context) throws Exception {
s3Client.createBucket(CreateBucketRequest.builder().bucket(bucketName).build()).join();
}
public S3AsyncClient getS3Client() {
return s3Client;
}
@Override
public void afterAll(final ExtensionContext context) throws Exception {
s3Client.close();
LOCAL_STACK.close();
}
@Override
public void beforeAll(final ExtensionContext context) throws Exception {
LOCAL_STACK.start();
s3Client = S3AsyncClient.builder()
.endpointOverride(LOCAL_STACK.getEndpoint())
.credentialsProvider(StaticCredentialsProvider
.create(AwsBasicCredentials.create(LOCAL_STACK.getAccessKey(), LOCAL_STACK.getSecretKey())))
.region(Region.of(LOCAL_STACK.getRegion()))
.build();
}
public String getBucketName() {
return bucketName;
}
}

View File

@ -135,6 +135,8 @@ dynamoDbTables:
tableName: repeated_use_signed_ec_pre_keys_test tableName: repeated_use_signed_ec_pre_keys_test
pqKeys: pqKeys:
tableName: pq_keys_test tableName: pq_keys_test
pagedPqKeys:
tableName: paged_pq_keys_test
pqLastResortKeys: pqLastResortKeys:
tableName: repeated_use_signed_kem_pre_keys_test tableName: repeated_use_signed_kem_pre_keys_test
messages: messages:
@ -171,6 +173,10 @@ dynamoDbTables:
verificationSessions: verificationSessions:
tableName: verification_sessions_test tableName: verification_sessions_test
pagedSingleUseKEMPreKeyStore:
bucket: preKeyBucket # S3 Bucket name
region: us-west-2 # AWS region
cacheCluster: # Redis server configuration for cache cluster cacheCluster: # Redis server configuration for cache cluster
type: local type: local