From 796dce3cd31c124468d6540c94e96ac7f109b552 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 1 Apr 2024 13:29:17 -0400 Subject: [PATCH] Always use the "peek" strategy for counting one-time pre-keys --- .../storage/SingleUsePreKeyStore.java | 71 +++++-------------- .../storage/SingleUsePreKeyStoreTest.java | 60 +--------------- 2 files changed, 21 insertions(+), 110 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java index 0aa2af10a..bbf428dfa 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java @@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.storage; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import static org.whispersystems.textsecuregcm.storage.AbstractDynamoDbStore.DYNAMO_DB_MAX_BATCH_SIZE; -import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; @@ -20,7 +20,6 @@ import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.util.AttributeValues; @@ -33,9 +32,7 @@ import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; import software.amazon.awssdk.services.dynamodb.model.QueryRequest; -import software.amazon.awssdk.services.dynamodb.model.QueryResponse; import software.amazon.awssdk.services.dynamodb.model.ReturnValue; -import software.amazon.awssdk.services.dynamodb.model.Select; /** * A single-use pre-key store stores single-use pre-keys of a specific type. Keys returned by a single-use pre-key @@ -52,13 +49,14 @@ public abstract class SingleUsePreKeyStore> { private final DynamoDbAsyncClient dynamoDbAsyncClient; private final String tableName; - private final String getKeyCountTimerName = name(getClass(), "getCount"); - + private final Timer getKeyCountTimer = Metrics.timer(name(getClass(), "getCount")); private final Timer storeKeyTimer = Metrics.timer(name(getClass(), "storeKey")); private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch")); private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice")); private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount")); + private final Counter noKeyCountAvailableCounter = Metrics.counter(name(getClass(), "noKeyCountAvailable")); + final DistributionSummary keysConsideredForTakeDistributionSummary = DistributionSummary .builder(name(getClass(), "keysConsideredForTake")) .publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999) @@ -181,28 +179,6 @@ public abstract class SingleUsePreKeyStore> { public CompletableFuture getCount(final UUID identifier, final byte deviceId) { final Timer.Sample sample = Timer.start(); - final AtomicBoolean countFromPeek = new AtomicBoolean(true); - - return peekCount(identifier, deviceId) - .thenCompose(maybeCount -> maybeCount - .map(CompletableFuture::completedFuture) - // Older key sets may not have a pre-calculated pre-key count; take the less efficient approach of counting - // items instead - .orElseGet(() -> { - countFromPeek.set(false); - return scanCount(identifier, deviceId); - })) - .whenComplete((keyCount, throwable) -> { - sample.stop(Metrics.timer(getKeyCountTimerName, "method", countFromPeek.get() ? "peek" : "scan")); - - if (throwable == null && keyCount != null) { - availableKeyCountDistributionSummary.record(keyCount); - } - }); - } - - @VisibleForTesting - CompletableFuture> peekCount(final UUID identifier, final byte deviceId) { return dynamoDbAsyncClient.query(QueryRequest.builder() .tableName(tableName) .consistentRead(false) @@ -219,38 +195,27 @@ public abstract class SingleUsePreKeyStore> { final Map item = response.items().getFirst(); if (item.containsKey(ATTR_REMAINING_KEYS)) { - return Optional.of(Integer.parseInt(item.get(ATTR_REMAINING_KEYS).n())); + return Integer.parseInt(item.get(ATTR_REMAINING_KEYS).n()); } else { - return Optional.empty(); + // Some legacy keys sets may not have pre-counted keys; in that case, we'll tell the owners of those key + // sets that they have none remaining, prompting an upload of a fresh set that we'll pre-count. This has + // no effect on consumers of keys, which will still be able to take keys if any are actually present. + noKeyCountAvailableCounter.increment(); + return 0; } } else { - return Optional.of(0); + return 0; + } + }) + .whenComplete((keyCount, throwable) -> { + sample.stop(getKeyCountTimer); + + if (throwable == null && keyCount != null) { + availableKeyCountDistributionSummary.record(keyCount); } }); } - @VisibleForTesting - CompletableFuture scanCount(final UUID identifier, final byte deviceId) { - // Getting an accurate count from DynamoDB can be very confusing. See: - // - // - https://github.com/aws/aws-sdk-java/issues/693 - // - https://github.com/aws/aws-sdk-java/issues/915 - // - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Query.html#Query.Count - return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() - .tableName(tableName) - .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") - .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) - .expressionAttributeValues(Map.of( - ":uuid", getPartitionKey(identifier), - ":sortprefix", getSortKeyPrefix(deviceId))) - .select(Select.COUNT) - .consistentRead(false) - .build())) - .map(QueryResponse::count) - .reduce(0, Integer::sum) - .toFuture(); - } - /** * Removes all single-use pre-keys for all devices associated with the given account/identity. * diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java index d28dd4f5b..071047891 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java @@ -17,8 +17,6 @@ import java.util.Set; import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.entities.PreKey; abstract class SingleUsePreKeyStoreTest> { @@ -53,9 +51,8 @@ abstract class SingleUsePreKeyStoreTest> { assertEquals(Optional.of(sortedPreKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join()); } - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void getCount(final boolean hasKeyCountAttribute) { + @Test + void getCount() { final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); final UUID accountIdentifier = UUID.randomUUID(); @@ -67,68 +64,17 @@ abstract class SingleUsePreKeyStoreTest> { preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); - if (!hasKeyCountAttribute) { - clearKeyCountAttributes(); - } - assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join()); for (int i = 0; i < KEY_COUNT; i++) { preKeyStore.take(accountIdentifier, deviceId).join(); assertEquals(KEY_COUNT - (i + 1), preKeyStore.getCount(accountIdentifier, deviceId).join()); } - } - - @Test - void peekCount() { - final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = 1; - - assertEquals(Optional.of(0), preKeyStore.peekCount(accountIdentifier, deviceId).join()); - - final List preKeys = generateRandomPreKeys(); - - preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); - - assertEquals(Optional.of(KEY_COUNT), preKeyStore.peekCount(accountIdentifier, deviceId).join()); - - for (int i = 0; i < KEY_COUNT; i++) { - preKeyStore.take(accountIdentifier, deviceId).join(); - assertEquals(Optional.of(KEY_COUNT - (i + 1)), preKeyStore.peekCount(accountIdentifier, deviceId).join()); - } preKeyStore.store(accountIdentifier, deviceId, List.of(generatePreKey(KEY_COUNT + 1))).join(); clearKeyCountAttributes(); - assertEquals(Optional.empty(), preKeyStore.peekCount(accountIdentifier, deviceId).join()); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void scanCount(final boolean hasKeyCountAttribute) { - final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = 1; - - assertEquals(0, preKeyStore.scanCount(accountIdentifier, deviceId).join()); - - final List preKeys = generateRandomPreKeys(); - - preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); - - if (!hasKeyCountAttribute) { - clearKeyCountAttributes(); - } - - assertEquals(KEY_COUNT, preKeyStore.scanCount(accountIdentifier, deviceId).join()); - - for (int i = 0; i < KEY_COUNT; i++) { - preKeyStore.take(accountIdentifier, deviceId).join(); - assertEquals(KEY_COUNT - (i + 1), preKeyStore.scanCount(accountIdentifier, deviceId).join()); - } + assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); } @Test