diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java index 536884179..6984d94ab 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java @@ -24,11 +24,16 @@ public class SingleUseECPreKeyStore extends SingleUsePreKeyStore { } @Override - protected Map getItemFromPreKey(final UUID identifier, final byte deviceId, final ECPreKey preKey) { + protected Map getItemFromPreKey(final UUID identifier, + final byte deviceId, + final ECPreKey preKey, + final int remainingKeys) { + return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(identifier), KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.keyId()), - ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.serializedPublicKey())); + ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.serializedPublicKey()), + ATTR_REMAINING_KEYS, AttributeValues.fromInt(remainingKeys)); } @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java index b373d0e57..de485705c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java @@ -21,12 +21,17 @@ public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore getItemFromPreKey(final UUID identifier, final byte deviceId, final KEMSignedPreKey signedPreKey) { + protected Map getItemFromPreKey(final UUID identifier, + final byte deviceId, + final KEMSignedPreKey signedPreKey, + final int remainingKeys) { + return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(identifier), KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.keyId()), ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.serializedPublicKey()), - ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.signature())); + ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.signature()), + ATTR_REMAINING_KEYS, AttributeValues.fromInt(remainingKeys)); } @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java index 305f80147..1251b39f2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java @@ -8,16 +8,19 @@ package org.whispersystems.textsecuregcm.storage; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import static org.whispersystems.textsecuregcm.storage.AbstractDynamoDbStore.DYNAMO_DB_MAX_BATCH_SIZE; +import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.nio.ByteBuffer; import java.time.Duration; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.util.AttributeValues; @@ -49,9 +52,10 @@ public abstract class SingleUsePreKeyStore> { private final DynamoDbAsyncClient dynamoDbAsyncClient; private final String tableName; + private final String getKeyCountTimerName = name(getClass(), "getCount"); + private final Timer storeKeyTimer = Metrics.timer(name(getClass(), "storeKey")); private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch")); - private final Timer getKeyCountTimer = Metrics.timer(name(getClass(), "getCount")); private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice")); private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount")); @@ -74,6 +78,7 @@ public abstract class SingleUsePreKeyStore> { static final String KEY_DEVICE_ID_KEY_ID = "DK"; static final String ATTR_PUBLIC_KEY = "P"; static final String ATTR_SIGNATURE = "S"; + static final String ATTR_REMAINING_KEYS = "R"; protected SingleUsePreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { this.dynamoDbAsyncClient = dynamoDbAsyncClient; @@ -97,18 +102,22 @@ public abstract class SingleUsePreKeyStore> { return Mono.fromFuture(() -> delete(identifier, deviceId)) .thenMany( Flux.fromIterable(preKeys) - .flatMap(preKey -> Mono.fromFuture(() -> store(identifier, deviceId, preKey)), DYNAMO_DB_MAX_BATCH_SIZE)) + .sort(Comparator.comparing(preKey -> preKey.keyId())) + .zipWith(Flux.range(0, preKeys.size()).map(i -> preKeys.size() - i)) + .flatMap(preKeyAndRemainingCount -> Mono.fromFuture(() -> + store(identifier, deviceId, preKeyAndRemainingCount.getT1(), preKeyAndRemainingCount.getT2())), + DYNAMO_DB_MAX_BATCH_SIZE)) .then() .toFuture() .thenRun(() -> sample.stop(storeKeyBatchTimer)); } - private CompletableFuture store(final UUID identifier, final byte deviceId, final K preKey) { + private CompletableFuture store(final UUID identifier, final byte deviceId, final K preKey, final int remainingKeys) { final Timer.Sample sample = Timer.start(); return dynamoDbAsyncClient.putItem(PutItemRequest.builder() .tableName(tableName) - .item(getItemFromPreKey(identifier, deviceId, preKey)) + .item(getItemFromPreKey(identifier, deviceId, preKey, remainingKeys)) .build()) .thenRun(() -> sample.stop(storeKeyTimer)); } @@ -172,6 +181,56 @@ public abstract class SingleUsePreKeyStore> { public CompletableFuture getCount(final UUID identifier, final byte deviceId) { final Timer.Sample sample = Timer.start(); + final AtomicBoolean countFromPeek = new AtomicBoolean(true); + + return peekCount(identifier, deviceId) + .thenCompose(maybeCount -> maybeCount + .map(CompletableFuture::completedFuture) + // Older key sets may not have a pre-calculated pre-key count; take the less efficient approach of counting + // items instead + .orElseGet(() -> { + countFromPeek.set(false); + return scanCount(identifier, deviceId); + })) + .whenComplete((keyCount, throwable) -> { + sample.stop(Metrics.timer(getKeyCountTimerName, "method", countFromPeek.get() ? "peek" : "scan")); + + if (throwable == null && keyCount != null) { + availableKeyCountDistributionSummary.record(keyCount); + } + }); + } + + @VisibleForTesting + CompletableFuture> peekCount(final UUID identifier, final byte deviceId) { + return dynamoDbAsyncClient.query(QueryRequest.builder() + .tableName(tableName) + .consistentRead(false) + .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") + .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) + .expressionAttributeValues(Map.of( + ":uuid", getPartitionKey(identifier), + ":sortprefix", getSortKeyPrefix(deviceId))) + .projectionExpression(ATTR_REMAINING_KEYS) + .limit(1) + .build()) + .thenApply(response -> { + if (response.count() > 0) { + final Map item = response.items().getFirst(); + + if (item.containsKey(ATTR_REMAINING_KEYS)) { + return Optional.of(Integer.parseInt(item.get(ATTR_REMAINING_KEYS).n())); + } else { + return Optional.empty(); + } + } else { + return Optional.of(0); + } + }); + } + + @VisibleForTesting + CompletableFuture scanCount(final UUID identifier, final byte deviceId) { // Getting an accurate count from DynamoDB can be very confusing. See: // // - https://github.com/aws/aws-sdk-java/issues/693 @@ -189,14 +248,7 @@ public abstract class SingleUsePreKeyStore> { .build())) .map(QueryResponse::count) .reduce(0, Integer::sum) - .toFuture() - .whenComplete((keyCount, throwable) -> { - sample.stop(getKeyCountTimer); - - if (throwable == null && keyCount != null) { - availableKeyCountDistributionSummary.record(keyCount); - } - }); + .toFuture(); } /** @@ -280,8 +332,10 @@ public abstract class SingleUsePreKeyStore> { return AttributeValues.fromByteBuffer(byteBuffer.flip()); } - protected abstract Map getItemFromPreKey(final UUID identifier, final byte deviceId, - final K preKey); + protected abstract Map getItemFromPreKey(final UUID identifier, + final byte deviceId, + final K preKey, + final int remainingKeys); protected abstract K getPreKeyFromItem(final Map item); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java index 52044e4f9..06eabf0ae 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java @@ -9,6 +9,12 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.RegisterExtension; import org.signal.libsignal.protocol.ecc.Curve; import org.whispersystems.textsecuregcm.entities.ECPreKey; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; +import software.amazon.awssdk.services.dynamodb.paginators.ScanIterable; +import java.util.Map; class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest { @@ -32,4 +38,24 @@ class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest { protected ECPreKey generatePreKey(final long keyId) { return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey()); } + + @Override + protected void clearKeyCountAttributes() { + final ScanIterable scanIterable = DYNAMO_DB_EXTENSION.getDynamoDbClient().scanPaginator(ScanRequest.builder() + .tableName(DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()) + .build()); + + for (final ScanResponse response : scanIterable) { + for (final Map item : response.items()) { + + DYNAMO_DB_EXTENSION.getDynamoDbClient().updateItem(UpdateItemRequest.builder() + .tableName(DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()) + .key(Map.of( + SingleUsePreKeyStore.KEY_ACCOUNT_UUID, item.get(SingleUsePreKeyStore.KEY_ACCOUNT_UUID), + SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID, item.get(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID))) + .updateExpression("REMOVE " + SingleUsePreKeyStore.ATTR_REMAINING_KEYS) + .build()); + } + } + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java index e21685f3c..f4d0336f0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java @@ -11,6 +11,12 @@ import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; +import software.amazon.awssdk.services.dynamodb.paginators.ScanIterable; +import java.util.Map; class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest { @@ -36,4 +42,24 @@ class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest item : response.items()) { + + DYNAMO_DB_EXTENSION.getDynamoDbClient().updateItem(UpdateItemRequest.builder() + .tableName(DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName()) + .key(Map.of( + SingleUsePreKeyStore.KEY_ACCOUNT_UUID, item.get(SingleUsePreKeyStore.KEY_ACCOUNT_UUID), + SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID, item.get(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID))) + .updateExpression("REMOVE " + SingleUsePreKeyStore.ATTR_REMAINING_KEYS) + .build()); + } + } + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java index 08693ce48..d28dd4f5b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java @@ -9,10 +9,16 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.entities.PreKey; abstract class SingleUsePreKeyStoreTest> { @@ -23,6 +29,8 @@ abstract class SingleUsePreKeyStoreTest> { protected abstract K generatePreKey(final long keyId); + protected abstract void clearKeyCountAttributes(); + @Test void storeTake() { final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); @@ -32,20 +40,22 @@ abstract class SingleUsePreKeyStoreTest> { assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join()); - final List preKeys = new ArrayList<>(KEY_COUNT); + final List sortedPreKeys; + { + final List preKeys = generateRandomPreKeys(); + assertDoesNotThrow(() -> preKeyStore.store(accountIdentifier, deviceId, preKeys).join()); - for (int i = 0; i < KEY_COUNT; i++) { - preKeys.add(generatePreKey(i)); + sortedPreKeys = new ArrayList<>(preKeys); + sortedPreKeys.sort(Comparator.comparing(preKey -> preKey.keyId())); } - assertDoesNotThrow(() -> preKeyStore.store(accountIdentifier, deviceId, preKeys).join()); - - assertEquals(Optional.of(preKeys.get(0)), preKeyStore.take(accountIdentifier, deviceId).join()); - assertEquals(Optional.of(preKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join()); + assertEquals(Optional.of(sortedPreKeys.get(0)), preKeyStore.take(accountIdentifier, deviceId).join()); + assertEquals(Optional.of(sortedPreKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join()); } - @Test - void getCount() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void getCount(final boolean hasKeyCountAttribute) { final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); final UUID accountIdentifier = UUID.randomUUID(); @@ -53,15 +63,72 @@ abstract class SingleUsePreKeyStoreTest> { assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); - final List preKeys = new ArrayList<>(KEY_COUNT); - - for (int i = 0; i < KEY_COUNT; i++) { - preKeys.add(generatePreKey(i)); - } + final List preKeys = generateRandomPreKeys(); preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); + if (!hasKeyCountAttribute) { + clearKeyCountAttributes(); + } + assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join()); + + for (int i = 0; i < KEY_COUNT; i++) { + preKeyStore.take(accountIdentifier, deviceId).join(); + assertEquals(KEY_COUNT - (i + 1), preKeyStore.getCount(accountIdentifier, deviceId).join()); + } + } + + @Test + void peekCount() { + final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = 1; + + assertEquals(Optional.of(0), preKeyStore.peekCount(accountIdentifier, deviceId).join()); + + final List preKeys = generateRandomPreKeys(); + + preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); + + assertEquals(Optional.of(KEY_COUNT), preKeyStore.peekCount(accountIdentifier, deviceId).join()); + + for (int i = 0; i < KEY_COUNT; i++) { + preKeyStore.take(accountIdentifier, deviceId).join(); + assertEquals(Optional.of(KEY_COUNT - (i + 1)), preKeyStore.peekCount(accountIdentifier, deviceId).join()); + } + + preKeyStore.store(accountIdentifier, deviceId, List.of(generatePreKey(KEY_COUNT + 1))).join(); + clearKeyCountAttributes(); + + assertEquals(Optional.empty(), preKeyStore.peekCount(accountIdentifier, deviceId).join()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void scanCount(final boolean hasKeyCountAttribute) { + final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = 1; + + assertEquals(0, preKeyStore.scanCount(accountIdentifier, deviceId).join()); + + final List preKeys = generateRandomPreKeys(); + + preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); + + if (!hasKeyCountAttribute) { + clearKeyCountAttributes(); + } + + assertEquals(KEY_COUNT, preKeyStore.scanCount(accountIdentifier, deviceId).join()); + + for (int i = 0; i < KEY_COUNT; i++) { + preKeyStore.take(accountIdentifier, deviceId).join(); + assertEquals(KEY_COUNT - (i + 1), preKeyStore.scanCount(accountIdentifier, deviceId).join()); + } } @Test @@ -74,11 +141,7 @@ abstract class SingleUsePreKeyStoreTest> { assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join()); - final List preKeys = new ArrayList<>(KEY_COUNT); - - for (int i = 0; i < KEY_COUNT; i++) { - preKeys.add(generatePreKey(i)); - } + final List preKeys = generateRandomPreKeys(); preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join(); @@ -99,11 +162,7 @@ abstract class SingleUsePreKeyStoreTest> { assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join()); - final List preKeys = new ArrayList<>(KEY_COUNT); - - for (int i = 0; i < KEY_COUNT; i++) { - preKeys.add(generatePreKey(i)); - } + final List preKeys = generateRandomPreKeys(); preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join(); @@ -113,4 +172,16 @@ abstract class SingleUsePreKeyStoreTest> { assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); assertEquals(0, preKeyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join()); } + + private List generateRandomPreKeys() { + final Set keyIds = new HashSet<>(KEY_COUNT); + + while (keyIds.size() < KEY_COUNT) { + keyIds.add(Math.abs(ThreadLocalRandom.current().nextInt())); + } + + return keyIds.stream() + .map(this::generatePreKey) + .toList(); + } }