Always use the "peek" strategy for counting one-time pre-keys

This commit is contained in:
Jon Chambers 2024-04-01 13:29:17 -04:00 committed by Jon Chambers
parent f59c34004d
commit 796dce3cd3
2 changed files with 21 additions and 110 deletions

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import static org.whispersystems.textsecuregcm.storage.AbstractDynamoDbStore.DYNAMO_DB_MAX_BATCH_SIZE;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
@ -20,7 +20,6 @@ import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
@ -33,9 +32,7 @@ import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.Select;
/**
* A single-use pre-key store stores single-use pre-keys of a specific type. Keys returned by a single-use pre-key
@ -52,13 +49,14 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName;
private final String getKeyCountTimerName = name(getClass(), "getCount");
private final Timer getKeyCountTimer = Metrics.timer(name(getClass(), "getCount"));
private final Timer storeKeyTimer = Metrics.timer(name(getClass(), "storeKey"));
private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch"));
private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice"));
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
private final Counter noKeyCountAvailableCounter = Metrics.counter(name(getClass(), "noKeyCountAvailable"));
final DistributionSummary keysConsideredForTakeDistributionSummary = DistributionSummary
.builder(name(getClass(), "keysConsideredForTake"))
.publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999)
@ -181,28 +179,6 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
public CompletableFuture<Integer> getCount(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
final AtomicBoolean countFromPeek = new AtomicBoolean(true);
return peekCount(identifier, deviceId)
.thenCompose(maybeCount -> maybeCount
.map(CompletableFuture::completedFuture)
// Older key sets may not have a pre-calculated pre-key count; take the less efficient approach of counting
// items instead
.orElseGet(() -> {
countFromPeek.set(false);
return scanCount(identifier, deviceId);
}))
.whenComplete((keyCount, throwable) -> {
sample.stop(Metrics.timer(getKeyCountTimerName, "method", countFromPeek.get() ? "peek" : "scan"));
if (throwable == null && keyCount != null) {
availableKeyCountDistributionSummary.record(keyCount);
}
});
}
@VisibleForTesting
CompletableFuture<Optional<Integer>> peekCount(final UUID identifier, final byte deviceId) {
return dynamoDbAsyncClient.query(QueryRequest.builder()
.tableName(tableName)
.consistentRead(false)
@ -219,38 +195,27 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
final Map<String, AttributeValue> item = response.items().getFirst();
if (item.containsKey(ATTR_REMAINING_KEYS)) {
return Optional.of(Integer.parseInt(item.get(ATTR_REMAINING_KEYS).n()));
return Integer.parseInt(item.get(ATTR_REMAINING_KEYS).n());
} else {
return Optional.empty();
// Some legacy keys sets may not have pre-counted keys; in that case, we'll tell the owners of those key
// sets that they have none remaining, prompting an upload of a fresh set that we'll pre-count. This has
// no effect on consumers of keys, which will still be able to take keys if any are actually present.
noKeyCountAvailableCounter.increment();
return 0;
}
} else {
return Optional.of(0);
return 0;
}
})
.whenComplete((keyCount, throwable) -> {
sample.stop(getKeyCountTimer);
if (throwable == null && keyCount != null) {
availableKeyCountDistributionSummary.record(keyCount);
}
});
}
@VisibleForTesting
CompletableFuture<Integer> scanCount(final UUID identifier, final byte deviceId) {
// Getting an accurate count from DynamoDB can be very confusing. See:
//
// - https://github.com/aws/aws-sdk-java/issues/693
// - https://github.com/aws/aws-sdk-java/issues/915
// - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Query.html#Query.Count
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(identifier),
":sortprefix", getSortKeyPrefix(deviceId)))
.select(Select.COUNT)
.consistentRead(false)
.build()))
.map(QueryResponse::count)
.reduce(0, Integer::sum)
.toFuture();
}
/**
* Removes all single-use pre-keys for all devices associated with the given account/identity.
*

View File

@ -17,8 +17,6 @@ import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.entities.PreKey;
abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
@ -53,9 +51,8 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
assertEquals(Optional.of(sortedPreKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void getCount(final boolean hasKeyCountAttribute) {
@Test
void getCount() {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
@ -67,68 +64,17 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
if (!hasKeyCountAttribute) {
clearKeyCountAttributes();
}
assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join());
for (int i = 0; i < KEY_COUNT; i++) {
preKeyStore.take(accountIdentifier, deviceId).join();
assertEquals(KEY_COUNT - (i + 1), preKeyStore.getCount(accountIdentifier, deviceId).join());
}
}
@Test
void peekCount() {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
assertEquals(Optional.of(0), preKeyStore.peekCount(accountIdentifier, deviceId).join());
final List<K> preKeys = generateRandomPreKeys();
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
assertEquals(Optional.of(KEY_COUNT), preKeyStore.peekCount(accountIdentifier, deviceId).join());
for (int i = 0; i < KEY_COUNT; i++) {
preKeyStore.take(accountIdentifier, deviceId).join();
assertEquals(Optional.of(KEY_COUNT - (i + 1)), preKeyStore.peekCount(accountIdentifier, deviceId).join());
}
preKeyStore.store(accountIdentifier, deviceId, List.of(generatePreKey(KEY_COUNT + 1))).join();
clearKeyCountAttributes();
assertEquals(Optional.empty(), preKeyStore.peekCount(accountIdentifier, deviceId).join());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void scanCount(final boolean hasKeyCountAttribute) {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = 1;
assertEquals(0, preKeyStore.scanCount(accountIdentifier, deviceId).join());
final List<K> preKeys = generateRandomPreKeys();
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
if (!hasKeyCountAttribute) {
clearKeyCountAttributes();
}
assertEquals(KEY_COUNT, preKeyStore.scanCount(accountIdentifier, deviceId).join());
for (int i = 0; i < KEY_COUNT; i++) {
preKeyStore.take(accountIdentifier, deviceId).join();
assertEquals(KEY_COUNT - (i + 1), preKeyStore.scanCount(accountIdentifier, deviceId).join());
}
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
}
@Test