Always use the "peek" strategy for counting one-time pre-keys
This commit is contained in:
		
							parent
							
								
									f59c34004d
								
							
						
					
					
						commit
						796dce3cd3
					
				|  | @ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.storage; | ||||||
| import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; | import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; | ||||||
| import static org.whispersystems.textsecuregcm.storage.AbstractDynamoDbStore.DYNAMO_DB_MAX_BATCH_SIZE; | import static org.whispersystems.textsecuregcm.storage.AbstractDynamoDbStore.DYNAMO_DB_MAX_BATCH_SIZE; | ||||||
| 
 | 
 | ||||||
| import com.google.common.annotations.VisibleForTesting; | import io.micrometer.core.instrument.Counter; | ||||||
| import io.micrometer.core.instrument.DistributionSummary; | import io.micrometer.core.instrument.DistributionSummary; | ||||||
| import io.micrometer.core.instrument.Metrics; | import io.micrometer.core.instrument.Metrics; | ||||||
| import io.micrometer.core.instrument.Timer; | import io.micrometer.core.instrument.Timer; | ||||||
|  | @ -20,7 +20,6 @@ import java.util.Map; | ||||||
| import java.util.Optional; | import java.util.Optional; | ||||||
| import java.util.UUID; | import java.util.UUID; | ||||||
| import java.util.concurrent.CompletableFuture; | import java.util.concurrent.CompletableFuture; | ||||||
| import java.util.concurrent.atomic.AtomicBoolean; |  | ||||||
| import java.util.concurrent.atomic.AtomicInteger; | import java.util.concurrent.atomic.AtomicInteger; | ||||||
| import org.whispersystems.textsecuregcm.entities.PreKey; | import org.whispersystems.textsecuregcm.entities.PreKey; | ||||||
| import org.whispersystems.textsecuregcm.util.AttributeValues; | import org.whispersystems.textsecuregcm.util.AttributeValues; | ||||||
|  | @ -33,9 +32,7 @@ import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; | ||||||
| import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; | import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; | ||||||
| import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; | import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; | ||||||
| import software.amazon.awssdk.services.dynamodb.model.QueryRequest; | import software.amazon.awssdk.services.dynamodb.model.QueryRequest; | ||||||
| import software.amazon.awssdk.services.dynamodb.model.QueryResponse; |  | ||||||
| import software.amazon.awssdk.services.dynamodb.model.ReturnValue; | import software.amazon.awssdk.services.dynamodb.model.ReturnValue; | ||||||
| import software.amazon.awssdk.services.dynamodb.model.Select; |  | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * A single-use pre-key store stores single-use pre-keys of a specific type. Keys returned by a single-use pre-key |  * A single-use pre-key store stores single-use pre-keys of a specific type. Keys returned by a single-use pre-key | ||||||
|  | @ -52,13 +49,14 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> { | ||||||
|   private final DynamoDbAsyncClient dynamoDbAsyncClient; |   private final DynamoDbAsyncClient dynamoDbAsyncClient; | ||||||
|   private final String tableName; |   private final String tableName; | ||||||
| 
 | 
 | ||||||
|   private final String getKeyCountTimerName = name(getClass(), "getCount"); |   private final Timer getKeyCountTimer = Metrics.timer(name(getClass(), "getCount")); | ||||||
| 
 |  | ||||||
|   private final Timer storeKeyTimer = Metrics.timer(name(getClass(), "storeKey")); |   private final Timer storeKeyTimer = Metrics.timer(name(getClass(), "storeKey")); | ||||||
|   private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch")); |   private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch")); | ||||||
|   private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice")); |   private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice")); | ||||||
|   private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount")); |   private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount")); | ||||||
| 
 | 
 | ||||||
|  |   private final Counter noKeyCountAvailableCounter = Metrics.counter(name(getClass(), "noKeyCountAvailable")); | ||||||
|  | 
 | ||||||
|   final DistributionSummary keysConsideredForTakeDistributionSummary = DistributionSummary |   final DistributionSummary keysConsideredForTakeDistributionSummary = DistributionSummary | ||||||
|       .builder(name(getClass(), "keysConsideredForTake")) |       .builder(name(getClass(), "keysConsideredForTake")) | ||||||
|       .publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999) |       .publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999) | ||||||
|  | @ -181,28 +179,6 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> { | ||||||
|   public CompletableFuture<Integer> getCount(final UUID identifier, final byte deviceId) { |   public CompletableFuture<Integer> getCount(final UUID identifier, final byte deviceId) { | ||||||
|     final Timer.Sample sample = Timer.start(); |     final Timer.Sample sample = Timer.start(); | ||||||
| 
 | 
 | ||||||
|     final AtomicBoolean countFromPeek = new AtomicBoolean(true); |  | ||||||
| 
 |  | ||||||
|     return peekCount(identifier, deviceId) |  | ||||||
|         .thenCompose(maybeCount -> maybeCount |  | ||||||
|             .map(CompletableFuture::completedFuture) |  | ||||||
|             // Older key sets may not have a pre-calculated pre-key count; take the less efficient approach of counting |  | ||||||
|             // items instead |  | ||||||
|             .orElseGet(() -> { |  | ||||||
|               countFromPeek.set(false); |  | ||||||
|               return scanCount(identifier, deviceId); |  | ||||||
|             })) |  | ||||||
|         .whenComplete((keyCount, throwable) -> { |  | ||||||
|           sample.stop(Metrics.timer(getKeyCountTimerName, "method", countFromPeek.get() ? "peek" : "scan")); |  | ||||||
| 
 |  | ||||||
|           if (throwable == null && keyCount != null) { |  | ||||||
|             availableKeyCountDistributionSummary.record(keyCount); |  | ||||||
|           } |  | ||||||
|         }); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   @VisibleForTesting |  | ||||||
|   CompletableFuture<Optional<Integer>> peekCount(final UUID identifier, final byte deviceId) { |  | ||||||
|     return dynamoDbAsyncClient.query(QueryRequest.builder() |     return dynamoDbAsyncClient.query(QueryRequest.builder() | ||||||
|             .tableName(tableName) |             .tableName(tableName) | ||||||
|             .consistentRead(false) |             .consistentRead(false) | ||||||
|  | @ -219,38 +195,27 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> { | ||||||
|             final Map<String, AttributeValue> item = response.items().getFirst(); |             final Map<String, AttributeValue> item = response.items().getFirst(); | ||||||
| 
 | 
 | ||||||
|             if (item.containsKey(ATTR_REMAINING_KEYS)) { |             if (item.containsKey(ATTR_REMAINING_KEYS)) { | ||||||
|               return Optional.of(Integer.parseInt(item.get(ATTR_REMAINING_KEYS).n())); |               return Integer.parseInt(item.get(ATTR_REMAINING_KEYS).n()); | ||||||
|             } else { |             } else { | ||||||
|               return Optional.empty(); |               // Some legacy keys sets may not have pre-counted keys; in that case, we'll tell the owners of those key | ||||||
|  |               // sets that they have none remaining, prompting an upload of a fresh set that we'll pre-count. This has | ||||||
|  |               // no effect on consumers of keys, which will still be able to take keys if any are actually present. | ||||||
|  |               noKeyCountAvailableCounter.increment(); | ||||||
|  |               return 0; | ||||||
|             } |             } | ||||||
|           } else { |           } else { | ||||||
|             return Optional.of(0); |             return 0; | ||||||
|  |           } | ||||||
|  |         }) | ||||||
|  |         .whenComplete((keyCount, throwable) -> { | ||||||
|  |           sample.stop(getKeyCountTimer); | ||||||
|  | 
 | ||||||
|  |           if (throwable == null && keyCount != null) { | ||||||
|  |             availableKeyCountDistributionSummary.record(keyCount); | ||||||
|           } |           } | ||||||
|         }); |         }); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @VisibleForTesting |  | ||||||
|   CompletableFuture<Integer> scanCount(final UUID identifier, final byte deviceId) { |  | ||||||
|     // Getting an accurate count from DynamoDB can be very confusing. See: |  | ||||||
|     // |  | ||||||
|     // - https://github.com/aws/aws-sdk-java/issues/693 |  | ||||||
|     // - https://github.com/aws/aws-sdk-java/issues/915 |  | ||||||
|     // - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Query.html#Query.Count |  | ||||||
|     return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() |  | ||||||
|             .tableName(tableName) |  | ||||||
|             .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") |  | ||||||
|             .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) |  | ||||||
|             .expressionAttributeValues(Map.of( |  | ||||||
|                 ":uuid", getPartitionKey(identifier), |  | ||||||
|                 ":sortprefix", getSortKeyPrefix(deviceId))) |  | ||||||
|             .select(Select.COUNT) |  | ||||||
|             .consistentRead(false) |  | ||||||
|             .build())) |  | ||||||
|         .map(QueryResponse::count) |  | ||||||
|         .reduce(0, Integer::sum) |  | ||||||
|         .toFuture(); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /** |   /** | ||||||
|    * Removes all single-use pre-keys for all devices associated with the given account/identity. |    * Removes all single-use pre-keys for all devices associated with the given account/identity. | ||||||
|    * |    * | ||||||
|  |  | ||||||
|  | @ -17,8 +17,6 @@ import java.util.Set; | ||||||
| import java.util.UUID; | import java.util.UUID; | ||||||
| import java.util.concurrent.ThreadLocalRandom; | import java.util.concurrent.ThreadLocalRandom; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; |  | ||||||
| import org.junit.jupiter.params.provider.ValueSource; |  | ||||||
| import org.whispersystems.textsecuregcm.entities.PreKey; | import org.whispersystems.textsecuregcm.entities.PreKey; | ||||||
| 
 | 
 | ||||||
| abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> { | abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> { | ||||||
|  | @ -53,9 +51,8 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> { | ||||||
|     assertEquals(Optional.of(sortedPreKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join()); |     assertEquals(Optional.of(sortedPreKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @ParameterizedTest |   @Test | ||||||
|   @ValueSource(booleans = {true, false}) |   void getCount() { | ||||||
|   void getCount(final boolean hasKeyCountAttribute) { |  | ||||||
|     final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore(); |     final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore(); | ||||||
| 
 | 
 | ||||||
|     final UUID accountIdentifier = UUID.randomUUID(); |     final UUID accountIdentifier = UUID.randomUUID(); | ||||||
|  | @ -67,68 +64,17 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> { | ||||||
| 
 | 
 | ||||||
|     preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); |     preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); | ||||||
| 
 | 
 | ||||||
|     if (!hasKeyCountAttribute) { |  | ||||||
|       clearKeyCountAttributes(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join()); |     assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join()); | ||||||
| 
 | 
 | ||||||
|     for (int i = 0; i < KEY_COUNT; i++) { |     for (int i = 0; i < KEY_COUNT; i++) { | ||||||
|       preKeyStore.take(accountIdentifier, deviceId).join(); |       preKeyStore.take(accountIdentifier, deviceId).join(); | ||||||
|       assertEquals(KEY_COUNT - (i + 1), preKeyStore.getCount(accountIdentifier, deviceId).join()); |       assertEquals(KEY_COUNT - (i + 1), preKeyStore.getCount(accountIdentifier, deviceId).join()); | ||||||
|     } |     } | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   @Test |  | ||||||
|   void peekCount() { |  | ||||||
|     final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore(); |  | ||||||
| 
 |  | ||||||
|     final UUID accountIdentifier = UUID.randomUUID(); |  | ||||||
|     final byte deviceId = 1; |  | ||||||
| 
 |  | ||||||
|     assertEquals(Optional.of(0), preKeyStore.peekCount(accountIdentifier, deviceId).join()); |  | ||||||
| 
 |  | ||||||
|     final List<K> preKeys = generateRandomPreKeys(); |  | ||||||
| 
 |  | ||||||
|     preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); |  | ||||||
| 
 |  | ||||||
|     assertEquals(Optional.of(KEY_COUNT), preKeyStore.peekCount(accountIdentifier, deviceId).join()); |  | ||||||
| 
 |  | ||||||
|     for (int i = 0; i < KEY_COUNT; i++) { |  | ||||||
|       preKeyStore.take(accountIdentifier, deviceId).join(); |  | ||||||
|       assertEquals(Optional.of(KEY_COUNT - (i + 1)), preKeyStore.peekCount(accountIdentifier, deviceId).join()); |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     preKeyStore.store(accountIdentifier, deviceId, List.of(generatePreKey(KEY_COUNT + 1))).join(); |     preKeyStore.store(accountIdentifier, deviceId, List.of(generatePreKey(KEY_COUNT + 1))).join(); | ||||||
|     clearKeyCountAttributes(); |     clearKeyCountAttributes(); | ||||||
| 
 | 
 | ||||||
|     assertEquals(Optional.empty(), preKeyStore.peekCount(accountIdentifier, deviceId).join()); |     assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   @ParameterizedTest |  | ||||||
|   @ValueSource(booleans = {true, false}) |  | ||||||
|   void scanCount(final boolean hasKeyCountAttribute) { |  | ||||||
|     final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore(); |  | ||||||
| 
 |  | ||||||
|     final UUID accountIdentifier = UUID.randomUUID(); |  | ||||||
|     final byte deviceId = 1; |  | ||||||
| 
 |  | ||||||
|     assertEquals(0, preKeyStore.scanCount(accountIdentifier, deviceId).join()); |  | ||||||
| 
 |  | ||||||
|     final List<K> preKeys = generateRandomPreKeys(); |  | ||||||
| 
 |  | ||||||
|     preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); |  | ||||||
| 
 |  | ||||||
|     if (!hasKeyCountAttribute) { |  | ||||||
|       clearKeyCountAttributes(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     assertEquals(KEY_COUNT, preKeyStore.scanCount(accountIdentifier, deviceId).join()); |  | ||||||
| 
 |  | ||||||
|     for (int i = 0; i < KEY_COUNT; i++) { |  | ||||||
|       preKeyStore.take(accountIdentifier, deviceId).join(); |  | ||||||
|       assertEquals(KEY_COUNT - (i + 1), preKeyStore.scanCount(accountIdentifier, deviceId).join()); |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Test |   @Test | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Jon Chambers
						Jon Chambers