From d4d94038291a0add60e8bf2fabc567d65676bcbf Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Thu, 21 Jan 2021 15:43:15 -0500 Subject: [PATCH] Add a Dynamo-backed key store. --- .../WhisperServerConfiguration.java | 10 + .../textsecuregcm/WhisperServerService.java | 13 +- .../configuration/DynamoDbConfiguration.java | 50 +++++ .../MessageDynamoDbConfiguration.java | 27 +-- .../controllers/KeysController.java | 8 +- .../storage/AbstractDynamoDbStore.java | 76 +++++++ .../storage/AccountsManager.java | 7 +- .../textsecuregcm/storage/KeyRecord.java | 18 ++ .../textsecuregcm/storage/Keys.java | 32 ++- .../textsecuregcm/storage/KeysDynamoDb.java | 208 ++++++++++++++++++ .../storage/MessagesDynamoDb.java | 72 +----- .../textsecuregcm/storage/PreKeyStore.java | 23 ++ .../workers/DeleteUserCommand.java | 13 +- .../storage/KeysDynamoDbRule.java | 40 ++++ .../storage/KeysDynamoDbTest.java | 136 ++++++++++++ ...ollerTest.java => KeysControllerTest.java} | 30 ++- .../tests/storage/AccountsManagerTest.java | 20 +- .../textsecuregcm/tests/storage/KeysTest.java | 185 +++++++++------- 18 files changed, 758 insertions(+), 210 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbConfiguration.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/PreKeyStore.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbRule.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbTest.java rename service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/{KeyControllerTest.java => KeysControllerTest.java} (96%) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index fc3e2f872..1b6742656 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -14,6 +14,7 @@ import org.whispersystems.textsecuregcm.configuration.AwsAttachmentsConfiguratio import org.whispersystems.textsecuregcm.configuration.CdnConfiguration; import org.whispersystems.textsecuregcm.configuration.DatabaseConfiguration; import org.whispersystems.textsecuregcm.configuration.DirectoryConfiguration; +import org.whispersystems.textsecuregcm.configuration.DynamoDbConfiguration; import org.whispersystems.textsecuregcm.configuration.GcmConfiguration; import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguration; import org.whispersystems.textsecuregcm.configuration.AccountsDatabaseConfiguration; @@ -128,6 +129,11 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private MessageDynamoDbConfiguration messageDynamoDb; + @Valid + @NotNull + @JsonProperty + private DynamoDbConfiguration keysDynamoDb; + @Valid @NotNull @JsonProperty @@ -306,6 +312,10 @@ public class WhisperServerConfiguration extends Configuration { return messageDynamoDb; } + public DynamoDbConfiguration getKeysDynamoDbConfiguration() { + return keysDynamoDb; + } + public DatabaseConfiguration getMessageStoreConfiguration() { return messageStore; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index de4d9e752..ef436ff8e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -126,6 +126,7 @@ import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; import org.whispersystems.textsecuregcm.storage.FeatureFlags; import org.whispersystems.textsecuregcm.storage.FeatureFlagsManager; import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagePersister; import org.whispersystems.textsecuregcm.storage.Messages; import org.whispersystems.textsecuregcm.storage.MessagesCache; @@ -275,7 +276,16 @@ public class WhisperServerService extends Application 0) { count = count - 1; @@ -98,7 +98,7 @@ public class KeysController { } } - keys.store(account.getNumber(), device.getId(), preKeys.getPreKeys()); + keys.store(account, device.getId(), preKeys.getPreKeys()); } @Timed @@ -179,12 +179,12 @@ public class KeysController { private List getLocalKeys(Account destination, String deviceIdSelector) { try { if (deviceIdSelector.equals("*")) { - return keys.get(destination.getNumber()); + return keys.take(destination); } long deviceId = Long.parseLong(deviceIdSelector); - return keys.get(destination.getNumber(), deviceId); + return keys.take(destination, deviceId); } catch (NumberFormatException e) { throw new WebApplicationException(Response.status(422).build()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java new file mode 100644 index 000000000..401b3979b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java @@ -0,0 +1,76 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import com.amazonaws.services.dynamodbv2.document.BatchWriteItemOutcome; +import com.amazonaws.services.dynamodbv2.document.DynamoDB; +import com.amazonaws.services.dynamodbv2.document.TableWriteItems; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static com.codahale.metrics.MetricRegistry.name; +import static io.micrometer.core.instrument.Metrics.counter; +import static io.micrometer.core.instrument.Metrics.timer; + +public class AbstractDynamoDbStore { + + private final DynamoDB dynamoDb; + + private final Timer batchWriteItemsFirstPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "true"); + private final Timer batchWriteItemsRetryPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "false"); + private final Counter batchWriteItemsUnprocessed = counter(name(getClass(), "batchWriteItemsUnprocessed")); + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private static final int MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE = 25; // This was arbitrarily chosen and may be entirely too high. + public static final int DYNAMO_DB_MAX_BATCH_SIZE = 25; // This limit comes from Amazon Dynamo DB itself. It will reject batch writes larger than this. + public static final int RESULT_SET_CHUNK_SIZE = 100; + + public AbstractDynamoDbStore(final DynamoDB dynamoDb) { + this.dynamoDb = dynamoDb; + } + + protected DynamoDB getDynamoDb() { + return dynamoDb; + } + + protected void executeTableWriteItemsUntilComplete(final TableWriteItems items) { + AtomicReference outcome = new AtomicReference<>(); + batchWriteItemsFirstPass.record(() -> outcome.set(dynamoDb.batchWriteItem(items))); + int attemptCount = 0; + while (!outcome.get().getUnprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) { + batchWriteItemsRetryPass.record(() -> outcome.set(dynamoDb.batchWriteItemUnprocessed(outcome.get().getUnprocessedItems()))); + ++attemptCount; + } + if (!outcome.get().getUnprocessedItems().isEmpty()) { + logger.error("Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.", attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, outcome.get().getUnprocessedItems().size()); + batchWriteItemsUnprocessed.increment(outcome.get().getUnprocessedItems().size()); + } + } + + static void writeInBatches(final Iterable items, final Consumer> action) { + final List batch = new ArrayList<>(DYNAMO_DB_MAX_BATCH_SIZE); + + for (T item : items) { + batch.add(item); + + if (batch.size() == DYNAMO_DB_MAX_BATCH_SIZE) { + action.accept(batch); + batch.clear(); + } + } + if (!batch.isEmpty()) { + action.accept(batch); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 44f0fb976..3d939df0f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -56,6 +56,7 @@ public class AccountsManager { private final DirectoryManager directory; private final DirectoryQueue directoryQueue; private final Keys keys; + private final KeysDynamoDb keysDynamoDb; private final MessagesManager messagesManager; private final UsernamesManager usernamesManager; private final ProfilesManager profilesManager; @@ -73,12 +74,13 @@ public class AccountsManager { } } - public AccountsManager(Accounts accounts, DirectoryManager directory, FaultTolerantRedisCluster cacheCluster, final DirectoryQueue directoryQueue, final Keys keys, final MessagesManager messagesManager, final UsernamesManager usernamesManager, final ProfilesManager profilesManager) { + public AccountsManager(Accounts accounts, DirectoryManager directory, FaultTolerantRedisCluster cacheCluster, final DirectoryQueue directoryQueue, final Keys keys, final KeysDynamoDb keysDynamoDb, final MessagesManager messagesManager, final UsernamesManager usernamesManager, final ProfilesManager profilesManager) { this.accounts = accounts; this.directory = directory; this.cacheCluster = cacheCluster; this.directoryQueue = directoryQueue; this.keys = keys; + this.keysDynamoDb = keysDynamoDb; this.messagesManager = messagesManager; this.usernamesManager = usernamesManager; this.profilesManager = profilesManager; @@ -150,7 +152,8 @@ public class AccountsManager { directoryQueue.deleteAccount(account); directory.remove(account.getNumber()); profilesManager.deleteAll(account.getUuid()); - keys.delete(account.getNumber()); + keys.delete(account); + keysDynamoDb.delete(account); messagesManager.clear(account.getNumber(), account.getUuid()); redisDelete(account); databaseDelete(account); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java index 17fe96f5d..c4f79487c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java @@ -5,6 +5,8 @@ package org.whispersystems.textsecuregcm.storage; +import java.util.Objects; + public class KeyRecord { private long id; @@ -41,4 +43,20 @@ public class KeyRecord { return publicKey; } + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final KeyRecord keyRecord = (KeyRecord)o; + return id == keyRecord.id && + deviceId == keyRecord.deviceId && + keyId == keyRecord.keyId && + Objects.equals(number, keyRecord.number) && + Objects.equals(publicKey, keyRecord.publicKey); + } + + @Override + public int hashCode() { + return Objects.hash(id, number, deviceId, keyId, publicKey); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java index f620ad9ff..156276d8f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java @@ -27,7 +27,7 @@ import java.util.function.Supplier; import static com.codahale.metrics.MetricRegistry.name; -public class Keys { +public class Keys implements PreKeyStore { private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final Meter fallbackMeter = metricRegistry.meter(name(Keys.class, "fallback")); @@ -49,7 +49,10 @@ public class Keys { this.retry = Retry.of("keys", retryConfiguration.toRetryConfigBuilder().build()); } - public void store(String number, long deviceId, List keys) { + @Override + public void store(Account account, long deviceId, List keys) { + final String number = account.getNumber(); + retry.executeRunnable(() -> { database.use(jdbi -> jdbi.useTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { try (Timer.Context ignored = storeTimer.time()) { @@ -74,8 +77,12 @@ public class Keys { }); } - public List get(String number, long deviceId) { - /* try { + @Override + public List take(Account account, long deviceId) { + /* + final String number = account.getNumber(); + + try { return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { try (Timer.Context ignored = getDevicetTimer.time()) { return handle.createQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *") @@ -95,8 +102,12 @@ public class Keys { return new LinkedList<>(); } - public List get(String number) { - /* try { + @Override + public List take(Account account) { + /* + final String number = account.getNumber(); + + try { return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { try (Timer.Context ignored = getTimer.time()) { return handle.createQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *") @@ -115,7 +126,10 @@ public class Keys { return new LinkedList<>(); } - public int getCount(String number, long deviceId) { + @Override + public int getCount(Account account, long deviceId) { + final String number = account.getNumber(); + return database.with(jdbi -> jdbi.withHandle(handle -> { try (Timer.Context ignored = getCountTimer.time()) { return handle.createQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id") @@ -127,7 +141,9 @@ public class Keys { })); } - public void delete(final String number) { + public void delete(final Account account) { + final String number = account.getNumber(); + database.use(jdbi -> jdbi.useHandle(handle -> { try (Timer.Context ignored = getCountTimer.time()) { handle.createUpdate("DELETE FROM keys WHERE number = :number") diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java new file mode 100644 index 000000000..1063a325d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java @@ -0,0 +1,208 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import com.amazonaws.services.dynamodbv2.document.DeleteItemOutcome; +import com.amazonaws.services.dynamodbv2.document.DynamoDB; +import com.amazonaws.services.dynamodbv2.document.Item; +import com.amazonaws.services.dynamodbv2.document.PrimaryKey; +import com.amazonaws.services.dynamodbv2.document.Table; +import com.amazonaws.services.dynamodbv2.document.TableWriteItems; +import com.amazonaws.services.dynamodbv2.document.spec.DeleteItemSpec; +import com.amazonaws.services.dynamodbv2.document.spec.QuerySpec; +import com.amazonaws.services.dynamodbv2.model.ReturnValue; +import com.amazonaws.services.dynamodbv2.model.Select; +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.DistributionSummary; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import org.whispersystems.textsecuregcm.entities.PreKey; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static com.codahale.metrics.MetricRegistry.name; + +public class KeysDynamoDb extends AbstractDynamoDbStore implements PreKeyStore { + + private final Table table; + + static final String KEY_ACCOUNT_UUID = "U"; + static final String KEY_DEVICE_ID_KEY_ID = "DK"; + static final String KEY_PUBLIC_KEY = "P"; + + private static final Timer STORE_KEYS_TIMER = Metrics.timer(name(KeysDynamoDb.class, "storeKeys")); + private static final Timer TAKE_KEY_FOR_DEVICE_TIMER = Metrics.timer(name(KeysDynamoDb.class, "takeKeyForDevice")); + private static final Timer TAKE_KEYS_FOR_ACCOUNT_TIMER = Metrics.timer(name(KeysDynamoDb.class, "takeKeyForAccount")); + private static final Timer GET_KEY_COUNT_TIMER = Metrics.timer(name(KeysDynamoDb.class, "getKeyCount")); + private static final Timer DELETE_KEYS_FOR_DEVICE_TIMER = Metrics.timer(name(KeysDynamoDb.class, "deleteKeysForDevice")); + private static final Timer DELETE_KEYS_FOR_ACCOUNT_TIMER = Metrics.timer(name(KeysDynamoDb.class, "deleteKeysForAccount")); + private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION = Metrics.summary(name(KeysDynamoDb.class, "contestedKeys")); + + public KeysDynamoDb(final DynamoDB dynamoDB, final String tableName) { + super(dynamoDB); + + this.table = dynamoDB.getTable(tableName); + } + + @Override + public void store(final Account account, final long deviceId, final List keys) { + STORE_KEYS_TIMER.record(() -> { + delete(account, deviceId); + + writeInBatches(keys, batch -> { + final TableWriteItems items = new TableWriteItems(table.getTableName()); + + for (final PreKey preKey : batch) { + items.addItemToPut(getItemFromPreKey(account.getUuid(), deviceId, preKey)); + } + + executeTableWriteItemsUntilComplete(items); + }); + }); + } + + @Override + public List take(final Account account, final long deviceId) { + return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> { + final byte[] partitionKey = getPartitionKey(account.getUuid()); + + final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") + .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) + .withValueMap(Map.of(":uuid", partitionKey, + ":sortprefix", getSortKeyPrefix(deviceId))) + .withProjectionExpression(KEY_DEVICE_ID_KEY_ID) + .withConsistentRead(false); + + int contestedKeys = 0; + + try { + for (final Item candidate : table.query(querySpec)) { + final DeleteItemSpec deleteItemSpec = new DeleteItemSpec().withPrimaryKey(KEY_ACCOUNT_UUID, partitionKey, KEY_DEVICE_ID_KEY_ID, candidate.getBinary(KEY_DEVICE_ID_KEY_ID)) + .withReturnValues(ReturnValue.ALL_OLD); + + final DeleteItemOutcome outcome = table.deleteItem(deleteItemSpec); + + if (outcome.getItem() != null) { + final PreKey preKey = getPreKeyFromItem(outcome.getItem()); + return List.of(new KeyRecord(-1, account.getNumber(), deviceId, preKey.getKeyId(), preKey.getPublicKey())); + } + + contestedKeys++; + } + + return Collections.emptyList(); + } finally { + CONTESTED_KEY_DISTRIBUTION.record(contestedKeys); + } + }); + } + + @Override + public List take(final Account account) { + return TAKE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { + final List keyRecords = new ArrayList<>(); + + for (final Device device : account.getDevices()) { + keyRecords.addAll(take(account, device.getId())); + } + + return keyRecords; + }); + } + + @Override + public int getCount(final Account account, final long deviceId) { + return GET_KEY_COUNT_TIMER.record(() -> { + final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") + .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) + .withValueMap(Map.of(":uuid", getPartitionKey(account.getUuid()), + ":sortprefix", getSortKeyPrefix(deviceId))) + .withSelect(Select.COUNT) + .withConsistentRead(false); + + // This is very confusing, but does appear to be the intended behavior. See: + // + // - https://github.com/aws/aws-sdk-java/issues/693 + // - https://github.com/aws/aws-sdk-java/issues/915 + return table.query(querySpec).firstPage().getLowLevelResult().getQueryResult().getCount(); + }); + } + + @Override + public void delete(final Account account) { + DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { + final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid") + .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID)) + .withValueMap(Map.of(":uuid", getPartitionKey(account.getUuid()))) + .withProjectionExpression(KEY_ACCOUNT_UUID + ", " + KEY_DEVICE_ID_KEY_ID) + .withConsistentRead(true); + + deleteItemsMatchingQuery(querySpec); + }); + } + + @VisibleForTesting + void delete(final Account account, final long deviceId) { + DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> { + final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") + .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) + .withValueMap(Map.of(":uuid", getPartitionKey(account.getUuid()), + ":sortprefix", getSortKeyPrefix(deviceId))) + .withProjectionExpression(KEY_ACCOUNT_UUID + ", " + KEY_DEVICE_ID_KEY_ID) + .withConsistentRead(true); + + deleteItemsMatchingQuery(querySpec); + }); + } + + private void deleteItemsMatchingQuery(final QuerySpec querySpec) { + writeInBatches(table.query(querySpec), batch -> { + final TableWriteItems writeItems = new TableWriteItems(table.getTableName()); + + for (final Item item : batch) { + writeItems.addPrimaryKeyToDelete(new PrimaryKey(KEY_ACCOUNT_UUID, item.getBinary(KEY_ACCOUNT_UUID), KEY_DEVICE_ID_KEY_ID, item.getBinary(KEY_DEVICE_ID_KEY_ID))); + } + + executeTableWriteItemsUntilComplete(writeItems); + }); + } + + private static byte[] getPartitionKey(final UUID accountUuid) { + final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); + byteBuffer.putLong(accountUuid.getMostSignificantBits()); + byteBuffer.putLong(accountUuid.getLeastSignificantBits()); + return byteBuffer.array(); + } + + private static byte[] getSortKey(final long deviceId, final long keyId) { + final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); + byteBuffer.putLong(deviceId); + byteBuffer.putLong(keyId); + return byteBuffer.array(); + } + + private static byte[] getSortKeyPrefix(final long deviceId) { + final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]); + byteBuffer.putLong(deviceId); + return byteBuffer.array(); + } + + private Item getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) { + return new Item().withBinary(KEY_ACCOUNT_UUID, getPartitionKey(accountUuid)) + .withBinary(KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId())) + .withString(KEY_PUBLIC_KEY, preKey.getPublicKey()); + } + + private PreKey getPreKeyFromItem(final Item item) { + final long keyId = ByteBuffer.wrap(item.getBinary(KEY_DEVICE_ID_KEY_ID)).getLong(8); + return new PreKey(keyId, item.getString(KEY_PUBLIC_KEY)); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java index 23449847d..149fa3e6e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java @@ -5,7 +5,6 @@ package org.whispersystems.textsecuregcm.storage; -import com.amazonaws.services.dynamodbv2.document.BatchWriteItemOutcome; import com.amazonaws.services.dynamodbv2.document.DeleteItemOutcome; import com.amazonaws.services.dynamodbv2.document.DynamoDB; import com.amazonaws.services.dynamodbv2.document.Index; @@ -17,11 +16,8 @@ import com.amazonaws.services.dynamodbv2.document.api.QueryApi; import com.amazonaws.services.dynamodbv2.document.spec.DeleteItemSpec; import com.amazonaws.services.dynamodbv2.document.spec.QuerySpec; import com.amazonaws.services.dynamodbv2.model.ReturnValue; -import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Timer; import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; @@ -33,17 +29,11 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import static com.codahale.metrics.MetricRegistry.name; -import static io.micrometer.core.instrument.Metrics.counter; import static io.micrometer.core.instrument.Metrics.timer; -public class MessagesDynamoDb { - private static final int MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE = 25; // This was arbitrarily chosen and may be entirely too high. - private static final int DYNAMO_DB_MAX_BATCH_SIZE = 25; // This limit comes from Amazon Dynamo DB itself. It will reject batch writes larger than this. - public static final int RESULT_SET_CHUNK_SIZE = 100; +public class MessagesDynamoDb extends AbstractDynamoDbStore { private static final String KEY_PARTITION = "H"; private static final String KEY_SORT = "S"; @@ -60,10 +50,6 @@ public class MessagesDynamoDb { private static final String KEY_CONTENT = "C"; private static final String KEY_TTL = "E"; - private final Logger logger = LoggerFactory.getLogger(getClass()); - private final Timer batchWriteItemsFirstPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "true"); - private final Timer batchWriteItemsRetryPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "false"); - private final Counter batchWriteItemsUnprocessed = counter(name(getClass(), "batchWriteItemsUnprocessed")); private final Timer storeTimer = timer(name(getClass(), "store")); private final Timer loadTimer = timer(name(getClass(), "load")); private final Timer deleteBySourceAndTimestamp = timer(name(getClass(), "delete", "sourceAndTimestamp")); @@ -71,18 +57,18 @@ public class MessagesDynamoDb { private final Timer deleteByAccount = timer(name(getClass(), "delete", "account")); private final Timer deleteByDevice = timer(name(getClass(), "delete", "device")); - private final DynamoDB dynamoDb; private final String tableName; private final Duration timeToLive; public MessagesDynamoDb(DynamoDB dynamoDb, String tableName, Duration timeToLive) { - this.dynamoDb = dynamoDb; + super(dynamoDb); + this.tableName = tableName; this.timeToLive = timeToLive; } public void store(final List messages, final UUID destinationAccountUuid, final long destinationDeviceId) { - storeTimer.record(() -> doInBatches(messages, (messageBatch) -> storeBatch(messageBatch, destinationAccountUuid, destinationDeviceId), DYNAMO_DB_MAX_BATCH_SIZE)); + storeTimer.record(() -> writeInBatches(messages, (messageBatch) -> storeBatch(messageBatch, destinationAccountUuid, destinationDeviceId))); } private void storeBatch(final List messages, final UUID destinationAccountUuid, final long destinationDeviceId) { @@ -135,7 +121,7 @@ public class MessagesDynamoDb { .withValueMap(Map.of(":part", partitionKey, ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId))) .withMaxResultSize(numberOfMessagesToFetch); - final Table table = dynamoDb.getTable(tableName); + final Table table = getDynamoDb().getTable(tableName); List messageEntities = new ArrayList<>(numberOfMessagesToFetch); for (Item message : table.query(querySpec)) { messageEntities.add(convertItemToOutgoingMessageEntity(message)); @@ -164,7 +150,7 @@ public class MessagesDynamoDb { ":source", source, ":timestamp", timestamp)); - final Table table = dynamoDb.getTable(tableName); + final Table table = getDynamoDb().getTable(tableName); return deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(table, partitionKey, querySpec, table); }); } @@ -179,7 +165,7 @@ public class MessagesDynamoDb { "#uuid", LOCAL_INDEX_MESSAGE_UUID_KEY_SORT)) .withValueMap(Map.of(":part", partitionKey, ":uuid", convertLocalIndexMessageUuidSortKey(messageUuid))); - final Table table = dynamoDb.getTable(tableName); + final Table table = getDynamoDb().getTable(tableName); final Index index = table.getIndex(LOCAL_INDEX_MESSAGE_UUID_NAME); return deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(table, partitionKey, querySpec, index); }); @@ -241,62 +227,24 @@ public class MessagesDynamoDb { } private void deleteRowsMatchingQuery(byte[] partitionKey, QuerySpec querySpec) { - final Table table = dynamoDb.getTable(tableName); - doInBatches(table.query(querySpec), (itemBatch) -> deleteItems(partitionKey, itemBatch), DYNAMO_DB_MAX_BATCH_SIZE); + final Table table = getDynamoDb().getTable(tableName); + writeInBatches(table.query(querySpec), (itemBatch) -> deleteItems(partitionKey, itemBatch)); } private void deleteItems(byte[] partitionKey, List items) { final TableWriteItems tableWriteItems = new TableWriteItems(tableName); - items.stream().map((x) -> new PrimaryKey(KEY_PARTITION, partitionKey, KEY_SORT, x.getBinary(KEY_SORT))).forEach(tableWriteItems::addPrimaryKeyToDelete); + items.stream().map(item -> new PrimaryKey(KEY_PARTITION, partitionKey, KEY_SORT, item.getBinary(KEY_SORT))).forEach(tableWriteItems::addPrimaryKeyToDelete); executeTableWriteItemsUntilComplete(tableWriteItems); } - private void executeTableWriteItemsUntilComplete(TableWriteItems items) { - AtomicReference outcome = new AtomicReference<>(); - batchWriteItemsFirstPass.record(() -> { - outcome.set(dynamoDb.batchWriteItem(items)); - }); - int attemptCount = 0; - while (!outcome.get().getUnprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) { - batchWriteItemsRetryPass.record(() -> { - outcome.set(dynamoDb.batchWriteItemUnprocessed(outcome.get().getUnprocessedItems())); - }); - ++attemptCount; - } - if (!outcome.get().getUnprocessedItems().isEmpty()) { - logger.error("Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.", attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, outcome.get().getUnprocessedItems().size()); - batchWriteItemsUnprocessed.increment(outcome.get().getUnprocessedItems().size()); - } - } - private long getTtlForMessage(MessageProtos.Envelope message) { return message.getServerTimestamp() / 1000 + timeToLive.getSeconds(); } - private static void doInBatches(final Iterable items, final Consumer> action, final int batchSize) { - List batch = new ArrayList<>(batchSize); - - for (T item : items) { - batch.add(item); - - if (batch.size() == batchSize) { - action.accept(batch); - batch.clear(); - } - } - if (!batch.isEmpty()) { - action.accept(batch); - } - } - private static byte[] convertPartitionKey(final UUID destinationAccountUuid) { return convertUuidToBytes(destinationAccountUuid); } - private static UUID convertPartitionKey(final byte[] bytes) { - return convertUuidFromBytes(bytes, "partition key"); - } - private static byte[] convertSortKey(final long destinationDeviceId, final long serverTimestamp, final UUID messageUuid) { ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[32]); byteBuffer.putLong(destinationDeviceId); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PreKeyStore.java new file mode 100644 index 000000000..35551789c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PreKeyStore.java @@ -0,0 +1,23 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import org.whispersystems.textsecuregcm.entities.PreKey; + +import java.util.List; + +public interface PreKeyStore { + + void store(Account account, long deviceId, List keys); + + int getCount(Account account, long deviceId); + + List take(Account account, long deviceId); + + List take(Account account); + + void delete(Account account); +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java index 53f7d1478..880b91acb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java @@ -34,6 +34,7 @@ import org.whispersystems.textsecuregcm.storage.DirectoryManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import org.whispersystems.textsecuregcm.storage.Messages; import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; @@ -97,7 +98,16 @@ public class DeleteUserCommand extends EnvironmentCommand account = accountsManager.get(user); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbRule.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbRule.java new file mode 100644 index 000000000..cee53509c --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbRule.java @@ -0,0 +1,40 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import com.amazonaws.services.dynamodbv2.document.DynamoDB; +import com.amazonaws.services.dynamodbv2.model.AttributeDefinition; +import com.amazonaws.services.dynamodbv2.model.CreateTableRequest; +import com.amazonaws.services.dynamodbv2.model.KeySchemaElement; +import com.amazonaws.services.dynamodbv2.model.ProvisionedThroughput; +import com.amazonaws.services.dynamodbv2.model.ScalarAttributeType; +import org.whispersystems.textsecuregcm.tests.util.LocalDynamoDbRule; + +public class KeysDynamoDbRule extends LocalDynamoDbRule { + public static final String TABLE_NAME = "Signal_Keys_Test"; + + @Override + protected void before() throws Throwable { + super.before(); + + final DynamoDB dynamoDB = getDynamoDB(); + + final CreateTableRequest createTableRequest = new CreateTableRequest() + .withTableName(TABLE_NAME) + .withKeySchema(new KeySchemaElement(KeysDynamoDb.KEY_ACCOUNT_UUID, "HASH"), + new KeySchemaElement(KeysDynamoDb.KEY_DEVICE_ID_KEY_ID, "RANGE")) + .withAttributeDefinitions(new AttributeDefinition(KeysDynamoDb.KEY_ACCOUNT_UUID, ScalarAttributeType.B), + new AttributeDefinition(KeysDynamoDb.KEY_DEVICE_ID_KEY_ID, ScalarAttributeType.B)) + .withProvisionedThroughput(new ProvisionedThroughput(20L, 20L)); + + dynamoDB.createTable(createTableRequest); + } + + @Override + protected void after() { + super.after(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbTest.java new file mode 100644 index 000000000..1fcc58880 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbTest.java @@ -0,0 +1,136 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; +import org.whispersystems.textsecuregcm.entities.PreKey; + +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.UUID; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KeysDynamoDbTest { + + private Account account; + private KeysDynamoDb keysDynamoDb; + + @ClassRule + public static KeysDynamoDbRule dynamoDbRule = new KeysDynamoDbRule(); + + private static final String ACCOUNT_NUMBER = "+18005551234"; + private static final long DEVICE_ID = 1L; + + @Before + public void setup() { + keysDynamoDb = new KeysDynamoDb(dynamoDbRule.getDynamoDB(), KeysDynamoDbRule.TABLE_NAME); + + account = mock(Account.class); + when(account.getNumber()).thenReturn(ACCOUNT_NUMBER); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + } + + @Test + public void testStore() { + assertEquals("Initial pre-key count for an account should be zero", + 0, keysDynamoDb.getCount(account, DEVICE_ID)); + + keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); + assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID)); + + keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); + assertEquals("Repeatedly storing same key should have no effect", + 1, keysDynamoDb.getCount(account, DEVICE_ID)); + + keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(2, "different-public-key"))); + assertEquals("Inserting a new key should overwrite all prior keys for the given account/device", + 1, keysDynamoDb.getCount(account, DEVICE_ID)); + + keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(3, "third-public-key"), new PreKey(4, "fourth-public-key"))); + assertEquals("Inserting multiple new keys should overwrite all prior keys for the given account/device", + 2, keysDynamoDb.getCount(account, DEVICE_ID)); + } + + @Test + public void testTakeAccount() { + final Device firstDevice = mock(Device.class); + final Device secondDevice = mock(Device.class); + + when(firstDevice.getId()).thenReturn(DEVICE_ID); + when(secondDevice.getId()).thenReturn(DEVICE_ID + 1); + when(account.getDevices()).thenReturn(Set.of(firstDevice, secondDevice)); + + assertEquals(Collections.emptyList(), keysDynamoDb.take(account)); + + final PreKey firstDevicePreKey = new PreKey(1, "public-key"); + final PreKey secondDevicePreKey = new PreKey(2, "second-key"); + + keysDynamoDb.store(account, DEVICE_ID, List.of(firstDevicePreKey)); + keysDynamoDb.store(account, DEVICE_ID + 1, List.of(secondDevicePreKey)); + + final Set expectedKeys = Set.of( + new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID, firstDevicePreKey.getKeyId(), firstDevicePreKey.getPublicKey()), + new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID + 1, secondDevicePreKey.getKeyId(), secondDevicePreKey.getPublicKey())); + + assertEquals(expectedKeys, new HashSet<>(keysDynamoDb.take(account))); + assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); + assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID + 1)); + } + + @Test + public void testTakeAccountAndDeviceId() { + assertEquals(Collections.emptyList(), keysDynamoDb.take(account, DEVICE_ID)); + + final PreKey preKey = new PreKey(1, "public-key"); + + keysDynamoDb.store(account, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key"))); + assertEquals(List.of(new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID, preKey.getKeyId(), preKey.getPublicKey())), keysDynamoDb.take(account, DEVICE_ID)); + assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID)); + } + + @Test + public void testGetCount() { + assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); + + keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); + assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID)); + } + + @Test + public void testDeleteByAccount() { + keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); + keysDynamoDb.store(account, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); + + assertEquals(2, keysDynamoDb.getCount(account, DEVICE_ID)); + assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID + 1)); + + keysDynamoDb.delete(account); + + assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); + assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID + 1)); + } + + @Test + public void testDeleteByAccountAndDevice() { + keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); + keysDynamoDb.store(account, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); + + assertEquals(2, keysDynamoDb.getCount(account, DEVICE_ID)); + assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID + 1)); + + keysDynamoDb.delete(account, DEVICE_ID); + + assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); + assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID + 1)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java similarity index 96% rename from service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index afd7839d4..bb2ec3d7a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -46,7 +46,7 @@ import io.dropwizard.testing.junit.ResourceTestRule; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.*; -public class KeyControllerTest { +public class KeysControllerTest { private static final String EXISTS_NUMBER = "+14152222222"; private static final UUID EXISTS_UUID = UUID.randomUUID(); @@ -141,18 +141,16 @@ public class KeyControllerTest { List singleDevice = new LinkedList<>(); singleDevice.add(SAMPLE_KEY); - when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(singleDevice); - - when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(new LinkedList<>()); + when(keys.take(eq(existsAccount), eq(1L))).thenReturn(singleDevice); List multiDevice = new LinkedList<>(); multiDevice.add(SAMPLE_KEY); multiDevice.add(SAMPLE_KEY2); multiDevice.add(SAMPLE_KEY3); multiDevice.add(SAMPLE_KEY4); - when(keys.get(EXISTS_NUMBER)).thenReturn(multiDevice); + when(keys.take(existsAccount)).thenReturn(multiDevice); - when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5); + when(keys.getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L))).thenReturn(5); when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY); when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null); @@ -169,7 +167,7 @@ public class KeyControllerTest { assertThat(result.getCount()).isEqualTo(4); - verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L)); + verify(keys).getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L)); } @Test @@ -183,7 +181,7 @@ public class KeyControllerTest { assertThat(result.getCount()).isEqualTo(4); - verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L)); + verify(keys).getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L)); } @@ -283,7 +281,7 @@ public class KeyControllerTest { assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); - verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); + verify(keys).take(eq(existsAccount), eq(1L)); verifyNoMoreInteractions(keys); } @@ -301,7 +299,7 @@ public class KeyControllerTest { assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); - verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); + verify(keys).take(eq(existsAccount), eq(1L)); verifyNoMoreInteractions(keys); } @@ -320,7 +318,7 @@ public class KeyControllerTest { assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); - verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); + verify(keys).take(eq(existsAccount), eq(1L)); verifyNoMoreInteractions(keys); } @@ -338,7 +336,7 @@ public class KeyControllerTest { assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); - verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); + verify(keys).take(eq(existsAccount), eq(1L)); verifyNoMoreInteractions(keys); } @@ -414,7 +412,7 @@ public class KeyControllerTest { assertThat(signedPreKey).isNull(); assertThat(deviceId).isEqualTo(4); - verify(keys).get(eq(EXISTS_NUMBER)); + verify(keys).take(eq(existsAccount)); verifyNoMoreInteractions(keys); } @@ -464,7 +462,7 @@ public class KeyControllerTest { assertThat(signedPreKey).isNull(); assertThat(deviceId).isEqualTo(4); - verify(keys).get(eq(EXISTS_NUMBER)); + verify(keys).take(eq(existsAccount)); verifyNoMoreInteractions(keys); } @@ -533,7 +531,7 @@ public class KeyControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor listCaptor = ArgumentCaptor.forClass(List.class); - verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture()); + verify(keys).store(eq(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture()); List capturedList = listCaptor.getValue(); assertThat(capturedList.size()).isEqualTo(1); @@ -567,7 +565,7 @@ public class KeyControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor listCaptor = ArgumentCaptor.forClass(List.class); - verify(keys).store(eq(AuthHelper.DISABLED_NUMBER), eq(1L), listCaptor.capture()); + verify(keys).store(eq(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture()); List capturedList = listCaptor.getValue(); assertThat(capturedList.size()).isEqualTo(1); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java index c259edbbc..72dd3b989 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java @@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.tests.storage; import io.lettuce.core.RedisException; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import org.junit.Test; -import org.whispersystems.textsecuregcm.entities.Profile; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.storage.Account; @@ -16,6 +15,7 @@ import org.whispersystems.textsecuregcm.storage.Accounts; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.DirectoryManager; import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.UsernamesManager; @@ -46,6 +46,7 @@ public class AccountsManagerTest { DirectoryManager directoryManager = mock(DirectoryManager.class); DirectoryQueue directoryQueue = mock(DirectoryQueue.class); Keys keys = mock(Keys.class); + KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class); MessagesManager messagesManager = mock(MessagesManager.class); UsernamesManager usernamesManager = mock(UsernamesManager.class); ProfilesManager profilesManager = mock(ProfilesManager.class); @@ -55,7 +56,7 @@ public class AccountsManagerTest { when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(uuid.toString()); when(commands.get(eq("Account3::" + uuid.toString()))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}"); - AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); + AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); Optional account = accountsManager.get("+14152222222"); assertTrue(account.isPresent()); @@ -76,6 +77,7 @@ public class AccountsManagerTest { DirectoryManager directoryManager = mock(DirectoryManager.class); DirectoryQueue directoryQueue = mock(DirectoryQueue.class); Keys keys = mock(Keys.class); + KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class); MessagesManager messagesManager = mock(MessagesManager.class); UsernamesManager usernamesManager = mock(UsernamesManager.class); ProfilesManager profilesManager = mock(ProfilesManager.class); @@ -84,7 +86,7 @@ public class AccountsManagerTest { when(commands.get(eq("Account3::" + uuid.toString()))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}"); - AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); + AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); Optional account = accountsManager.get(uuid); assertTrue(account.isPresent()); @@ -106,6 +108,7 @@ public class AccountsManagerTest { DirectoryManager directoryManager = mock(DirectoryManager.class); DirectoryQueue directoryQueue = mock(DirectoryQueue.class); Keys keys = mock(Keys.class); + KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class); MessagesManager messagesManager = mock(MessagesManager.class); UsernamesManager usernamesManager = mock(UsernamesManager.class); ProfilesManager profilesManager = mock(ProfilesManager.class); @@ -115,7 +118,7 @@ public class AccountsManagerTest { when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(null); when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); - AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); + AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); Optional retrieved = accountsManager.get("+14152222222"); assertTrue(retrieved.isPresent()); @@ -138,6 +141,7 @@ public class AccountsManagerTest { DirectoryManager directoryManager = mock(DirectoryManager.class); DirectoryQueue directoryQueue = mock(DirectoryQueue.class); Keys keys = mock(Keys.class); + KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class); MessagesManager messagesManager = mock(MessagesManager.class); UsernamesManager usernamesManager = mock(UsernamesManager.class); ProfilesManager profilesManager = mock(ProfilesManager.class); @@ -147,7 +151,7 @@ public class AccountsManagerTest { when(commands.get(eq("Account3::" + uuid))).thenReturn(null); when(accounts.get(eq(uuid))).thenReturn(Optional.of(account)); - AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); + AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); Optional retrieved = accountsManager.get(uuid); assertTrue(retrieved.isPresent()); @@ -170,6 +174,7 @@ public class AccountsManagerTest { DirectoryManager directoryManager = mock(DirectoryManager.class); DirectoryQueue directoryQueue = mock(DirectoryQueue.class); Keys keys = mock(Keys.class); + KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class); MessagesManager messagesManager = mock(MessagesManager.class); UsernamesManager usernamesManager = mock(UsernamesManager.class); ProfilesManager profilesManager = mock(ProfilesManager.class); @@ -179,7 +184,7 @@ public class AccountsManagerTest { when(commands.get(eq("AccountMap::+14152222222"))).thenThrow(new RedisException("Connection lost!")); when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); - AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); + AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); Optional retrieved = accountsManager.get("+14152222222"); assertTrue(retrieved.isPresent()); @@ -202,6 +207,7 @@ public class AccountsManagerTest { DirectoryManager directoryManager = mock(DirectoryManager.class); DirectoryQueue directoryQueue = mock(DirectoryQueue.class); Keys keys = mock(Keys.class); + KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class); MessagesManager messagesManager = mock(MessagesManager.class); UsernamesManager usernamesManager = mock(UsernamesManager.class); ProfilesManager profilesManager = mock(ProfilesManager.class); @@ -211,7 +217,7 @@ public class AccountsManagerTest { when(commands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!")); when(accounts.get(eq(uuid))).thenReturn(Optional.of(account)); - AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); + AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); Optional retrieved = accountsManager.get(uuid); assertTrue(retrieved.isPresent()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java index f9f5f30e1..437188628 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java @@ -24,6 +24,7 @@ import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguratio import org.whispersystems.textsecuregcm.configuration.AccountsDatabaseConfiguration; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; import org.whispersystems.textsecuregcm.storage.KeyRecord; import org.whispersystems.textsecuregcm.storage.Keys; @@ -41,13 +42,14 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -@Ignore public class KeysTest { @Rule public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml")); - private Keys keys; + private Account firstAccount; + private Account secondAccount; + private Keys keys; @Before public void setup() { @@ -56,6 +58,12 @@ public class KeysTest { new CircuitBreakerConfiguration()); this.keys = new Keys(faultTolerantDatabase, new RetryConfiguration()); + + this.firstAccount = mock(Account.class); + this.secondAccount = mock(Account.class); + + when(firstAccount.getNumber()).thenReturn("+14152222222"); + when(secondAccount.getNumber()).thenReturn("+14151111111"); } @@ -79,18 +87,18 @@ public class KeysTest { anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); } - keys.store("+14152222222", 1, deviceOnePreKeys); - keys.store("+14152222222", 2, deviceTwoPreKeys); + keys.store(firstAccount, 1, deviceOnePreKeys); + keys.store(firstAccount, 2, deviceTwoPreKeys); - keys.store("+14151111111", 1, oldAnotherDeviceOnePrKeys); - keys.store("+14151111111", 1, anotherDeviceOnePreKeys); - keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); + keys.store(secondAccount, 1, oldAnotherDeviceOnePrKeys); + keys.store(secondAccount, 1, anotherDeviceOnePreKeys); + keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM keys WHERE number = ? AND device_id = ? ORDER BY key_id"); - verifyStoredState(statement, "+14152222222", 1); - verifyStoredState(statement, "+14152222222", 2); - verifyStoredState(statement, "+14151111111", 1); - verifyStoredState(statement, "+14151111111", 2); + verifyStoredState(statement, firstAccount, 1); + verifyStoredState(statement, firstAccount, 2); + verifyStoredState(statement, secondAccount, 1); + verifyStoredState(statement, secondAccount, 2); } @Test @@ -102,11 +110,12 @@ public class KeysTest { } - keys.store("+14152222222", 1, deviceOnePreKeys); + keys.store(firstAccount, 1, deviceOnePreKeys); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); } + @Ignore @Test public void testGetForDevice() { List deviceOnePreKeys = new LinkedList<>(); @@ -125,45 +134,46 @@ public class KeysTest { anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); } - keys.store("+14152222222", 1, deviceOnePreKeys); - keys.store("+14152222222", 2, deviceTwoPreKeys); + keys.store(firstAccount, 1, deviceOnePreKeys); + keys.store(firstAccount, 2, deviceTwoPreKeys); - keys.store("+14151111111", 1, anotherDeviceOnePreKeys); - keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); + keys.store(secondAccount, 1, anotherDeviceOnePreKeys); + keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); - List records = keys.get("+14152222222", 1); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); + List records = keys.take(firstAccount, 1); assertThat(records.size()).isEqualTo(1); assertThat(records.get(0).getKeyId()).isEqualTo(1); assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey1"); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99); - assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); - assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); - assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(99); + assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); + assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); + assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); - records = keys.get("+14152222222", 1); + records = keys.take(firstAccount, 1); assertThat(records.size()).isEqualTo(1); assertThat(records.get(0).getKeyId()).isEqualTo(2); assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey2"); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); - assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); - assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); - assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(98); + assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); + assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); + assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); - records = keys.get("+14152222222", 2); + records = keys.take(firstAccount, 2); assertThat(records.size()).isEqualTo(1); assertThat(records.get(0).getKeyId()).isEqualTo(1); assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device2PublicKey1"); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); - assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99); - assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); - assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(98); + assertThat(keys.getCount(firstAccount, 2)).isEqualTo(99); + assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); + assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); } + @Ignore @Test public void testGetForAllDevices() { List deviceOnePreKeys = new LinkedList<>(); @@ -184,18 +194,18 @@ public class KeysTest { anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i)); } - keys.store("+14152222222", 1, deviceOnePreKeys); - keys.store("+14152222222", 2, deviceTwoPreKeys); + keys.store(firstAccount, 1, deviceOnePreKeys); + keys.store(firstAccount, 2, deviceTwoPreKeys); - keys.store("+14151111111", 1, anotherDeviceOnePreKeys); - keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); - keys.store("+14151111111", 3, anotherDeviceThreePreKeys); + keys.store(secondAccount, 1, anotherDeviceOnePreKeys); + keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); + keys.store(secondAccount, 3, anotherDeviceThreePreKeys); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); - assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); - List records = keys.get("+14152222222"); + List records = keys.take(firstAccount); assertThat(records.size()).isEqualTo(2); assertThat(records.get(0).getKeyId()).isEqualTo(1); @@ -204,10 +214,10 @@ public class KeysTest { assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey1"))).isTrue(); assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey1"))).isTrue(); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99); - assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(99); + assertThat(keys.getCount(firstAccount, 2)).isEqualTo(99); - records = keys.get("+14152222222"); + records = keys.take(firstAccount); assertThat(records.size()).isEqualTo(2); assertThat(records.get(0).getKeyId()).isEqualTo(2); @@ -216,11 +226,11 @@ public class KeysTest { assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey2"))).isTrue(); assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey2"))).isTrue(); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); - assertThat(keys.getCount("+14152222222", 2)).isEqualTo(98); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(98); + assertThat(keys.getCount(firstAccount, 2)).isEqualTo(98); - records = keys.get("+14151111111"); + records = keys.take(secondAccount); assertThat(records.size()).isEqualTo(3); assertThat(records.get(0).getKeyId()).isEqualTo(1); @@ -231,11 +241,12 @@ public class KeysTest { assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device2PublicKey1"))).isTrue(); assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device3PublicKey1"))).isTrue(); - assertThat(keys.getCount("+14151111111", 1)).isEqualTo(99); - assertThat(keys.getCount("+14151111111", 2)).isEqualTo(99); - assertThat(keys.getCount("+14151111111", 3)).isEqualTo(99); + assertThat(keys.getCount(secondAccount, 1)).isEqualTo(99); + assertThat(keys.getCount(secondAccount, 2)).isEqualTo(99); + assertThat(keys.getCount(secondAccount, 3)).isEqualTo(99); } + @Ignore @Test public void testGetForAllDevicesParallel() throws InterruptedException { List deviceOnePreKeys = new LinkedList<>(); @@ -246,11 +257,11 @@ public class KeysTest { deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i)); } - keys.store("+14152222222", 1, deviceOnePreKeys); - keys.store("+14152222222", 2, deviceTwoPreKeys); + keys.store(firstAccount, 1, deviceOnePreKeys); + keys.store(firstAccount, 2, deviceTwoPreKeys); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); - assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); List threads = new LinkedList<>(); @@ -260,7 +271,7 @@ public class KeysTest { final int MAX_RETRIES = 5; for (int retryAttempt = 0; results == null && retryAttempt < MAX_RETRIES; ++retryAttempt) { try { - results = keys.get("+14152222222"); + results = keys.take(firstAccount); } catch (UnableToExecuteStatementException e) { if (retryAttempt == MAX_RETRIES - 1) { throw e; @@ -278,8 +289,8 @@ public class KeysTest { thread.join(); } - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(80); - assertThat(keys.getCount("+14152222222",2)).isEqualTo(80); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(80); + assertThat(keys.getCount(firstAccount,2)).isEqualTo(80); } @Test @@ -302,32 +313,32 @@ public class KeysTest { anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i)); } - keys.store("+14152222222", 1, deviceOnePreKeys); - keys.store("+14152222222", 2, deviceTwoPreKeys); + keys.store(firstAccount, 1, deviceOnePreKeys); + keys.store(firstAccount, 2, deviceTwoPreKeys); - keys.store("+14151111111", 1, anotherDeviceOnePreKeys); - keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); - keys.store("+14151111111", 3, anotherDeviceThreePreKeys); + keys.store(secondAccount, 1, anotherDeviceOnePreKeys); + keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); + keys.store(secondAccount, 3, anotherDeviceThreePreKeys); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); - assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); - assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); - assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); - assertThat(keys.getCount("+14151111111", 3)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); + assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); + assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); + assertThat(keys.getCount(secondAccount, 3)).isEqualTo(100); - keys.delete("+14152222222"); + keys.delete(firstAccount); - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(0); - assertThat(keys.getCount("+14152222222", 2)).isEqualTo(0); - assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); - assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); - assertThat(keys.getCount("+14151111111", 3)).isEqualTo(100); + assertThat(keys.getCount(firstAccount, 1)).isEqualTo(0); + assertThat(keys.getCount(firstAccount, 2)).isEqualTo(0); + assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); + assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); + assertThat(keys.getCount(secondAccount, 3)).isEqualTo(100); } @Test public void testEmptyKeyGet() { - List records = keys.get("+14152222222"); + List records = keys.take(firstAccount); assertThat(records.isEmpty()).isTrue(); } @@ -361,21 +372,21 @@ public class KeysTest { } try { - keys.store("+14152222222", 1, deviceOnePreKeys); + keys.store(firstAccount, 1, deviceOnePreKeys); throw new AssertionError(); } catch (TransactionException e) { // good } try { - keys.store("+14152222222", 1, deviceOnePreKeys); + keys.store(firstAccount, 1, deviceOnePreKeys); throw new AssertionError(); } catch (TransactionException e) { // good } try { - keys.store("+14152222222", 1, deviceOnePreKeys); + keys.store(firstAccount, 1, deviceOnePreKeys); throw new AssertionError(); } catch (CallNotPermittedException e) { // good @@ -384,7 +395,7 @@ public class KeysTest { Thread.sleep(1100); try { - keys.store("+14152222222", 1, deviceOnePreKeys); + keys.store(firstAccount, 1, deviceOnePreKeys); throw new AssertionError(); } catch (TransactionException e) { // good @@ -401,7 +412,10 @@ public class KeysTest { Keys keys = new Keys(new FaultTolerantDatabase("testBreaker", jdbi, new CircuitBreakerConfiguration()), new RetryConfiguration()); // We're happy as long as nothing throws an exception - keys.store("+18005551234", 1, Collections.emptyList()); + Account account = mock(Account.class); + when(account.getNumber()).thenReturn("+18005551234"); + + keys.store(account, 1, Collections.emptyList()); } @Test @@ -414,12 +428,15 @@ public class KeysTest { Keys keys = new Keys(new FaultTolerantDatabase("testBreaker", jdbi, new CircuitBreakerConfiguration()), new RetryConfiguration()); - assertThat(keys.get("+18005551234")).isEqualTo(Collections.emptyList()); - assertThat(keys.get("+18005551234", 1)).isEqualTo(Collections.emptyList()); + Account account = mock(Account.class); + when(account.getNumber()).thenReturn("+18005551234"); + + assertThat(keys.take(account)).isEqualTo(Collections.emptyList()); + assertThat(keys.take(account, 1)).isEqualTo(Collections.emptyList()); } - private void verifyStoredState(PreparedStatement statement, String number, int deviceId) throws SQLException { - statement.setString(1, number); + private void verifyStoredState(PreparedStatement statement, Account account, int deviceId) throws SQLException { + statement.setString(1, account.getNumber()); statement.setInt(2, deviceId); ResultSet resultSet = statement.executeQuery(); @@ -431,7 +448,7 @@ public class KeysTest { assertThat(keyId).isEqualTo(rowCount); - assertThat(publicKey).isEqualTo(number + "Device" + deviceId + "PublicKey" + rowCount); + assertThat(publicKey).isEqualTo(account.getNumber() + "Device" + deviceId + "PublicKey" + rowCount); rowCount++; }