Use pre-calculated pre-key counts when possible
This commit is contained in:
parent
47fd8f5793
commit
7124621f66
|
@ -24,11 +24,16 @@ public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<ECPreKey> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final byte deviceId, final ECPreKey preKey) {
|
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier,
|
||||||
|
final byte deviceId,
|
||||||
|
final ECPreKey preKey,
|
||||||
|
final int remainingKeys) {
|
||||||
|
|
||||||
return Map.of(
|
return Map.of(
|
||||||
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
|
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
|
||||||
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.keyId()),
|
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.keyId()),
|
||||||
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.serializedPublicKey()));
|
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.serializedPublicKey()),
|
||||||
|
ATTR_REMAINING_KEYS, AttributeValues.fromInt(remainingKeys));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -21,12 +21,17 @@ public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore<KEMSignedPreKe
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final byte deviceId, final KEMSignedPreKey signedPreKey) {
|
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier,
|
||||||
|
final byte deviceId,
|
||||||
|
final KEMSignedPreKey signedPreKey,
|
||||||
|
final int remainingKeys) {
|
||||||
|
|
||||||
return Map.of(
|
return Map.of(
|
||||||
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
|
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
|
||||||
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.keyId()),
|
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.keyId()),
|
||||||
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.serializedPublicKey()),
|
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.serializedPublicKey()),
|
||||||
ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.signature()));
|
ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.signature()),
|
||||||
|
ATTR_REMAINING_KEYS, AttributeValues.fromInt(remainingKeys));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -8,16 +8,19 @@ package org.whispersystems.textsecuregcm.storage;
|
||||||
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||||
import static org.whispersystems.textsecuregcm.storage.AbstractDynamoDbStore.DYNAMO_DB_MAX_BATCH_SIZE;
|
import static org.whispersystems.textsecuregcm.storage.AbstractDynamoDbStore.DYNAMO_DB_MAX_BATCH_SIZE;
|
||||||
|
|
||||||
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import io.micrometer.core.instrument.DistributionSummary;
|
import io.micrometer.core.instrument.DistributionSummary;
|
||||||
import io.micrometer.core.instrument.Metrics;
|
import io.micrometer.core.instrument.Metrics;
|
||||||
import io.micrometer.core.instrument.Timer;
|
import io.micrometer.core.instrument.Timer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import org.whispersystems.textsecuregcm.entities.PreKey;
|
import org.whispersystems.textsecuregcm.entities.PreKey;
|
||||||
import org.whispersystems.textsecuregcm.util.AttributeValues;
|
import org.whispersystems.textsecuregcm.util.AttributeValues;
|
||||||
|
@ -49,9 +52,10 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
|
||||||
private final DynamoDbAsyncClient dynamoDbAsyncClient;
|
private final DynamoDbAsyncClient dynamoDbAsyncClient;
|
||||||
private final String tableName;
|
private final String tableName;
|
||||||
|
|
||||||
|
private final String getKeyCountTimerName = name(getClass(), "getCount");
|
||||||
|
|
||||||
private final Timer storeKeyTimer = Metrics.timer(name(getClass(), "storeKey"));
|
private final Timer storeKeyTimer = Metrics.timer(name(getClass(), "storeKey"));
|
||||||
private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch"));
|
private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch"));
|
||||||
private final Timer getKeyCountTimer = Metrics.timer(name(getClass(), "getCount"));
|
|
||||||
private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice"));
|
private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice"));
|
||||||
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
|
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
|
||||||
|
|
||||||
|
@ -74,6 +78,7 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
|
||||||
static final String KEY_DEVICE_ID_KEY_ID = "DK";
|
static final String KEY_DEVICE_ID_KEY_ID = "DK";
|
||||||
static final String ATTR_PUBLIC_KEY = "P";
|
static final String ATTR_PUBLIC_KEY = "P";
|
||||||
static final String ATTR_SIGNATURE = "S";
|
static final String ATTR_SIGNATURE = "S";
|
||||||
|
static final String ATTR_REMAINING_KEYS = "R";
|
||||||
|
|
||||||
protected SingleUsePreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
|
protected SingleUsePreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
|
||||||
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
|
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
|
||||||
|
@ -97,18 +102,22 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
|
||||||
return Mono.fromFuture(() -> delete(identifier, deviceId))
|
return Mono.fromFuture(() -> delete(identifier, deviceId))
|
||||||
.thenMany(
|
.thenMany(
|
||||||
Flux.fromIterable(preKeys)
|
Flux.fromIterable(preKeys)
|
||||||
.flatMap(preKey -> Mono.fromFuture(() -> store(identifier, deviceId, preKey)), DYNAMO_DB_MAX_BATCH_SIZE))
|
.sort(Comparator.comparing(preKey -> preKey.keyId()))
|
||||||
|
.zipWith(Flux.range(0, preKeys.size()).map(i -> preKeys.size() - i))
|
||||||
|
.flatMap(preKeyAndRemainingCount -> Mono.fromFuture(() ->
|
||||||
|
store(identifier, deviceId, preKeyAndRemainingCount.getT1(), preKeyAndRemainingCount.getT2())),
|
||||||
|
DYNAMO_DB_MAX_BATCH_SIZE))
|
||||||
.then()
|
.then()
|
||||||
.toFuture()
|
.toFuture()
|
||||||
.thenRun(() -> sample.stop(storeKeyBatchTimer));
|
.thenRun(() -> sample.stop(storeKeyBatchTimer));
|
||||||
}
|
}
|
||||||
|
|
||||||
private CompletableFuture<Void> store(final UUID identifier, final byte deviceId, final K preKey) {
|
private CompletableFuture<Void> store(final UUID identifier, final byte deviceId, final K preKey, final int remainingKeys) {
|
||||||
final Timer.Sample sample = Timer.start();
|
final Timer.Sample sample = Timer.start();
|
||||||
|
|
||||||
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
|
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
|
||||||
.tableName(tableName)
|
.tableName(tableName)
|
||||||
.item(getItemFromPreKey(identifier, deviceId, preKey))
|
.item(getItemFromPreKey(identifier, deviceId, preKey, remainingKeys))
|
||||||
.build())
|
.build())
|
||||||
.thenRun(() -> sample.stop(storeKeyTimer));
|
.thenRun(() -> sample.stop(storeKeyTimer));
|
||||||
}
|
}
|
||||||
|
@ -172,6 +181,56 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
|
||||||
public CompletableFuture<Integer> getCount(final UUID identifier, final byte deviceId) {
|
public CompletableFuture<Integer> getCount(final UUID identifier, final byte deviceId) {
|
||||||
final Timer.Sample sample = Timer.start();
|
final Timer.Sample sample = Timer.start();
|
||||||
|
|
||||||
|
final AtomicBoolean countFromPeek = new AtomicBoolean(true);
|
||||||
|
|
||||||
|
return peekCount(identifier, deviceId)
|
||||||
|
.thenCompose(maybeCount -> maybeCount
|
||||||
|
.map(CompletableFuture::completedFuture)
|
||||||
|
// Older key sets may not have a pre-calculated pre-key count; take the less efficient approach of counting
|
||||||
|
// items instead
|
||||||
|
.orElseGet(() -> {
|
||||||
|
countFromPeek.set(false);
|
||||||
|
return scanCount(identifier, deviceId);
|
||||||
|
}))
|
||||||
|
.whenComplete((keyCount, throwable) -> {
|
||||||
|
sample.stop(Metrics.timer(getKeyCountTimerName, "method", countFromPeek.get() ? "peek" : "scan"));
|
||||||
|
|
||||||
|
if (throwable == null && keyCount != null) {
|
||||||
|
availableKeyCountDistributionSummary.record(keyCount);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
CompletableFuture<Optional<Integer>> peekCount(final UUID identifier, final byte deviceId) {
|
||||||
|
return dynamoDbAsyncClient.query(QueryRequest.builder()
|
||||||
|
.tableName(tableName)
|
||||||
|
.consistentRead(false)
|
||||||
|
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
|
||||||
|
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
|
||||||
|
.expressionAttributeValues(Map.of(
|
||||||
|
":uuid", getPartitionKey(identifier),
|
||||||
|
":sortprefix", getSortKeyPrefix(deviceId)))
|
||||||
|
.projectionExpression(ATTR_REMAINING_KEYS)
|
||||||
|
.limit(1)
|
||||||
|
.build())
|
||||||
|
.thenApply(response -> {
|
||||||
|
if (response.count() > 0) {
|
||||||
|
final Map<String, AttributeValue> item = response.items().getFirst();
|
||||||
|
|
||||||
|
if (item.containsKey(ATTR_REMAINING_KEYS)) {
|
||||||
|
return Optional.of(Integer.parseInt(item.get(ATTR_REMAINING_KEYS).n()));
|
||||||
|
} else {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Optional.of(0);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
CompletableFuture<Integer> scanCount(final UUID identifier, final byte deviceId) {
|
||||||
// Getting an accurate count from DynamoDB can be very confusing. See:
|
// Getting an accurate count from DynamoDB can be very confusing. See:
|
||||||
//
|
//
|
||||||
// - https://github.com/aws/aws-sdk-java/issues/693
|
// - https://github.com/aws/aws-sdk-java/issues/693
|
||||||
|
@ -189,14 +248,7 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
|
||||||
.build()))
|
.build()))
|
||||||
.map(QueryResponse::count)
|
.map(QueryResponse::count)
|
||||||
.reduce(0, Integer::sum)
|
.reduce(0, Integer::sum)
|
||||||
.toFuture()
|
.toFuture();
|
||||||
.whenComplete((keyCount, throwable) -> {
|
|
||||||
sample.stop(getKeyCountTimer);
|
|
||||||
|
|
||||||
if (throwable == null && keyCount != null) {
|
|
||||||
availableKeyCountDistributionSummary.record(keyCount);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -280,8 +332,10 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
|
||||||
return AttributeValues.fromByteBuffer(byteBuffer.flip());
|
return AttributeValues.fromByteBuffer(byteBuffer.flip());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final byte deviceId,
|
protected abstract Map<String, AttributeValue> getItemFromPreKey(final UUID identifier,
|
||||||
final K preKey);
|
final byte deviceId,
|
||||||
|
final K preKey,
|
||||||
|
final int remainingKeys);
|
||||||
|
|
||||||
protected abstract K getPreKeyFromItem(final Map<String, AttributeValue> item);
|
protected abstract K getPreKeyFromItem(final Map<String, AttributeValue> item);
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,12 @@ import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||||
import org.signal.libsignal.protocol.ecc.Curve;
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
import org.whispersystems.textsecuregcm.entities.ECPreKey;
|
import org.whispersystems.textsecuregcm.entities.ECPreKey;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.model.ScanResponse;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.paginators.ScanIterable;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<ECPreKey> {
|
class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<ECPreKey> {
|
||||||
|
|
||||||
|
@ -32,4 +38,24 @@ class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<ECPreKey> {
|
||||||
protected ECPreKey generatePreKey(final long keyId) {
|
protected ECPreKey generatePreKey(final long keyId) {
|
||||||
return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey());
|
return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void clearKeyCountAttributes() {
|
||||||
|
final ScanIterable scanIterable = DYNAMO_DB_EXTENSION.getDynamoDbClient().scanPaginator(ScanRequest.builder()
|
||||||
|
.tableName(DynamoDbExtensionSchema.Tables.EC_KEYS.tableName())
|
||||||
|
.build());
|
||||||
|
|
||||||
|
for (final ScanResponse response : scanIterable) {
|
||||||
|
for (final Map<String, AttributeValue> item : response.items()) {
|
||||||
|
|
||||||
|
DYNAMO_DB_EXTENSION.getDynamoDbClient().updateItem(UpdateItemRequest.builder()
|
||||||
|
.tableName(DynamoDbExtensionSchema.Tables.EC_KEYS.tableName())
|
||||||
|
.key(Map.of(
|
||||||
|
SingleUsePreKeyStore.KEY_ACCOUNT_UUID, item.get(SingleUsePreKeyStore.KEY_ACCOUNT_UUID),
|
||||||
|
SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID, item.get(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID)))
|
||||||
|
.updateExpression("REMOVE " + SingleUsePreKeyStore.ATTR_REMAINING_KEYS)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,12 @@ import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||||
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.model.ScanResponse;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.paginators.ScanIterable;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<KEMSignedPreKey> {
|
class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<KEMSignedPreKey> {
|
||||||
|
|
||||||
|
@ -36,4 +42,24 @@ class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<KEMSignedPreK
|
||||||
protected KEMSignedPreKey generatePreKey(final long keyId) {
|
protected KEMSignedPreKey generatePreKey(final long keyId) {
|
||||||
return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR);
|
return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void clearKeyCountAttributes() {
|
||||||
|
final ScanIterable scanIterable = DYNAMO_DB_EXTENSION.getDynamoDbClient().scanPaginator(ScanRequest.builder()
|
||||||
|
.tableName(DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName())
|
||||||
|
.build());
|
||||||
|
|
||||||
|
for (final ScanResponse response : scanIterable) {
|
||||||
|
for (final Map<String, AttributeValue> item : response.items()) {
|
||||||
|
|
||||||
|
DYNAMO_DB_EXTENSION.getDynamoDbClient().updateItem(UpdateItemRequest.builder()
|
||||||
|
.tableName(DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName())
|
||||||
|
.key(Map.of(
|
||||||
|
SingleUsePreKeyStore.KEY_ACCOUNT_UUID, item.get(SingleUsePreKeyStore.KEY_ACCOUNT_UUID),
|
||||||
|
SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID, item.get(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID)))
|
||||||
|
.updateExpression("REMOVE " + SingleUsePreKeyStore.ATTR_REMAINING_KEYS)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,10 +9,16 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
import org.whispersystems.textsecuregcm.entities.PreKey;
|
import org.whispersystems.textsecuregcm.entities.PreKey;
|
||||||
|
|
||||||
abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||||
|
@ -23,6 +29,8 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||||
|
|
||||||
protected abstract K generatePreKey(final long keyId);
|
protected abstract K generatePreKey(final long keyId);
|
||||||
|
|
||||||
|
protected abstract void clearKeyCountAttributes();
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void storeTake() {
|
void storeTake() {
|
||||||
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
|
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
|
||||||
|
@ -32,20 +40,22 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||||
|
|
||||||
assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join());
|
assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join());
|
||||||
|
|
||||||
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
|
final List<K> sortedPreKeys;
|
||||||
|
{
|
||||||
|
final List<K> preKeys = generateRandomPreKeys();
|
||||||
|
assertDoesNotThrow(() -> preKeyStore.store(accountIdentifier, deviceId, preKeys).join());
|
||||||
|
|
||||||
for (int i = 0; i < KEY_COUNT; i++) {
|
sortedPreKeys = new ArrayList<>(preKeys);
|
||||||
preKeys.add(generatePreKey(i));
|
sortedPreKeys.sort(Comparator.comparing(preKey -> preKey.keyId()));
|
||||||
}
|
}
|
||||||
|
|
||||||
assertDoesNotThrow(() -> preKeyStore.store(accountIdentifier, deviceId, preKeys).join());
|
assertEquals(Optional.of(sortedPreKeys.get(0)), preKeyStore.take(accountIdentifier, deviceId).join());
|
||||||
|
assertEquals(Optional.of(sortedPreKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join());
|
||||||
assertEquals(Optional.of(preKeys.get(0)), preKeyStore.take(accountIdentifier, deviceId).join());
|
|
||||||
assertEquals(Optional.of(preKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
void getCount() {
|
@ValueSource(booleans = {true, false})
|
||||||
|
void getCount(final boolean hasKeyCountAttribute) {
|
||||||
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
|
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
|
||||||
|
|
||||||
final UUID accountIdentifier = UUID.randomUUID();
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
@ -53,15 +63,72 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||||
|
|
||||||
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||||
|
|
||||||
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
|
final List<K> preKeys = generateRandomPreKeys();
|
||||||
|
|
||||||
for (int i = 0; i < KEY_COUNT; i++) {
|
|
||||||
preKeys.add(generatePreKey(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
||||||
|
|
||||||
|
if (!hasKeyCountAttribute) {
|
||||||
|
clearKeyCountAttributes();
|
||||||
|
}
|
||||||
|
|
||||||
assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||||
|
|
||||||
|
for (int i = 0; i < KEY_COUNT; i++) {
|
||||||
|
preKeyStore.take(accountIdentifier, deviceId).join();
|
||||||
|
assertEquals(KEY_COUNT - (i + 1), preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void peekCount() {
|
||||||
|
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
|
||||||
|
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = 1;
|
||||||
|
|
||||||
|
assertEquals(Optional.of(0), preKeyStore.peekCount(accountIdentifier, deviceId).join());
|
||||||
|
|
||||||
|
final List<K> preKeys = generateRandomPreKeys();
|
||||||
|
|
||||||
|
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
||||||
|
|
||||||
|
assertEquals(Optional.of(KEY_COUNT), preKeyStore.peekCount(accountIdentifier, deviceId).join());
|
||||||
|
|
||||||
|
for (int i = 0; i < KEY_COUNT; i++) {
|
||||||
|
preKeyStore.take(accountIdentifier, deviceId).join();
|
||||||
|
assertEquals(Optional.of(KEY_COUNT - (i + 1)), preKeyStore.peekCount(accountIdentifier, deviceId).join());
|
||||||
|
}
|
||||||
|
|
||||||
|
preKeyStore.store(accountIdentifier, deviceId, List.of(generatePreKey(KEY_COUNT + 1))).join();
|
||||||
|
clearKeyCountAttributes();
|
||||||
|
|
||||||
|
assertEquals(Optional.empty(), preKeyStore.peekCount(accountIdentifier, deviceId).join());
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(booleans = {true, false})
|
||||||
|
void scanCount(final boolean hasKeyCountAttribute) {
|
||||||
|
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
|
||||||
|
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = 1;
|
||||||
|
|
||||||
|
assertEquals(0, preKeyStore.scanCount(accountIdentifier, deviceId).join());
|
||||||
|
|
||||||
|
final List<K> preKeys = generateRandomPreKeys();
|
||||||
|
|
||||||
|
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
||||||
|
|
||||||
|
if (!hasKeyCountAttribute) {
|
||||||
|
clearKeyCountAttributes();
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(KEY_COUNT, preKeyStore.scanCount(accountIdentifier, deviceId).join());
|
||||||
|
|
||||||
|
for (int i = 0; i < KEY_COUNT; i++) {
|
||||||
|
preKeyStore.take(accountIdentifier, deviceId).join();
|
||||||
|
assertEquals(KEY_COUNT - (i + 1), preKeyStore.scanCount(accountIdentifier, deviceId).join());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -74,11 +141,7 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||||
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||||
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join());
|
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join());
|
||||||
|
|
||||||
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
|
final List<K> preKeys = generateRandomPreKeys();
|
||||||
|
|
||||||
for (int i = 0; i < KEY_COUNT; i++) {
|
|
||||||
preKeys.add(generatePreKey(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
||||||
preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
|
preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
|
||||||
|
@ -99,11 +162,7 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||||
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||||
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join());
|
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join());
|
||||||
|
|
||||||
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
|
final List<K> preKeys = generateRandomPreKeys();
|
||||||
|
|
||||||
for (int i = 0; i < KEY_COUNT; i++) {
|
|
||||||
preKeys.add(generatePreKey(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
|
||||||
preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
|
preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
|
||||||
|
@ -113,4 +172,16 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
|
||||||
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
|
||||||
assertEquals(0, preKeyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join());
|
assertEquals(0, preKeyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<K> generateRandomPreKeys() {
|
||||||
|
final Set<Integer> keyIds = new HashSet<>(KEY_COUNT);
|
||||||
|
|
||||||
|
while (keyIds.size() < KEY_COUNT) {
|
||||||
|
keyIds.add(Math.abs(ThreadLocalRandom.current().nextInt()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return keyIds.stream()
|
||||||
|
.map(this::generatePreKey)
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue