Add a Dynamo-backed key store.
This commit is contained in:
		
							parent
							
								
									426e6923ac
								
							
						
					
					
						commit
						d4d9403829
					
				|  | @ -14,6 +14,7 @@ import org.whispersystems.textsecuregcm.configuration.AwsAttachmentsConfiguratio | |||
| import org.whispersystems.textsecuregcm.configuration.CdnConfiguration; | ||||
| import org.whispersystems.textsecuregcm.configuration.DatabaseConfiguration; | ||||
| import org.whispersystems.textsecuregcm.configuration.DirectoryConfiguration; | ||||
| import org.whispersystems.textsecuregcm.configuration.DynamoDbConfiguration; | ||||
| import org.whispersystems.textsecuregcm.configuration.GcmConfiguration; | ||||
| import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguration; | ||||
| import org.whispersystems.textsecuregcm.configuration.AccountsDatabaseConfiguration; | ||||
|  | @ -128,6 +129,11 @@ public class WhisperServerConfiguration extends Configuration { | |||
|   @JsonProperty | ||||
|   private MessageDynamoDbConfiguration messageDynamoDb; | ||||
| 
 | ||||
|   @Valid | ||||
|   @NotNull | ||||
|   @JsonProperty | ||||
|   private DynamoDbConfiguration keysDynamoDb; | ||||
| 
 | ||||
|   @Valid | ||||
|   @NotNull | ||||
|   @JsonProperty | ||||
|  | @ -306,6 +312,10 @@ public class WhisperServerConfiguration extends Configuration { | |||
|     return messageDynamoDb; | ||||
|   } | ||||
| 
 | ||||
|   public DynamoDbConfiguration getKeysDynamoDbConfiguration() { | ||||
|     return keysDynamoDb; | ||||
|   } | ||||
| 
 | ||||
|   public DatabaseConfiguration getMessageStoreConfiguration() { | ||||
|     return messageStore; | ||||
|   } | ||||
|  |  | |||
|  | @ -126,6 +126,7 @@ import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; | |||
| import org.whispersystems.textsecuregcm.storage.FeatureFlags; | ||||
| import org.whispersystems.textsecuregcm.storage.FeatureFlagsManager; | ||||
| import org.whispersystems.textsecuregcm.storage.Keys; | ||||
| import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; | ||||
| import org.whispersystems.textsecuregcm.storage.MessagePersister; | ||||
| import org.whispersystems.textsecuregcm.storage.Messages; | ||||
| import org.whispersystems.textsecuregcm.storage.MessagesCache; | ||||
|  | @ -275,7 +276,16 @@ public class WhisperServerService extends Application<WhisperServerConfiguration | |||
|             .withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) config.getMessageDynamoDbConfiguration().getClientExecutionTimeout().toMillis())) | ||||
|                                                               .withRequestTimeout((int) config.getMessageDynamoDbConfiguration().getClientRequestTimeout().toMillis())) | ||||
|             .withCredentials(InstanceProfileCredentialsProvider.getInstance()); | ||||
| 
 | ||||
|     AmazonDynamoDBClientBuilder keysDynamoDbClientBuilder = AmazonDynamoDBClientBuilder | ||||
|             .standard() | ||||
|             .withRegion(config.getKeysDynamoDbConfiguration().getRegion()) | ||||
|             .withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) config.getKeysDynamoDbConfiguration().getClientExecutionTimeout().toMillis())) | ||||
|                                                               .withRequestTimeout((int) config.getKeysDynamoDbConfiguration().getClientRequestTimeout().toMillis())) | ||||
|             .withCredentials(InstanceProfileCredentialsProvider.getInstance()); | ||||
| 
 | ||||
|     DynamoDB messageDynamoDb = new DynamoDB(messageDynamoDbClientBuilder.build()); | ||||
|     DynamoDB preKeyDynamoDb = new DynamoDB(keysDynamoDbClientBuilder.build()); | ||||
| 
 | ||||
|     Accounts          accounts          = new Accounts(accountDatabase); | ||||
|     PendingAccounts   pendingAccounts   = new PendingAccounts(accountDatabase); | ||||
|  | @ -284,6 +294,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration | |||
|     ReservedUsernames reservedUsernames = new ReservedUsernames(accountDatabase); | ||||
|     Profiles          profiles          = new Profiles(accountDatabase); | ||||
|     Keys              keys              = new Keys(accountDatabase, config.getAccountsDatabaseConfiguration().getKeyOperationRetryConfiguration()); | ||||
|     KeysDynamoDb      keysDynamoDb      = new KeysDynamoDb(preKeyDynamoDb, config.getKeysDynamoDbConfiguration().getTableName()); | ||||
|     Messages          messages          = new Messages(messageDatabase); | ||||
|     MessagesDynamoDb  messagesDynamoDb  = new MessagesDynamoDb(messageDynamoDb, config.getMessageDynamoDbConfiguration().getTableName(), config.getMessageDynamoDbConfiguration().getTimeToLive()); | ||||
|     AbusiveHostRules  abusiveHostRules  = new AbusiveHostRules(abuseDatabase); | ||||
|  | @ -338,7 +349,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration | |||
|     MessagesCache              messagesCache              = new MessagesCache(messagesCluster, messagesCluster, keyspaceNotificationDispatchExecutor); | ||||
|     PushLatencyManager         pushLatencyManager         = new PushLatencyManager(metricsCluster); | ||||
|     MessagesManager            messagesManager            = new MessagesManager(messages, messagesDynamoDb, messagesCache, pushLatencyManager, experimentEnrollmentManager); | ||||
|     AccountsManager            accountsManager            = new AccountsManager(accounts, directory, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); | ||||
|     AccountsManager            accountsManager            = new AccountsManager(accounts, directory, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||
|     RemoteConfigsManager       remoteConfigsManager       = new RemoteConfigsManager(remoteConfigs); | ||||
|     FeatureFlagsManager        featureFlagsManager        = new FeatureFlagsManager(featureFlags, recurringJobExecutor); | ||||
|     DeadLetterHandler          deadLetterHandler          = new DeadLetterHandler(accountsManager, messagesManager); | ||||
|  |  | |||
|  | @ -0,0 +1,50 @@ | |||
| /* | ||||
|  * Copyright 2021 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.configuration; | ||||
| 
 | ||||
| import com.fasterxml.jackson.annotation.JsonProperty; | ||||
| 
 | ||||
| import javax.validation.Valid; | ||||
| import javax.validation.constraints.NotBlank; | ||||
| import javax.validation.constraints.NotEmpty; | ||||
| import java.time.Duration; | ||||
| 
 | ||||
| public class DynamoDbConfiguration { | ||||
| 
 | ||||
|     @JsonProperty | ||||
|     @NotBlank | ||||
|     private String region; | ||||
| 
 | ||||
|     @JsonProperty | ||||
|     @NotBlank | ||||
|     private String tableName; | ||||
| 
 | ||||
|     @JsonProperty | ||||
|     private Duration clientExecutionTimeout = Duration.ofSeconds(30); | ||||
| 
 | ||||
|     @JsonProperty | ||||
|     private Duration clientRequestTimeout = Duration.ofSeconds(10); | ||||
| 
 | ||||
|     @Valid | ||||
|     @NotEmpty | ||||
|     public String getRegion() { | ||||
|         return region; | ||||
|     } | ||||
| 
 | ||||
|     @Valid | ||||
|     @NotEmpty | ||||
|     public String getTableName() { | ||||
|         return tableName; | ||||
|     } | ||||
| 
 | ||||
|     public Duration getClientExecutionTimeout() { | ||||
|         return clientExecutionTimeout; | ||||
|     } | ||||
| 
 | ||||
|     public Duration getClientRequestTimeout() { | ||||
|         return clientRequestTimeout; | ||||
|     } | ||||
| } | ||||
|  | @ -9,35 +9,12 @@ import javax.validation.Valid; | |||
| import javax.validation.constraints.NotEmpty; | ||||
| import java.time.Duration; | ||||
| 
 | ||||
| public class MessageDynamoDbConfiguration { | ||||
|   private String region; | ||||
|   private String tableName; | ||||
| public class MessageDynamoDbConfiguration extends DynamoDbConfiguration { | ||||
| 
 | ||||
|   private Duration timeToLive = Duration.ofDays(7); | ||||
|   private Duration clientExecutionTimeout = Duration.ofSeconds(30); | ||||
|   private Duration clientRequestTimeout = Duration.ofSeconds(10); | ||||
| 
 | ||||
|   @Valid | ||||
|   @NotEmpty | ||||
|   public String getRegion() { | ||||
|     return region; | ||||
|   } | ||||
| 
 | ||||
|   @Valid | ||||
|   @NotEmpty | ||||
|   public String getTableName() { | ||||
|     return tableName; | ||||
|   } | ||||
| 
 | ||||
|   @Valid | ||||
|   public Duration getTimeToLive() { | ||||
|     return timeToLive; | ||||
|   } | ||||
| 
 | ||||
|   public Duration getClientExecutionTimeout() { | ||||
|     return clientExecutionTimeout; | ||||
|   } | ||||
| 
 | ||||
|   public Duration getClientRequestTimeout() { | ||||
|     return clientRequestTimeout; | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -62,7 +62,7 @@ public class KeysController { | |||
|   @GET | ||||
|   @Produces(MediaType.APPLICATION_JSON) | ||||
|   public PreKeyCount getStatus(@Auth Account account) { | ||||
|     int count = keys.getCount(account.getNumber(), account.getAuthenticatedDevice().get().getId()); | ||||
|     int count = keys.getCount(account, account.getAuthenticatedDevice().get().getId()); | ||||
| 
 | ||||
|     if (count > 0) { | ||||
|       count = count - 1; | ||||
|  | @ -98,7 +98,7 @@ public class KeysController { | |||
|       } | ||||
|     } | ||||
| 
 | ||||
|     keys.store(account.getNumber(), device.getId(), preKeys.getPreKeys()); | ||||
|     keys.store(account, device.getId(), preKeys.getPreKeys()); | ||||
|   } | ||||
| 
 | ||||
|   @Timed | ||||
|  | @ -179,12 +179,12 @@ public class KeysController { | |||
|   private List<KeyRecord> getLocalKeys(Account destination, String deviceIdSelector) { | ||||
|     try { | ||||
|       if (deviceIdSelector.equals("*")) { | ||||
|         return keys.get(destination.getNumber()); | ||||
|         return keys.take(destination); | ||||
|       } | ||||
| 
 | ||||
|       long deviceId = Long.parseLong(deviceIdSelector); | ||||
| 
 | ||||
|       return keys.get(destination.getNumber(), deviceId); | ||||
|       return keys.take(destination, deviceId); | ||||
|     } catch (NumberFormatException e) { | ||||
|       throw new WebApplicationException(Response.status(422).build()); | ||||
|     } | ||||
|  |  | |||
|  | @ -0,0 +1,76 @@ | |||
| /* | ||||
|  * Copyright 2021 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.storage; | ||||
| 
 | ||||
| import com.amazonaws.services.dynamodbv2.document.BatchWriteItemOutcome; | ||||
| import com.amazonaws.services.dynamodbv2.document.DynamoDB; | ||||
| import com.amazonaws.services.dynamodbv2.document.TableWriteItems; | ||||
| import io.micrometer.core.instrument.Counter; | ||||
| import io.micrometer.core.instrument.Timer; | ||||
| import org.slf4j.Logger; | ||||
| import org.slf4j.LoggerFactory; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| import java.util.concurrent.atomic.AtomicReference; | ||||
| import java.util.function.Consumer; | ||||
| 
 | ||||
| import static com.codahale.metrics.MetricRegistry.name; | ||||
| import static io.micrometer.core.instrument.Metrics.counter; | ||||
| import static io.micrometer.core.instrument.Metrics.timer; | ||||
| 
 | ||||
| public class AbstractDynamoDbStore { | ||||
| 
 | ||||
|     private final DynamoDB dynamoDb; | ||||
| 
 | ||||
|     private final Timer   batchWriteItemsFirstPass   = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "true"); | ||||
|     private final Timer   batchWriteItemsRetryPass   = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "false"); | ||||
|     private final Counter batchWriteItemsUnprocessed = counter(name(getClass(), "batchWriteItemsUnprocessed")); | ||||
| 
 | ||||
|     private final Logger logger = LoggerFactory.getLogger(getClass()); | ||||
| 
 | ||||
|     private static final int MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE = 25;  // This was arbitrarily chosen and may be entirely too high. | ||||
|     public static final int DYNAMO_DB_MAX_BATCH_SIZE = 25;  // This limit comes from Amazon Dynamo DB itself. It will reject batch writes larger than this. | ||||
|     public static final int RESULT_SET_CHUNK_SIZE = 100; | ||||
| 
 | ||||
|     public AbstractDynamoDbStore(final DynamoDB dynamoDb) { | ||||
|         this.dynamoDb = dynamoDb; | ||||
|     } | ||||
| 
 | ||||
|     protected DynamoDB getDynamoDb() { | ||||
|         return dynamoDb; | ||||
|     } | ||||
| 
 | ||||
|     protected void executeTableWriteItemsUntilComplete(final TableWriteItems items) { | ||||
|         AtomicReference<BatchWriteItemOutcome> outcome = new AtomicReference<>(); | ||||
|         batchWriteItemsFirstPass.record(() -> outcome.set(dynamoDb.batchWriteItem(items))); | ||||
|         int attemptCount = 0; | ||||
|         while (!outcome.get().getUnprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) { | ||||
|             batchWriteItemsRetryPass.record(() -> outcome.set(dynamoDb.batchWriteItemUnprocessed(outcome.get().getUnprocessedItems()))); | ||||
|             ++attemptCount; | ||||
|         } | ||||
|         if (!outcome.get().getUnprocessedItems().isEmpty()) { | ||||
|             logger.error("Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.", attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, outcome.get().getUnprocessedItems().size()); | ||||
|             batchWriteItemsUnprocessed.increment(outcome.get().getUnprocessedItems().size()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     static <T> void writeInBatches(final Iterable<T> items, final Consumer<List<T>> action) { | ||||
|         final List<T> batch = new ArrayList<>(DYNAMO_DB_MAX_BATCH_SIZE); | ||||
| 
 | ||||
|         for (T item : items) { | ||||
|             batch.add(item); | ||||
| 
 | ||||
|             if (batch.size() == DYNAMO_DB_MAX_BATCH_SIZE) { | ||||
|                 action.accept(batch); | ||||
|                 batch.clear(); | ||||
|             } | ||||
|         } | ||||
|         if (!batch.isEmpty()) { | ||||
|             action.accept(batch); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | @ -56,6 +56,7 @@ public class AccountsManager { | |||
|   private final DirectoryManager          directory; | ||||
|   private final DirectoryQueue            directoryQueue; | ||||
|   private final Keys                      keys; | ||||
|   private final KeysDynamoDb              keysDynamoDb; | ||||
|   private final MessagesManager           messagesManager; | ||||
|   private final UsernamesManager          usernamesManager; | ||||
|   private final ProfilesManager           profilesManager; | ||||
|  | @ -73,12 +74,13 @@ public class AccountsManager { | |||
|     } | ||||
|   } | ||||
| 
 | ||||
|   public AccountsManager(Accounts accounts, DirectoryManager directory, FaultTolerantRedisCluster cacheCluster, final DirectoryQueue directoryQueue, final Keys keys, final MessagesManager messagesManager, final UsernamesManager usernamesManager, final ProfilesManager profilesManager) { | ||||
|   public AccountsManager(Accounts accounts, DirectoryManager directory, FaultTolerantRedisCluster cacheCluster, final DirectoryQueue directoryQueue, final Keys keys, final KeysDynamoDb keysDynamoDb, final MessagesManager messagesManager, final UsernamesManager usernamesManager, final ProfilesManager profilesManager) { | ||||
|     this.accounts         = accounts; | ||||
|     this.directory        = directory; | ||||
|     this.cacheCluster     = cacheCluster; | ||||
|     this.directoryQueue   = directoryQueue; | ||||
|     this.keys             = keys; | ||||
|     this.keysDynamoDb     = keysDynamoDb; | ||||
|     this.messagesManager  = messagesManager; | ||||
|     this.usernamesManager = usernamesManager; | ||||
|     this.profilesManager  = profilesManager; | ||||
|  | @ -150,7 +152,8 @@ public class AccountsManager { | |||
|       directoryQueue.deleteAccount(account); | ||||
|       directory.remove(account.getNumber()); | ||||
|       profilesManager.deleteAll(account.getUuid()); | ||||
|       keys.delete(account.getNumber()); | ||||
|       keys.delete(account); | ||||
|       keysDynamoDb.delete(account); | ||||
|       messagesManager.clear(account.getNumber(), account.getUuid()); | ||||
|       redisDelete(account); | ||||
|       databaseDelete(account); | ||||
|  |  | |||
|  | @ -5,6 +5,8 @@ | |||
| 
 | ||||
| package org.whispersystems.textsecuregcm.storage; | ||||
| 
 | ||||
| import java.util.Objects; | ||||
| 
 | ||||
| public class KeyRecord { | ||||
| 
 | ||||
|   private long    id; | ||||
|  | @ -41,4 +43,20 @@ public class KeyRecord { | |||
|     return publicKey; | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public boolean equals(final Object o) { | ||||
|     if (this == o) return true; | ||||
|     if (o == null || getClass() != o.getClass()) return false; | ||||
|     final KeyRecord keyRecord = (KeyRecord)o; | ||||
|     return id == keyRecord.id && | ||||
|             deviceId == keyRecord.deviceId && | ||||
|             keyId == keyRecord.keyId && | ||||
|             Objects.equals(number, keyRecord.number) && | ||||
|             Objects.equals(publicKey, keyRecord.publicKey); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public int hashCode() { | ||||
|     return Objects.hash(id, number, deviceId, keyId, publicKey); | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -27,7 +27,7 @@ import java.util.function.Supplier; | |||
| 
 | ||||
| import static com.codahale.metrics.MetricRegistry.name; | ||||
| 
 | ||||
| public class Keys { | ||||
| public class Keys implements PreKeyStore { | ||||
| 
 | ||||
|   private final MetricRegistry metricRegistry  = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); | ||||
|   private final Meter          fallbackMeter   = metricRegistry.meter(name(Keys.class, "fallback")); | ||||
|  | @ -49,7 +49,10 @@ public class Keys { | |||
|     this.retry = Retry.of("keys", retryConfiguration.toRetryConfigBuilder().build()); | ||||
|   } | ||||
| 
 | ||||
|   public void store(String number, long deviceId, List<PreKey> keys) { | ||||
|   @Override | ||||
|   public void store(Account account, long deviceId, List<PreKey> keys) { | ||||
|     final String number = account.getNumber(); | ||||
| 
 | ||||
|     retry.executeRunnable(() -> { | ||||
|       database.use(jdbi -> jdbi.useTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { | ||||
|         try (Timer.Context ignored = storeTimer.time()) { | ||||
|  | @ -74,8 +77,12 @@ public class Keys { | |||
|     }); | ||||
|   } | ||||
| 
 | ||||
|   public List<KeyRecord> get(String number, long deviceId) { | ||||
|     /* try { | ||||
|   @Override | ||||
|   public List<KeyRecord> take(Account account, long deviceId) { | ||||
|     /* | ||||
|     final String number = account.getNumber(); | ||||
| 
 | ||||
|     try { | ||||
|       return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { | ||||
|         try (Timer.Context ignored = getDevicetTimer.time()) { | ||||
|           return handle.createQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *") | ||||
|  | @ -95,8 +102,12 @@ public class Keys { | |||
|     return new LinkedList<>(); | ||||
|   } | ||||
| 
 | ||||
|   public List<KeyRecord> get(String number) { | ||||
|     /* try { | ||||
|   @Override | ||||
|   public List<KeyRecord> take(Account account) { | ||||
|     /* | ||||
|     final String number = account.getNumber(); | ||||
| 
 | ||||
|     try { | ||||
|       return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { | ||||
|         try (Timer.Context ignored = getTimer.time()) { | ||||
|           return handle.createQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *") | ||||
|  | @ -115,7 +126,10 @@ public class Keys { | |||
|     return new LinkedList<>(); | ||||
|   } | ||||
| 
 | ||||
|   public int getCount(String number, long deviceId) { | ||||
|   @Override | ||||
|   public int getCount(Account account, long deviceId) { | ||||
|     final String number = account.getNumber(); | ||||
| 
 | ||||
|     return database.with(jdbi -> jdbi.withHandle(handle -> { | ||||
|       try (Timer.Context ignored = getCountTimer.time()) { | ||||
|         return handle.createQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id") | ||||
|  | @ -127,7 +141,9 @@ public class Keys { | |||
|     })); | ||||
|   } | ||||
| 
 | ||||
|   public void delete(final String number) { | ||||
|   public void delete(final Account account) { | ||||
|     final String number = account.getNumber(); | ||||
| 
 | ||||
|     database.use(jdbi -> jdbi.useHandle(handle -> { | ||||
|       try (Timer.Context ignored = getCountTimer.time()) { | ||||
|         handle.createUpdate("DELETE FROM keys WHERE number = :number") | ||||
|  |  | |||
|  | @ -0,0 +1,208 @@ | |||
| /* | ||||
|  * Copyright 2021 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.storage; | ||||
| 
 | ||||
| import com.amazonaws.services.dynamodbv2.document.DeleteItemOutcome; | ||||
| import com.amazonaws.services.dynamodbv2.document.DynamoDB; | ||||
| import com.amazonaws.services.dynamodbv2.document.Item; | ||||
| import com.amazonaws.services.dynamodbv2.document.PrimaryKey; | ||||
| import com.amazonaws.services.dynamodbv2.document.Table; | ||||
| import com.amazonaws.services.dynamodbv2.document.TableWriteItems; | ||||
| import com.amazonaws.services.dynamodbv2.document.spec.DeleteItemSpec; | ||||
| import com.amazonaws.services.dynamodbv2.document.spec.QuerySpec; | ||||
| import com.amazonaws.services.dynamodbv2.model.ReturnValue; | ||||
| import com.amazonaws.services.dynamodbv2.model.Select; | ||||
| import com.google.common.annotations.VisibleForTesting; | ||||
| import io.micrometer.core.instrument.DistributionSummary; | ||||
| import io.micrometer.core.instrument.Metrics; | ||||
| import io.micrometer.core.instrument.Timer; | ||||
| import org.whispersystems.textsecuregcm.entities.PreKey; | ||||
| 
 | ||||
| import java.nio.ByteBuffer; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import java.util.UUID; | ||||
| 
 | ||||
| import static com.codahale.metrics.MetricRegistry.name; | ||||
| 
 | ||||
| public class KeysDynamoDb extends AbstractDynamoDbStore implements PreKeyStore { | ||||
| 
 | ||||
|     private final Table table; | ||||
| 
 | ||||
|     static final String KEY_ACCOUNT_UUID = "U"; | ||||
|     static final String KEY_DEVICE_ID_KEY_ID = "DK"; | ||||
|     static final String KEY_PUBLIC_KEY = "P"; | ||||
| 
 | ||||
|     private static final Timer               STORE_KEYS_TIMER              = Metrics.timer(name(KeysDynamoDb.class, "storeKeys")); | ||||
|     private static final Timer               TAKE_KEY_FOR_DEVICE_TIMER     = Metrics.timer(name(KeysDynamoDb.class, "takeKeyForDevice")); | ||||
|     private static final Timer               TAKE_KEYS_FOR_ACCOUNT_TIMER   = Metrics.timer(name(KeysDynamoDb.class, "takeKeyForAccount")); | ||||
|     private static final Timer               GET_KEY_COUNT_TIMER           = Metrics.timer(name(KeysDynamoDb.class, "getKeyCount")); | ||||
|     private static final Timer               DELETE_KEYS_FOR_DEVICE_TIMER  = Metrics.timer(name(KeysDynamoDb.class, "deleteKeysForDevice")); | ||||
|     private static final Timer               DELETE_KEYS_FOR_ACCOUNT_TIMER = Metrics.timer(name(KeysDynamoDb.class, "deleteKeysForAccount")); | ||||
|     private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION    = Metrics.summary(name(KeysDynamoDb.class, "contestedKeys")); | ||||
| 
 | ||||
|     public KeysDynamoDb(final DynamoDB dynamoDB, final String tableName) { | ||||
|         super(dynamoDB); | ||||
| 
 | ||||
|         this.table = dynamoDB.getTable(tableName); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void store(final Account account, final long deviceId, final List<PreKey> keys) { | ||||
|         STORE_KEYS_TIMER.record(() -> { | ||||
|             delete(account, deviceId); | ||||
| 
 | ||||
|             writeInBatches(keys, batch -> { | ||||
|                 final TableWriteItems items = new TableWriteItems(table.getTableName()); | ||||
| 
 | ||||
|                 for (final PreKey preKey : batch) { | ||||
|                     items.addItemToPut(getItemFromPreKey(account.getUuid(), deviceId, preKey)); | ||||
|                 } | ||||
| 
 | ||||
|                 executeTableWriteItemsUntilComplete(items); | ||||
|             }); | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<KeyRecord> take(final Account account, final long deviceId) { | ||||
|         return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> { | ||||
|             final byte[] partitionKey = getPartitionKey(account.getUuid()); | ||||
| 
 | ||||
|             final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") | ||||
|                                                        .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) | ||||
|                                                        .withValueMap(Map.of(":uuid", partitionKey, | ||||
|                                                                             ":sortprefix", getSortKeyPrefix(deviceId))) | ||||
|                                                        .withProjectionExpression(KEY_DEVICE_ID_KEY_ID) | ||||
|                                                        .withConsistentRead(false); | ||||
| 
 | ||||
|             int contestedKeys = 0; | ||||
| 
 | ||||
|             try { | ||||
|                 for (final Item candidate : table.query(querySpec)) { | ||||
|                     final DeleteItemSpec deleteItemSpec = new DeleteItemSpec().withPrimaryKey(KEY_ACCOUNT_UUID, partitionKey, KEY_DEVICE_ID_KEY_ID, candidate.getBinary(KEY_DEVICE_ID_KEY_ID)) | ||||
|                                                                               .withReturnValues(ReturnValue.ALL_OLD); | ||||
| 
 | ||||
|                     final DeleteItemOutcome outcome = table.deleteItem(deleteItemSpec); | ||||
| 
 | ||||
|                     if (outcome.getItem() != null) { | ||||
|                         final PreKey preKey = getPreKeyFromItem(outcome.getItem()); | ||||
|                         return List.of(new KeyRecord(-1, account.getNumber(), deviceId, preKey.getKeyId(), preKey.getPublicKey())); | ||||
|                     } | ||||
| 
 | ||||
|                     contestedKeys++; | ||||
|                 } | ||||
| 
 | ||||
|                 return Collections.emptyList(); | ||||
|             } finally { | ||||
|                 CONTESTED_KEY_DISTRIBUTION.record(contestedKeys); | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<KeyRecord> take(final Account account) { | ||||
|         return TAKE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { | ||||
|             final List<KeyRecord> keyRecords = new ArrayList<>(); | ||||
| 
 | ||||
|             for (final Device device : account.getDevices()) { | ||||
|                 keyRecords.addAll(take(account, device.getId())); | ||||
|             } | ||||
| 
 | ||||
|             return keyRecords; | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int getCount(final Account account, final long deviceId) { | ||||
|         return GET_KEY_COUNT_TIMER.record(() -> { | ||||
|             final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") | ||||
|                                                        .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) | ||||
|                                                        .withValueMap(Map.of(":uuid", getPartitionKey(account.getUuid()), | ||||
|                                                                             ":sortprefix", getSortKeyPrefix(deviceId))) | ||||
|                                                        .withSelect(Select.COUNT) | ||||
|                                                        .withConsistentRead(false); | ||||
| 
 | ||||
|             // This is very confusing, but does appear to be the intended behavior. See: | ||||
|             // | ||||
|             // - https://github.com/aws/aws-sdk-java/issues/693 | ||||
|             // - https://github.com/aws/aws-sdk-java/issues/915 | ||||
|             return table.query(querySpec).firstPage().getLowLevelResult().getQueryResult().getCount(); | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void delete(final Account account) { | ||||
|         DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { | ||||
|             final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid") | ||||
|                                                        .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID)) | ||||
|                                                        .withValueMap(Map.of(":uuid", getPartitionKey(account.getUuid()))) | ||||
|                                                        .withProjectionExpression(KEY_ACCOUNT_UUID + ", " + KEY_DEVICE_ID_KEY_ID) | ||||
|                                                        .withConsistentRead(true); | ||||
| 
 | ||||
|             deleteItemsMatchingQuery(querySpec); | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     @VisibleForTesting | ||||
|     void delete(final Account account, final long deviceId) { | ||||
|         DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> { | ||||
|             final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") | ||||
|                                                        .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) | ||||
|                                                        .withValueMap(Map.of(":uuid", getPartitionKey(account.getUuid()), | ||||
|                                                                             ":sortprefix", getSortKeyPrefix(deviceId))) | ||||
|                                                        .withProjectionExpression(KEY_ACCOUNT_UUID + ", " + KEY_DEVICE_ID_KEY_ID) | ||||
|                                                        .withConsistentRead(true); | ||||
| 
 | ||||
|             deleteItemsMatchingQuery(querySpec); | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     private void deleteItemsMatchingQuery(final QuerySpec querySpec) { | ||||
|         writeInBatches(table.query(querySpec), batch -> { | ||||
|             final TableWriteItems writeItems = new TableWriteItems(table.getTableName()); | ||||
| 
 | ||||
|             for (final Item item : batch) { | ||||
|                 writeItems.addPrimaryKeyToDelete(new PrimaryKey(KEY_ACCOUNT_UUID, item.getBinary(KEY_ACCOUNT_UUID), KEY_DEVICE_ID_KEY_ID, item.getBinary(KEY_DEVICE_ID_KEY_ID))); | ||||
|             } | ||||
| 
 | ||||
|             executeTableWriteItemsUntilComplete(writeItems); | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     private static byte[] getPartitionKey(final UUID accountUuid) { | ||||
|         final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); | ||||
|         byteBuffer.putLong(accountUuid.getMostSignificantBits()); | ||||
|         byteBuffer.putLong(accountUuid.getLeastSignificantBits()); | ||||
|         return byteBuffer.array(); | ||||
|     } | ||||
| 
 | ||||
|     private static byte[] getSortKey(final long deviceId, final long keyId) { | ||||
|         final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); | ||||
|         byteBuffer.putLong(deviceId); | ||||
|         byteBuffer.putLong(keyId); | ||||
|         return byteBuffer.array(); | ||||
|     } | ||||
| 
 | ||||
|     private static byte[] getSortKeyPrefix(final long deviceId) { | ||||
|         final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]); | ||||
|         byteBuffer.putLong(deviceId); | ||||
|         return byteBuffer.array(); | ||||
|     } | ||||
| 
 | ||||
|     private Item getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) { | ||||
|         return new Item().withBinary(KEY_ACCOUNT_UUID, getPartitionKey(accountUuid)) | ||||
|                          .withBinary(KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId())) | ||||
|                          .withString(KEY_PUBLIC_KEY, preKey.getPublicKey()); | ||||
|     } | ||||
| 
 | ||||
|     private PreKey getPreKeyFromItem(final Item item) { | ||||
|         final long keyId = ByteBuffer.wrap(item.getBinary(KEY_DEVICE_ID_KEY_ID)).getLong(8); | ||||
|         return new PreKey(keyId, item.getString(KEY_PUBLIC_KEY)); | ||||
|     } | ||||
| } | ||||
|  | @ -5,7 +5,6 @@ | |||
| 
 | ||||
| package org.whispersystems.textsecuregcm.storage; | ||||
| 
 | ||||
| import com.amazonaws.services.dynamodbv2.document.BatchWriteItemOutcome; | ||||
| import com.amazonaws.services.dynamodbv2.document.DeleteItemOutcome; | ||||
| import com.amazonaws.services.dynamodbv2.document.DynamoDB; | ||||
| import com.amazonaws.services.dynamodbv2.document.Index; | ||||
|  | @ -17,11 +16,8 @@ import com.amazonaws.services.dynamodbv2.document.api.QueryApi; | |||
| import com.amazonaws.services.dynamodbv2.document.spec.DeleteItemSpec; | ||||
| import com.amazonaws.services.dynamodbv2.document.spec.QuerySpec; | ||||
| import com.amazonaws.services.dynamodbv2.model.ReturnValue; | ||||
| import io.micrometer.core.instrument.Counter; | ||||
| import io.micrometer.core.instrument.Timer; | ||||
| import org.apache.commons.lang3.StringUtils; | ||||
| import org.slf4j.Logger; | ||||
| import org.slf4j.LoggerFactory; | ||||
| import org.whispersystems.textsecuregcm.entities.MessageProtos; | ||||
| import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; | ||||
| 
 | ||||
|  | @ -33,17 +29,11 @@ import java.util.List; | |||
| import java.util.Map; | ||||
| import java.util.Optional; | ||||
| import java.util.UUID; | ||||
| import java.util.concurrent.atomic.AtomicReference; | ||||
| import java.util.function.Consumer; | ||||
| 
 | ||||
| import static com.codahale.metrics.MetricRegistry.name; | ||||
| import static io.micrometer.core.instrument.Metrics.counter; | ||||
| import static io.micrometer.core.instrument.Metrics.timer; | ||||
| 
 | ||||
| public class MessagesDynamoDb { | ||||
|   private static final int MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE = 25;  // This was arbitrarily chosen and may be entirely too high. | ||||
|   private static final int DYNAMO_DB_MAX_BATCH_SIZE = 25;  // This limit comes from Amazon Dynamo DB itself. It will reject batch writes larger than this. | ||||
|   public static final int RESULT_SET_CHUNK_SIZE = 100; | ||||
| public class MessagesDynamoDb extends AbstractDynamoDbStore { | ||||
| 
 | ||||
|   private static final String KEY_PARTITION = "H"; | ||||
|   private static final String KEY_SORT = "S"; | ||||
|  | @ -60,10 +50,6 @@ public class MessagesDynamoDb { | |||
|   private static final String KEY_CONTENT = "C"; | ||||
|   private static final String KEY_TTL = "E"; | ||||
| 
 | ||||
|   private final Logger logger = LoggerFactory.getLogger(getClass()); | ||||
|   private final Timer batchWriteItemsFirstPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "true"); | ||||
|   private final Timer batchWriteItemsRetryPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "false"); | ||||
|   private final Counter batchWriteItemsUnprocessed = counter(name(getClass(), "batchWriteItemsUnprocessed")); | ||||
|   private final Timer storeTimer = timer(name(getClass(), "store")); | ||||
|   private final Timer loadTimer = timer(name(getClass(), "load")); | ||||
|   private final Timer deleteBySourceAndTimestamp = timer(name(getClass(), "delete", "sourceAndTimestamp")); | ||||
|  | @ -71,18 +57,18 @@ public class MessagesDynamoDb { | |||
|   private final Timer deleteByAccount = timer(name(getClass(), "delete", "account")); | ||||
|   private final Timer deleteByDevice = timer(name(getClass(), "delete", "device")); | ||||
| 
 | ||||
|   private final DynamoDB dynamoDb; | ||||
|   private final String tableName; | ||||
|   private final Duration timeToLive; | ||||
| 
 | ||||
|   public MessagesDynamoDb(DynamoDB dynamoDb, String tableName, Duration timeToLive) { | ||||
|     this.dynamoDb = dynamoDb; | ||||
|     super(dynamoDb); | ||||
| 
 | ||||
|     this.tableName = tableName; | ||||
|     this.timeToLive = timeToLive; | ||||
|   } | ||||
| 
 | ||||
|   public void store(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid, final long destinationDeviceId) { | ||||
|     storeTimer.record(() -> doInBatches(messages, (messageBatch) -> storeBatch(messageBatch, destinationAccountUuid, destinationDeviceId), DYNAMO_DB_MAX_BATCH_SIZE)); | ||||
|     storeTimer.record(() -> writeInBatches(messages, (messageBatch) -> storeBatch(messageBatch, destinationAccountUuid, destinationDeviceId))); | ||||
|   } | ||||
| 
 | ||||
|   private void storeBatch(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid, final long destinationDeviceId) { | ||||
|  | @ -135,7 +121,7 @@ public class MessagesDynamoDb { | |||
|                                                  .withValueMap(Map.of(":part", partitionKey, | ||||
|                                                                       ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId))) | ||||
|                                                  .withMaxResultSize(numberOfMessagesToFetch); | ||||
|       final Table table = dynamoDb.getTable(tableName); | ||||
|       final Table table = getDynamoDb().getTable(tableName); | ||||
|       List<OutgoingMessageEntity> messageEntities = new ArrayList<>(numberOfMessagesToFetch); | ||||
|       for (Item message : table.query(querySpec)) { | ||||
|         messageEntities.add(convertItemToOutgoingMessageEntity(message)); | ||||
|  | @ -164,7 +150,7 @@ public class MessagesDynamoDb { | |||
|                                                                       ":source", source, | ||||
|                                                                       ":timestamp", timestamp)); | ||||
| 
 | ||||
|       final Table table = dynamoDb.getTable(tableName); | ||||
|       final Table table = getDynamoDb().getTable(tableName); | ||||
|       return deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(table, partitionKey, querySpec, table); | ||||
|     }); | ||||
|   } | ||||
|  | @ -179,7 +165,7 @@ public class MessagesDynamoDb { | |||
|                                                                      "#uuid", LOCAL_INDEX_MESSAGE_UUID_KEY_SORT)) | ||||
|                                                  .withValueMap(Map.of(":part", partitionKey, | ||||
|                                                                       ":uuid", convertLocalIndexMessageUuidSortKey(messageUuid))); | ||||
|       final Table table = dynamoDb.getTable(tableName); | ||||
|       final Table table = getDynamoDb().getTable(tableName); | ||||
|       final Index index = table.getIndex(LOCAL_INDEX_MESSAGE_UUID_NAME); | ||||
|       return deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(table, partitionKey, querySpec, index); | ||||
|     }); | ||||
|  | @ -241,62 +227,24 @@ public class MessagesDynamoDb { | |||
|   } | ||||
| 
 | ||||
|   private void deleteRowsMatchingQuery(byte[] partitionKey, QuerySpec querySpec) { | ||||
|     final Table table = dynamoDb.getTable(tableName); | ||||
|     doInBatches(table.query(querySpec), (itemBatch) -> deleteItems(partitionKey, itemBatch), DYNAMO_DB_MAX_BATCH_SIZE); | ||||
|     final Table table = getDynamoDb().getTable(tableName); | ||||
|     writeInBatches(table.query(querySpec), (itemBatch) -> deleteItems(partitionKey, itemBatch)); | ||||
|   } | ||||
| 
 | ||||
|   private void deleteItems(byte[] partitionKey, List<Item> items) { | ||||
|     final TableWriteItems tableWriteItems = new TableWriteItems(tableName); | ||||
|     items.stream().map((x) -> new PrimaryKey(KEY_PARTITION, partitionKey, KEY_SORT, x.getBinary(KEY_SORT))).forEach(tableWriteItems::addPrimaryKeyToDelete); | ||||
|     items.stream().map(item -> new PrimaryKey(KEY_PARTITION, partitionKey, KEY_SORT, item.getBinary(KEY_SORT))).forEach(tableWriteItems::addPrimaryKeyToDelete); | ||||
|     executeTableWriteItemsUntilComplete(tableWriteItems); | ||||
|   } | ||||
| 
 | ||||
|   private void executeTableWriteItemsUntilComplete(TableWriteItems items) { | ||||
|     AtomicReference<BatchWriteItemOutcome> outcome = new AtomicReference<>(); | ||||
|     batchWriteItemsFirstPass.record(() -> { | ||||
|       outcome.set(dynamoDb.batchWriteItem(items)); | ||||
|     }); | ||||
|     int attemptCount = 0; | ||||
|     while (!outcome.get().getUnprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) { | ||||
|       batchWriteItemsRetryPass.record(() -> { | ||||
|         outcome.set(dynamoDb.batchWriteItemUnprocessed(outcome.get().getUnprocessedItems())); | ||||
|       }); | ||||
|       ++attemptCount; | ||||
|     } | ||||
|     if (!outcome.get().getUnprocessedItems().isEmpty()) { | ||||
|       logger.error("Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.", attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, outcome.get().getUnprocessedItems().size()); | ||||
|       batchWriteItemsUnprocessed.increment(outcome.get().getUnprocessedItems().size()); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   private long getTtlForMessage(MessageProtos.Envelope message) { | ||||
|     return message.getServerTimestamp() / 1000 + timeToLive.getSeconds(); | ||||
|   } | ||||
| 
 | ||||
|   private static <T> void doInBatches(final Iterable<T> items, final Consumer<List<T>> action, final int batchSize) { | ||||
|     List<T> batch = new ArrayList<>(batchSize); | ||||
| 
 | ||||
|     for (T item : items) { | ||||
|       batch.add(item); | ||||
| 
 | ||||
|       if (batch.size() == batchSize) { | ||||
|         action.accept(batch); | ||||
|         batch.clear(); | ||||
|       } | ||||
|     } | ||||
|     if (!batch.isEmpty()) { | ||||
|       action.accept(batch); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   private static byte[] convertPartitionKey(final UUID destinationAccountUuid) { | ||||
|     return convertUuidToBytes(destinationAccountUuid); | ||||
|   } | ||||
| 
 | ||||
|   private static UUID convertPartitionKey(final byte[] bytes) { | ||||
|     return convertUuidFromBytes(bytes, "partition key"); | ||||
|   } | ||||
| 
 | ||||
|   private static byte[] convertSortKey(final long destinationDeviceId, final long serverTimestamp, final UUID messageUuid) { | ||||
|     ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[32]); | ||||
|     byteBuffer.putLong(destinationDeviceId); | ||||
|  |  | |||
|  | @ -0,0 +1,23 @@ | |||
| /* | ||||
|  * Copyright 2021 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.storage; | ||||
| 
 | ||||
| import org.whispersystems.textsecuregcm.entities.PreKey; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| public interface PreKeyStore { | ||||
| 
 | ||||
|     void store(Account account, long deviceId, List<PreKey> keys); | ||||
| 
 | ||||
|     int getCount(Account account, long deviceId); | ||||
| 
 | ||||
|     List<KeyRecord> take(Account account, long deviceId); | ||||
| 
 | ||||
|     List<KeyRecord> take(Account account); | ||||
| 
 | ||||
|     void delete(Account account); | ||||
| } | ||||
|  | @ -34,6 +34,7 @@ import org.whispersystems.textsecuregcm.storage.DirectoryManager; | |||
| import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; | ||||
| import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; | ||||
| import org.whispersystems.textsecuregcm.storage.Keys; | ||||
| import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; | ||||
| import org.whispersystems.textsecuregcm.storage.Messages; | ||||
| import org.whispersystems.textsecuregcm.storage.MessagesCache; | ||||
| import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; | ||||
|  | @ -97,7 +98,16 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura | |||
|               .withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) configuration.getMessageDynamoDbConfiguration().getClientExecutionTimeout().toMillis())) | ||||
|                                                                 .withRequestTimeout((int) configuration.getMessageDynamoDbConfiguration().getClientRequestTimeout().toMillis())) | ||||
|               .withCredentials(InstanceProfileCredentialsProvider.getInstance()); | ||||
| 
 | ||||
|       AmazonDynamoDBClientBuilder keysDynamoDbClientBuilder = AmazonDynamoDBClientBuilder | ||||
|               .standard() | ||||
|               .withRegion(configuration.getKeysDynamoDbConfiguration().getRegion()) | ||||
|               .withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) configuration.getKeysDynamoDbConfiguration().getClientExecutionTimeout().toMillis())) | ||||
|                                                                 .withRequestTimeout((int) configuration.getKeysDynamoDbConfiguration().getClientRequestTimeout().toMillis())) | ||||
|               .withCredentials(InstanceProfileCredentialsProvider.getInstance()); | ||||
| 
 | ||||
|       DynamoDB messageDynamoDb = new DynamoDB(clientBuilder.build()); | ||||
|       DynamoDB preKeyDynamoDb  = new DynamoDB(keysDynamoDbClientBuilder.build()); | ||||
| 
 | ||||
|       FaultTolerantRedisCluster cacheCluster = new FaultTolerantRedisCluster("main_cache_cluster", configuration.getCacheClusterConfiguration(), redisClusterClientResources); | ||||
| 
 | ||||
|  | @ -111,6 +121,7 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura | |||
|       Profiles                  profiles             = new Profiles(accountDatabase); | ||||
|       ReservedUsernames         reservedUsernames    = new ReservedUsernames(accountDatabase); | ||||
|       Keys                      keys                 = new Keys(accountDatabase, configuration.getAccountsDatabaseConfiguration().getKeyOperationRetryConfiguration()); | ||||
|       KeysDynamoDb              keysDynamoDb         = new KeysDynamoDb(messageDynamoDb, configuration.getKeysDynamoDbConfiguration().getTableName()); | ||||
|       Messages                  messages             = new Messages(messageDatabase); | ||||
|       MessagesDynamoDb          messagesDynamoDb     = new MessagesDynamoDb(messageDynamoDb, configuration.getMessageDynamoDbConfiguration().getTableName(), configuration.getMessageDynamoDbConfiguration().getTimeToLive()); | ||||
|       ReplicatedJedisPool       redisClient          = new RedisClientFactory("directory_cache_delete_command", configuration.getDirectoryConfiguration().getRedisConfiguration().getUrl(), configuration.getDirectoryConfiguration().getRedisConfiguration().getReplicaUrls(), configuration.getDirectoryConfiguration().getRedisConfiguration().getCircuitBreakerConfiguration()).getRedisClientPool(); | ||||
|  | @ -124,7 +135,7 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura | |||
|       UsernamesManager          usernamesManager     = new UsernamesManager(usernames, reservedUsernames, cacheCluster); | ||||
|       ProfilesManager           profilesManager      = new ProfilesManager(profiles, cacheCluster); | ||||
|       MessagesManager           messagesManager      = new MessagesManager(messages, messagesDynamoDb, messagesCache, pushLatencyManager, new ExperimentEnrollmentManager(dynamicConfigurationManager)); | ||||
|       AccountsManager           accountsManager      = new AccountsManager(accounts, directory, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); | ||||
|       AccountsManager           accountsManager      = new AccountsManager(accounts, directory, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||
| 
 | ||||
|       for (String user: users) { | ||||
|         Optional<Account> account = accountsManager.get(user); | ||||
|  |  | |||
|  | @ -0,0 +1,40 @@ | |||
| /* | ||||
|  * Copyright 2021 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.storage; | ||||
| 
 | ||||
| import com.amazonaws.services.dynamodbv2.document.DynamoDB; | ||||
| import com.amazonaws.services.dynamodbv2.model.AttributeDefinition; | ||||
| import com.amazonaws.services.dynamodbv2.model.CreateTableRequest; | ||||
| import com.amazonaws.services.dynamodbv2.model.KeySchemaElement; | ||||
| import com.amazonaws.services.dynamodbv2.model.ProvisionedThroughput; | ||||
| import com.amazonaws.services.dynamodbv2.model.ScalarAttributeType; | ||||
| import org.whispersystems.textsecuregcm.tests.util.LocalDynamoDbRule; | ||||
| 
 | ||||
| public class KeysDynamoDbRule extends LocalDynamoDbRule { | ||||
|     public static final String TABLE_NAME = "Signal_Keys_Test"; | ||||
| 
 | ||||
|     @Override | ||||
|     protected void before() throws Throwable { | ||||
|         super.before(); | ||||
| 
 | ||||
|         final DynamoDB dynamoDB = getDynamoDB(); | ||||
| 
 | ||||
|         final CreateTableRequest createTableRequest = new CreateTableRequest() | ||||
|                 .withTableName(TABLE_NAME) | ||||
|                 .withKeySchema(new KeySchemaElement(KeysDynamoDb.KEY_ACCOUNT_UUID, "HASH"), | ||||
|                                new KeySchemaElement(KeysDynamoDb.KEY_DEVICE_ID_KEY_ID, "RANGE")) | ||||
|                 .withAttributeDefinitions(new AttributeDefinition(KeysDynamoDb.KEY_ACCOUNT_UUID, ScalarAttributeType.B), | ||||
|                                           new AttributeDefinition(KeysDynamoDb.KEY_DEVICE_ID_KEY_ID, ScalarAttributeType.B)) | ||||
|                 .withProvisionedThroughput(new ProvisionedThroughput(20L, 20L)); | ||||
| 
 | ||||
|         dynamoDB.createTable(createTableRequest); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     protected void after() { | ||||
|         super.after(); | ||||
|     } | ||||
| } | ||||
|  | @ -0,0 +1,136 @@ | |||
| /* | ||||
|  * Copyright 2021 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.storage; | ||||
| 
 | ||||
| import org.junit.Before; | ||||
| import org.junit.ClassRule; | ||||
| import org.junit.Test; | ||||
| import org.whispersystems.textsecuregcm.entities.PreKey; | ||||
| 
 | ||||
| import java.util.Collections; | ||||
| import java.util.HashSet; | ||||
| import java.util.List; | ||||
| import java.util.Set; | ||||
| import java.util.UUID; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.mockito.Mockito.mock; | ||||
| import static org.mockito.Mockito.when; | ||||
| 
 | ||||
| public class KeysDynamoDbTest { | ||||
| 
 | ||||
|     private Account account; | ||||
|     private KeysDynamoDb keysDynamoDb; | ||||
| 
 | ||||
|     @ClassRule | ||||
|     public static KeysDynamoDbRule dynamoDbRule = new KeysDynamoDbRule(); | ||||
| 
 | ||||
|     private static final String ACCOUNT_NUMBER = "+18005551234"; | ||||
|     private static final long DEVICE_ID = 1L; | ||||
| 
 | ||||
|     @Before | ||||
|     public void setup() { | ||||
|         keysDynamoDb = new KeysDynamoDb(dynamoDbRule.getDynamoDB(), KeysDynamoDbRule.TABLE_NAME); | ||||
| 
 | ||||
|         account = mock(Account.class); | ||||
|         when(account.getNumber()).thenReturn(ACCOUNT_NUMBER); | ||||
|         when(account.getUuid()).thenReturn(UUID.randomUUID()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testStore() { | ||||
|         assertEquals("Initial pre-key count for an account should be zero", | ||||
|                 0, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
| 
 | ||||
|         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); | ||||
|         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
| 
 | ||||
|         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); | ||||
|         assertEquals("Repeatedly storing same key should have no effect", | ||||
|                 1, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
| 
 | ||||
|         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(2, "different-public-key"))); | ||||
|         assertEquals("Inserting a new key should overwrite all prior keys for the given account/device", | ||||
|                 1, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
| 
 | ||||
|         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(3, "third-public-key"), new PreKey(4, "fourth-public-key"))); | ||||
|         assertEquals("Inserting multiple new keys should overwrite all prior keys for the given account/device", | ||||
|                 2, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testTakeAccount() { | ||||
|         final Device firstDevice = mock(Device.class); | ||||
|         final Device secondDevice = mock(Device.class); | ||||
| 
 | ||||
|         when(firstDevice.getId()).thenReturn(DEVICE_ID); | ||||
|         when(secondDevice.getId()).thenReturn(DEVICE_ID + 1); | ||||
|         when(account.getDevices()).thenReturn(Set.of(firstDevice, secondDevice)); | ||||
| 
 | ||||
|         assertEquals(Collections.emptyList(), keysDynamoDb.take(account)); | ||||
| 
 | ||||
|         final PreKey firstDevicePreKey = new PreKey(1, "public-key"); | ||||
|         final PreKey secondDevicePreKey = new PreKey(2, "second-key"); | ||||
| 
 | ||||
|         keysDynamoDb.store(account, DEVICE_ID, List.of(firstDevicePreKey)); | ||||
|         keysDynamoDb.store(account, DEVICE_ID + 1, List.of(secondDevicePreKey)); | ||||
| 
 | ||||
|         final Set<KeyRecord> expectedKeys = Set.of( | ||||
|                 new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID, firstDevicePreKey.getKeyId(), firstDevicePreKey.getPublicKey()), | ||||
|                 new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID + 1, secondDevicePreKey.getKeyId(), secondDevicePreKey.getPublicKey())); | ||||
| 
 | ||||
|         assertEquals(expectedKeys, new HashSet<>(keysDynamoDb.take(account))); | ||||
|         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
|         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID + 1)); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testTakeAccountAndDeviceId() { | ||||
|         assertEquals(Collections.emptyList(), keysDynamoDb.take(account, DEVICE_ID)); | ||||
| 
 | ||||
|         final PreKey preKey = new PreKey(1, "public-key"); | ||||
| 
 | ||||
|         keysDynamoDb.store(account, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key"))); | ||||
|         assertEquals(List.of(new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID, preKey.getKeyId(), preKey.getPublicKey())), keysDynamoDb.take(account, DEVICE_ID)); | ||||
|         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testGetCount() { | ||||
|         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
| 
 | ||||
|         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); | ||||
|         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testDeleteByAccount() { | ||||
|         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); | ||||
|         keysDynamoDb.store(account, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); | ||||
| 
 | ||||
|         assertEquals(2, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
|         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID + 1)); | ||||
| 
 | ||||
|         keysDynamoDb.delete(account); | ||||
| 
 | ||||
|         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
|         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID + 1)); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testDeleteByAccountAndDevice() { | ||||
|         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); | ||||
|         keysDynamoDb.store(account, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); | ||||
| 
 | ||||
|         assertEquals(2, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
|         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID + 1)); | ||||
| 
 | ||||
|         keysDynamoDb.delete(account, DEVICE_ID); | ||||
| 
 | ||||
|         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||
|         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID + 1)); | ||||
|     } | ||||
| } | ||||
|  | @ -46,7 +46,7 @@ import io.dropwizard.testing.junit.ResourceTestRule; | |||
| import static org.assertj.core.api.Assertions.assertThat; | ||||
| import static org.mockito.Mockito.*; | ||||
| 
 | ||||
| public class KeyControllerTest { | ||||
| public class KeysControllerTest { | ||||
| 
 | ||||
|   private static final String EXISTS_NUMBER = "+14152222222"; | ||||
|   private static final UUID   EXISTS_UUID   = UUID.randomUUID(); | ||||
|  | @ -141,18 +141,16 @@ public class KeyControllerTest { | |||
| 
 | ||||
|     List<KeyRecord> singleDevice = new LinkedList<>(); | ||||
|     singleDevice.add(SAMPLE_KEY); | ||||
|     when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(singleDevice); | ||||
| 
 | ||||
|     when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(new LinkedList<>()); | ||||
|     when(keys.take(eq(existsAccount), eq(1L))).thenReturn(singleDevice); | ||||
| 
 | ||||
|     List<KeyRecord> multiDevice = new LinkedList<>(); | ||||
|     multiDevice.add(SAMPLE_KEY); | ||||
|     multiDevice.add(SAMPLE_KEY2); | ||||
|     multiDevice.add(SAMPLE_KEY3); | ||||
|     multiDevice.add(SAMPLE_KEY4); | ||||
|     when(keys.get(EXISTS_NUMBER)).thenReturn(multiDevice); | ||||
|     when(keys.take(existsAccount)).thenReturn(multiDevice); | ||||
| 
 | ||||
|     when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5); | ||||
|     when(keys.getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L))).thenReturn(5); | ||||
| 
 | ||||
|     when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY); | ||||
|     when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null); | ||||
|  | @ -169,7 +167,7 @@ public class KeyControllerTest { | |||
| 
 | ||||
|     assertThat(result.getCount()).isEqualTo(4); | ||||
| 
 | ||||
|     verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L)); | ||||
|     verify(keys).getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L)); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|  | @ -183,7 +181,7 @@ public class KeyControllerTest { | |||
| 
 | ||||
|     assertThat(result.getCount()).isEqualTo(4); | ||||
| 
 | ||||
|     verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L)); | ||||
|     verify(keys).getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L)); | ||||
|   } | ||||
| 
 | ||||
| 
 | ||||
|  | @ -283,7 +281,7 @@ public class KeyControllerTest { | |||
|     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); | ||||
|     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); | ||||
| 
 | ||||
|     verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); | ||||
|     verify(keys).take(eq(existsAccount), eq(1L)); | ||||
|     verifyNoMoreInteractions(keys); | ||||
|   } | ||||
| 
 | ||||
|  | @ -301,7 +299,7 @@ public class KeyControllerTest { | |||
|     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); | ||||
|     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); | ||||
| 
 | ||||
|     verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); | ||||
|     verify(keys).take(eq(existsAccount), eq(1L)); | ||||
|     verifyNoMoreInteractions(keys); | ||||
|   } | ||||
| 
 | ||||
|  | @ -320,7 +318,7 @@ public class KeyControllerTest { | |||
|     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); | ||||
|     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); | ||||
| 
 | ||||
|     verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); | ||||
|     verify(keys).take(eq(existsAccount), eq(1L)); | ||||
|     verifyNoMoreInteractions(keys); | ||||
|   } | ||||
| 
 | ||||
|  | @ -338,7 +336,7 @@ public class KeyControllerTest { | |||
|     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); | ||||
|     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); | ||||
| 
 | ||||
|     verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); | ||||
|     verify(keys).take(eq(existsAccount), eq(1L)); | ||||
|     verifyNoMoreInteractions(keys); | ||||
|   } | ||||
| 
 | ||||
|  | @ -414,7 +412,7 @@ public class KeyControllerTest { | |||
|     assertThat(signedPreKey).isNull(); | ||||
|     assertThat(deviceId).isEqualTo(4); | ||||
| 
 | ||||
|     verify(keys).get(eq(EXISTS_NUMBER)); | ||||
|     verify(keys).take(eq(existsAccount)); | ||||
|     verifyNoMoreInteractions(keys); | ||||
|   } | ||||
| 
 | ||||
|  | @ -464,7 +462,7 @@ public class KeyControllerTest { | |||
|     assertThat(signedPreKey).isNull(); | ||||
|     assertThat(deviceId).isEqualTo(4); | ||||
| 
 | ||||
|     verify(keys).get(eq(EXISTS_NUMBER)); | ||||
|     verify(keys).take(eq(existsAccount)); | ||||
|     verifyNoMoreInteractions(keys); | ||||
|   } | ||||
| 
 | ||||
|  | @ -533,7 +531,7 @@ public class KeyControllerTest { | |||
|     assertThat(response.getStatus()).isEqualTo(204); | ||||
| 
 | ||||
|     ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class); | ||||
|     verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture()); | ||||
|     verify(keys).store(eq(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture()); | ||||
| 
 | ||||
|     List<PreKey> capturedList = listCaptor.getValue(); | ||||
|     assertThat(capturedList.size()).isEqualTo(1); | ||||
|  | @ -567,7 +565,7 @@ public class KeyControllerTest { | |||
|     assertThat(response.getStatus()).isEqualTo(204); | ||||
| 
 | ||||
|     ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class); | ||||
|     verify(keys).store(eq(AuthHelper.DISABLED_NUMBER), eq(1L), listCaptor.capture()); | ||||
|     verify(keys).store(eq(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture()); | ||||
| 
 | ||||
|     List<PreKey> capturedList = listCaptor.getValue(); | ||||
|     assertThat(capturedList.size()).isEqualTo(1); | ||||
|  | @ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.tests.storage; | |||
| import io.lettuce.core.RedisException; | ||||
| import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; | ||||
| import org.junit.Test; | ||||
| import org.whispersystems.textsecuregcm.entities.Profile; | ||||
| import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; | ||||
| import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; | ||||
| import org.whispersystems.textsecuregcm.storage.Account; | ||||
|  | @ -16,6 +15,7 @@ import org.whispersystems.textsecuregcm.storage.Accounts; | |||
| import org.whispersystems.textsecuregcm.storage.AccountsManager; | ||||
| import org.whispersystems.textsecuregcm.storage.DirectoryManager; | ||||
| import org.whispersystems.textsecuregcm.storage.Keys; | ||||
| import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; | ||||
| import org.whispersystems.textsecuregcm.storage.MessagesManager; | ||||
| import org.whispersystems.textsecuregcm.storage.ProfilesManager; | ||||
| import org.whispersystems.textsecuregcm.storage.UsernamesManager; | ||||
|  | @ -46,6 +46,7 @@ public class AccountsManagerTest { | |||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||
|     Keys                                         keys             = mock(Keys.class); | ||||
|     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||
|  | @ -55,7 +56,7 @@ public class AccountsManagerTest { | |||
|     when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(uuid.toString()); | ||||
|     when(commands.get(eq("Account3::" + uuid.toString()))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}"); | ||||
| 
 | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||
|     Optional<Account> account         = accountsManager.get("+14152222222"); | ||||
| 
 | ||||
|     assertTrue(account.isPresent()); | ||||
|  | @ -76,6 +77,7 @@ public class AccountsManagerTest { | |||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||
|     Keys                                         keys             = mock(Keys.class); | ||||
|     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||
|  | @ -84,7 +86,7 @@ public class AccountsManagerTest { | |||
| 
 | ||||
|     when(commands.get(eq("Account3::" + uuid.toString()))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}"); | ||||
| 
 | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||
|     Optional<Account> account         = accountsManager.get(uuid); | ||||
| 
 | ||||
|     assertTrue(account.isPresent()); | ||||
|  | @ -106,6 +108,7 @@ public class AccountsManagerTest { | |||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||
|     Keys                                         keys             = mock(Keys.class); | ||||
|     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||
|  | @ -115,7 +118,7 @@ public class AccountsManagerTest { | |||
|     when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(null); | ||||
|     when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); | ||||
| 
 | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||
|     Optional<Account> retrieved       = accountsManager.get("+14152222222"); | ||||
| 
 | ||||
|     assertTrue(retrieved.isPresent()); | ||||
|  | @ -138,6 +141,7 @@ public class AccountsManagerTest { | |||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||
|     Keys                                         keys             = mock(Keys.class); | ||||
|     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||
|  | @ -147,7 +151,7 @@ public class AccountsManagerTest { | |||
|     when(commands.get(eq("Account3::" + uuid))).thenReturn(null); | ||||
|     when(accounts.get(eq(uuid))).thenReturn(Optional.of(account)); | ||||
| 
 | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||
|     Optional<Account> retrieved       = accountsManager.get(uuid); | ||||
| 
 | ||||
|     assertTrue(retrieved.isPresent()); | ||||
|  | @ -170,6 +174,7 @@ public class AccountsManagerTest { | |||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||
|     Keys                                         keys             = mock(Keys.class); | ||||
|     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||
|  | @ -179,7 +184,7 @@ public class AccountsManagerTest { | |||
|     when(commands.get(eq("AccountMap::+14152222222"))).thenThrow(new RedisException("Connection lost!")); | ||||
|     when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); | ||||
| 
 | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||
|     Optional<Account> retrieved       = accountsManager.get("+14152222222"); | ||||
| 
 | ||||
|     assertTrue(retrieved.isPresent()); | ||||
|  | @ -202,6 +207,7 @@ public class AccountsManagerTest { | |||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||
|     Keys                                         keys             = mock(Keys.class); | ||||
|     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||
|  | @ -211,7 +217,7 @@ public class AccountsManagerTest { | |||
|     when(commands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!")); | ||||
|     when(accounts.get(eq(uuid))).thenReturn(Optional.of(account)); | ||||
| 
 | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); | ||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||
|     Optional<Account> retrieved       = accountsManager.get(uuid); | ||||
| 
 | ||||
|     assertTrue(retrieved.isPresent()); | ||||
|  |  | |||
|  | @ -24,6 +24,7 @@ import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguratio | |||
| import org.whispersystems.textsecuregcm.configuration.AccountsDatabaseConfiguration; | ||||
| import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; | ||||
| import org.whispersystems.textsecuregcm.entities.PreKey; | ||||
| import org.whispersystems.textsecuregcm.storage.Account; | ||||
| import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; | ||||
| import org.whispersystems.textsecuregcm.storage.KeyRecord; | ||||
| import org.whispersystems.textsecuregcm.storage.Keys; | ||||
|  | @ -41,13 +42,14 @@ import static org.mockito.Mockito.doThrow; | |||
| import static org.mockito.Mockito.mock; | ||||
| import static org.mockito.Mockito.when; | ||||
| 
 | ||||
| @Ignore | ||||
| public class KeysTest { | ||||
| 
 | ||||
|   @Rule | ||||
|   public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml")); | ||||
| 
 | ||||
|   private Keys keys; | ||||
|   private Account firstAccount; | ||||
|   private Account secondAccount; | ||||
|   private Keys    keys; | ||||
| 
 | ||||
|   @Before | ||||
|   public void setup() { | ||||
|  | @ -56,6 +58,12 @@ public class KeysTest { | |||
|                                                                             new CircuitBreakerConfiguration()); | ||||
| 
 | ||||
|     this.keys = new Keys(faultTolerantDatabase, new RetryConfiguration()); | ||||
| 
 | ||||
|     this.firstAccount  = mock(Account.class); | ||||
|     this.secondAccount = mock(Account.class); | ||||
| 
 | ||||
|     when(firstAccount.getNumber()).thenReturn("+14152222222"); | ||||
|     when(secondAccount.getNumber()).thenReturn("+14151111111"); | ||||
|   } | ||||
| 
 | ||||
| 
 | ||||
|  | @ -79,18 +87,18 @@ public class KeysTest { | |||
|       anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); | ||||
|     } | ||||
| 
 | ||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); | ||||
|     keys.store("+14152222222", 2, deviceTwoPreKeys); | ||||
|     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||
|     keys.store(firstAccount, 2, deviceTwoPreKeys); | ||||
| 
 | ||||
|     keys.store("+14151111111", 1, oldAnotherDeviceOnePrKeys); | ||||
|     keys.store("+14151111111", 1, anotherDeviceOnePreKeys); | ||||
|     keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); | ||||
|     keys.store(secondAccount, 1, oldAnotherDeviceOnePrKeys); | ||||
|     keys.store(secondAccount, 1, anotherDeviceOnePreKeys); | ||||
|     keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); | ||||
| 
 | ||||
|     PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM keys WHERE number = ? AND device_id = ? ORDER BY key_id"); | ||||
|     verifyStoredState(statement, "+14152222222", 1); | ||||
|     verifyStoredState(statement, "+14152222222", 2); | ||||
|     verifyStoredState(statement, "+14151111111", 1); | ||||
|     verifyStoredState(statement, "+14151111111", 2); | ||||
|     verifyStoredState(statement, firstAccount, 1); | ||||
|     verifyStoredState(statement, firstAccount, 2); | ||||
|     verifyStoredState(statement, secondAccount, 1); | ||||
|     verifyStoredState(statement, secondAccount, 2); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|  | @ -102,11 +110,12 @@ public class KeysTest { | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); | ||||
|     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||
| 
 | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); | ||||
|   } | ||||
| 
 | ||||
|   @Ignore | ||||
|   @Test | ||||
|   public void testGetForDevice() { | ||||
|     List<PreKey> deviceOnePreKeys = new LinkedList<>(); | ||||
|  | @ -125,45 +134,46 @@ public class KeysTest { | |||
|       anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); | ||||
|     } | ||||
| 
 | ||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); | ||||
|     keys.store("+14152222222", 2, deviceTwoPreKeys); | ||||
|     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||
|     keys.store(firstAccount, 2, deviceTwoPreKeys); | ||||
| 
 | ||||
|     keys.store("+14151111111", 1, anotherDeviceOnePreKeys); | ||||
|     keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); | ||||
|     keys.store(secondAccount, 1, anotherDeviceOnePreKeys); | ||||
|     keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); | ||||
| 
 | ||||
| 
 | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); | ||||
|     List<KeyRecord> records = keys.get("+14152222222", 1); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); | ||||
|     List<KeyRecord> records = keys.take(firstAccount, 1); | ||||
| 
 | ||||
|     assertThat(records.size()).isEqualTo(1); | ||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(1); | ||||
|     assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey1"); | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99); | ||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(99); | ||||
|     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); | ||||
| 
 | ||||
|     records = keys.get("+14152222222", 1); | ||||
|     records = keys.take(firstAccount, 1); | ||||
| 
 | ||||
|     assertThat(records.size()).isEqualTo(1); | ||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(2); | ||||
|     assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey2"); | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); | ||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(98); | ||||
|     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); | ||||
| 
 | ||||
|     records = keys.get("+14152222222", 2); | ||||
|     records = keys.take(firstAccount, 2); | ||||
| 
 | ||||
|     assertThat(records.size()).isEqualTo(1); | ||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(1); | ||||
|     assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device2PublicKey1"); | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); | ||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99); | ||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(98); | ||||
|     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(99); | ||||
|     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); | ||||
|   } | ||||
| 
 | ||||
|   @Ignore | ||||
|   @Test | ||||
|   public void testGetForAllDevices() { | ||||
|     List<PreKey> deviceOnePreKeys = new LinkedList<>(); | ||||
|  | @ -184,18 +194,18 @@ public class KeysTest { | |||
|       anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i)); | ||||
|     } | ||||
| 
 | ||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); | ||||
|     keys.store("+14152222222", 2, deviceTwoPreKeys); | ||||
|     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||
|     keys.store(firstAccount, 2, deviceTwoPreKeys); | ||||
| 
 | ||||
|     keys.store("+14151111111", 1, anotherDeviceOnePreKeys); | ||||
|     keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); | ||||
|     keys.store("+14151111111", 3, anotherDeviceThreePreKeys); | ||||
|     keys.store(secondAccount, 1, anotherDeviceOnePreKeys); | ||||
|     keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); | ||||
|     keys.store(secondAccount, 3, anotherDeviceThreePreKeys); | ||||
| 
 | ||||
| 
 | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); | ||||
| 
 | ||||
|     List<KeyRecord> records = keys.get("+14152222222"); | ||||
|     List<KeyRecord> records = keys.take(firstAccount); | ||||
| 
 | ||||
|     assertThat(records.size()).isEqualTo(2); | ||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(1); | ||||
|  | @ -204,10 +214,10 @@ public class KeysTest { | |||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey1"))).isTrue(); | ||||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey1"))).isTrue(); | ||||
| 
 | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99); | ||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(99); | ||||
|     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(99); | ||||
| 
 | ||||
|     records = keys.get("+14152222222"); | ||||
|     records = keys.take(firstAccount); | ||||
| 
 | ||||
|     assertThat(records.size()).isEqualTo(2); | ||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(2); | ||||
|  | @ -216,11 +226,11 @@ public class KeysTest { | |||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey2"))).isTrue(); | ||||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey2"))).isTrue(); | ||||
| 
 | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); | ||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(98); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(98); | ||||
|     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(98); | ||||
| 
 | ||||
| 
 | ||||
|     records = keys.get("+14151111111"); | ||||
|     records = keys.take(secondAccount); | ||||
| 
 | ||||
|     assertThat(records.size()).isEqualTo(3); | ||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(1); | ||||
|  | @ -231,11 +241,12 @@ public class KeysTest { | |||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device2PublicKey1"))).isTrue(); | ||||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device3PublicKey1"))).isTrue(); | ||||
| 
 | ||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(99); | ||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(99); | ||||
|     assertThat(keys.getCount("+14151111111", 3)).isEqualTo(99); | ||||
|     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(99); | ||||
|     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(99); | ||||
|     assertThat(keys.getCount(secondAccount, 3)).isEqualTo(99); | ||||
|   } | ||||
| 
 | ||||
|   @Ignore | ||||
|   @Test | ||||
|   public void testGetForAllDevicesParallel() throws InterruptedException { | ||||
|     List<PreKey> deviceOnePreKeys = new LinkedList<>(); | ||||
|  | @ -246,11 +257,11 @@ public class KeysTest { | |||
|       deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i)); | ||||
|     } | ||||
| 
 | ||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); | ||||
|     keys.store("+14152222222", 2, deviceTwoPreKeys); | ||||
|     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||
|     keys.store(firstAccount, 2, deviceTwoPreKeys); | ||||
| 
 | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); | ||||
| 
 | ||||
|     List<Thread> threads = new LinkedList<>(); | ||||
| 
 | ||||
|  | @ -260,7 +271,7 @@ public class KeysTest { | |||
|         final int MAX_RETRIES = 5; | ||||
|         for (int retryAttempt = 0; results == null && retryAttempt < MAX_RETRIES; ++retryAttempt) { | ||||
|           try { | ||||
|             results = keys.get("+14152222222"); | ||||
|             results = keys.take(firstAccount); | ||||
|           } catch (UnableToExecuteStatementException e) { | ||||
|             if (retryAttempt == MAX_RETRIES - 1) { | ||||
|               throw e; | ||||
|  | @ -278,8 +289,8 @@ public class KeysTest { | |||
|       thread.join(); | ||||
|     } | ||||
| 
 | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(80); | ||||
|     assertThat(keys.getCount("+14152222222",2)).isEqualTo(80); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(80); | ||||
|     assertThat(keys.getCount(firstAccount,2)).isEqualTo(80); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|  | @ -302,32 +313,32 @@ public class KeysTest { | |||
|       anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i)); | ||||
|     } | ||||
| 
 | ||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); | ||||
|     keys.store("+14152222222", 2, deviceTwoPreKeys); | ||||
|     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||
|     keys.store(firstAccount, 2, deviceTwoPreKeys); | ||||
| 
 | ||||
|     keys.store("+14151111111", 1, anotherDeviceOnePreKeys); | ||||
|     keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); | ||||
|     keys.store("+14151111111", 3, anotherDeviceThreePreKeys); | ||||
|     keys.store(secondAccount, 1, anotherDeviceOnePreKeys); | ||||
|     keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); | ||||
|     keys.store(secondAccount, 3, anotherDeviceThreePreKeys); | ||||
| 
 | ||||
| 
 | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14151111111", 3)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(secondAccount, 3)).isEqualTo(100); | ||||
| 
 | ||||
|     keys.delete("+14152222222"); | ||||
|     keys.delete(firstAccount); | ||||
| 
 | ||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(0); | ||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(0); | ||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount("+14151111111", 3)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(0); | ||||
|     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(0); | ||||
|     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); | ||||
|     assertThat(keys.getCount(secondAccount, 3)).isEqualTo(100); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void testEmptyKeyGet() { | ||||
|     List<KeyRecord> records = keys.get("+14152222222"); | ||||
|     List<KeyRecord> records = keys.take(firstAccount); | ||||
| 
 | ||||
|     assertThat(records.isEmpty()).isTrue(); | ||||
|   } | ||||
|  | @ -361,21 +372,21 @@ public class KeysTest { | |||
|     } | ||||
| 
 | ||||
|     try { | ||||
|       keys.store("+14152222222", 1, deviceOnePreKeys); | ||||
|       keys.store(firstAccount, 1, deviceOnePreKeys); | ||||
|       throw new AssertionError(); | ||||
|     } catch (TransactionException e) { | ||||
|       // good | ||||
|     } | ||||
| 
 | ||||
|     try { | ||||
|       keys.store("+14152222222", 1, deviceOnePreKeys); | ||||
|       keys.store(firstAccount, 1, deviceOnePreKeys); | ||||
|       throw new AssertionError(); | ||||
|     } catch (TransactionException e) { | ||||
|       // good | ||||
|     } | ||||
| 
 | ||||
|     try { | ||||
|       keys.store("+14152222222", 1, deviceOnePreKeys); | ||||
|       keys.store(firstAccount, 1, deviceOnePreKeys); | ||||
|       throw new AssertionError(); | ||||
|     } catch (CallNotPermittedException e) { | ||||
|       // good | ||||
|  | @ -384,7 +395,7 @@ public class KeysTest { | |||
|     Thread.sleep(1100); | ||||
| 
 | ||||
|     try { | ||||
|       keys.store("+14152222222", 1, deviceOnePreKeys); | ||||
|       keys.store(firstAccount, 1, deviceOnePreKeys); | ||||
|       throw new AssertionError(); | ||||
|     } catch (TransactionException e) { | ||||
|       // good | ||||
|  | @ -401,7 +412,10 @@ public class KeysTest { | |||
|     Keys keys = new Keys(new FaultTolerantDatabase("testBreaker", jdbi, new CircuitBreakerConfiguration()), new RetryConfiguration()); | ||||
| 
 | ||||
|     // We're happy as long as nothing throws an exception | ||||
|     keys.store("+18005551234", 1, Collections.emptyList()); | ||||
|     Account account = mock(Account.class); | ||||
|     when(account.getNumber()).thenReturn("+18005551234"); | ||||
| 
 | ||||
|     keys.store(account, 1, Collections.emptyList()); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|  | @ -414,12 +428,15 @@ public class KeysTest { | |||
| 
 | ||||
|     Keys keys = new Keys(new FaultTolerantDatabase("testBreaker", jdbi, new CircuitBreakerConfiguration()), new RetryConfiguration()); | ||||
| 
 | ||||
|     assertThat(keys.get("+18005551234")).isEqualTo(Collections.emptyList()); | ||||
|     assertThat(keys.get("+18005551234", 1)).isEqualTo(Collections.emptyList()); | ||||
|     Account account = mock(Account.class); | ||||
|     when(account.getNumber()).thenReturn("+18005551234"); | ||||
| 
 | ||||
|     assertThat(keys.take(account)).isEqualTo(Collections.emptyList()); | ||||
|     assertThat(keys.take(account, 1)).isEqualTo(Collections.emptyList()); | ||||
|   } | ||||
| 
 | ||||
|   private void verifyStoredState(PreparedStatement statement, String number, int deviceId) throws SQLException { | ||||
|     statement.setString(1, number); | ||||
|   private void verifyStoredState(PreparedStatement statement, Account account, int deviceId) throws SQLException { | ||||
|     statement.setString(1, account.getNumber()); | ||||
|     statement.setInt(2, deviceId); | ||||
| 
 | ||||
|     ResultSet resultSet = statement.executeQuery(); | ||||
|  | @ -431,7 +448,7 @@ public class KeysTest { | |||
| 
 | ||||
| 
 | ||||
|       assertThat(keyId).isEqualTo(rowCount); | ||||
|       assertThat(publicKey).isEqualTo(number + "Device" + deviceId + "PublicKey" + rowCount); | ||||
|       assertThat(publicKey).isEqualTo(account.getNumber() + "Device" + deviceId + "PublicKey" + rowCount); | ||||
| 
 | ||||
|       rowCount++; | ||||
|     } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Jon Chambers
						Jon Chambers