Add a Dynamo-backed key store.
This commit is contained in:
		
							parent
							
								
									426e6923ac
								
							
						
					
					
						commit
						d4d9403829
					
				|  | @ -14,6 +14,7 @@ import org.whispersystems.textsecuregcm.configuration.AwsAttachmentsConfiguratio | ||||||
| import org.whispersystems.textsecuregcm.configuration.CdnConfiguration; | import org.whispersystems.textsecuregcm.configuration.CdnConfiguration; | ||||||
| import org.whispersystems.textsecuregcm.configuration.DatabaseConfiguration; | import org.whispersystems.textsecuregcm.configuration.DatabaseConfiguration; | ||||||
| import org.whispersystems.textsecuregcm.configuration.DirectoryConfiguration; | import org.whispersystems.textsecuregcm.configuration.DirectoryConfiguration; | ||||||
|  | import org.whispersystems.textsecuregcm.configuration.DynamoDbConfiguration; | ||||||
| import org.whispersystems.textsecuregcm.configuration.GcmConfiguration; | import org.whispersystems.textsecuregcm.configuration.GcmConfiguration; | ||||||
| import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguration; | import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguration; | ||||||
| import org.whispersystems.textsecuregcm.configuration.AccountsDatabaseConfiguration; | import org.whispersystems.textsecuregcm.configuration.AccountsDatabaseConfiguration; | ||||||
|  | @ -128,6 +129,11 @@ public class WhisperServerConfiguration extends Configuration { | ||||||
|   @JsonProperty |   @JsonProperty | ||||||
|   private MessageDynamoDbConfiguration messageDynamoDb; |   private MessageDynamoDbConfiguration messageDynamoDb; | ||||||
| 
 | 
 | ||||||
|  |   @Valid | ||||||
|  |   @NotNull | ||||||
|  |   @JsonProperty | ||||||
|  |   private DynamoDbConfiguration keysDynamoDb; | ||||||
|  | 
 | ||||||
|   @Valid |   @Valid | ||||||
|   @NotNull |   @NotNull | ||||||
|   @JsonProperty |   @JsonProperty | ||||||
|  | @ -306,6 +312,10 @@ public class WhisperServerConfiguration extends Configuration { | ||||||
|     return messageDynamoDb; |     return messageDynamoDb; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   public DynamoDbConfiguration getKeysDynamoDbConfiguration() { | ||||||
|  |     return keysDynamoDb; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   public DatabaseConfiguration getMessageStoreConfiguration() { |   public DatabaseConfiguration getMessageStoreConfiguration() { | ||||||
|     return messageStore; |     return messageStore; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -126,6 +126,7 @@ import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; | ||||||
| import org.whispersystems.textsecuregcm.storage.FeatureFlags; | import org.whispersystems.textsecuregcm.storage.FeatureFlags; | ||||||
| import org.whispersystems.textsecuregcm.storage.FeatureFlagsManager; | import org.whispersystems.textsecuregcm.storage.FeatureFlagsManager; | ||||||
| import org.whispersystems.textsecuregcm.storage.Keys; | import org.whispersystems.textsecuregcm.storage.Keys; | ||||||
|  | import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; | ||||||
| import org.whispersystems.textsecuregcm.storage.MessagePersister; | import org.whispersystems.textsecuregcm.storage.MessagePersister; | ||||||
| import org.whispersystems.textsecuregcm.storage.Messages; | import org.whispersystems.textsecuregcm.storage.Messages; | ||||||
| import org.whispersystems.textsecuregcm.storage.MessagesCache; | import org.whispersystems.textsecuregcm.storage.MessagesCache; | ||||||
|  | @ -275,7 +276,16 @@ public class WhisperServerService extends Application<WhisperServerConfiguration | ||||||
|             .withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) config.getMessageDynamoDbConfiguration().getClientExecutionTimeout().toMillis())) |             .withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) config.getMessageDynamoDbConfiguration().getClientExecutionTimeout().toMillis())) | ||||||
|                                                               .withRequestTimeout((int) config.getMessageDynamoDbConfiguration().getClientRequestTimeout().toMillis())) |                                                               .withRequestTimeout((int) config.getMessageDynamoDbConfiguration().getClientRequestTimeout().toMillis())) | ||||||
|             .withCredentials(InstanceProfileCredentialsProvider.getInstance()); |             .withCredentials(InstanceProfileCredentialsProvider.getInstance()); | ||||||
|  | 
 | ||||||
|  |     AmazonDynamoDBClientBuilder keysDynamoDbClientBuilder = AmazonDynamoDBClientBuilder | ||||||
|  |             .standard() | ||||||
|  |             .withRegion(config.getKeysDynamoDbConfiguration().getRegion()) | ||||||
|  |             .withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) config.getKeysDynamoDbConfiguration().getClientExecutionTimeout().toMillis())) | ||||||
|  |                                                               .withRequestTimeout((int) config.getKeysDynamoDbConfiguration().getClientRequestTimeout().toMillis())) | ||||||
|  |             .withCredentials(InstanceProfileCredentialsProvider.getInstance()); | ||||||
|  | 
 | ||||||
|     DynamoDB messageDynamoDb = new DynamoDB(messageDynamoDbClientBuilder.build()); |     DynamoDB messageDynamoDb = new DynamoDB(messageDynamoDbClientBuilder.build()); | ||||||
|  |     DynamoDB preKeyDynamoDb = new DynamoDB(keysDynamoDbClientBuilder.build()); | ||||||
| 
 | 
 | ||||||
|     Accounts          accounts          = new Accounts(accountDatabase); |     Accounts          accounts          = new Accounts(accountDatabase); | ||||||
|     PendingAccounts   pendingAccounts   = new PendingAccounts(accountDatabase); |     PendingAccounts   pendingAccounts   = new PendingAccounts(accountDatabase); | ||||||
|  | @ -284,6 +294,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration | ||||||
|     ReservedUsernames reservedUsernames = new ReservedUsernames(accountDatabase); |     ReservedUsernames reservedUsernames = new ReservedUsernames(accountDatabase); | ||||||
|     Profiles          profiles          = new Profiles(accountDatabase); |     Profiles          profiles          = new Profiles(accountDatabase); | ||||||
|     Keys              keys              = new Keys(accountDatabase, config.getAccountsDatabaseConfiguration().getKeyOperationRetryConfiguration()); |     Keys              keys              = new Keys(accountDatabase, config.getAccountsDatabaseConfiguration().getKeyOperationRetryConfiguration()); | ||||||
|  |     KeysDynamoDb      keysDynamoDb      = new KeysDynamoDb(preKeyDynamoDb, config.getKeysDynamoDbConfiguration().getTableName()); | ||||||
|     Messages          messages          = new Messages(messageDatabase); |     Messages          messages          = new Messages(messageDatabase); | ||||||
|     MessagesDynamoDb  messagesDynamoDb  = new MessagesDynamoDb(messageDynamoDb, config.getMessageDynamoDbConfiguration().getTableName(), config.getMessageDynamoDbConfiguration().getTimeToLive()); |     MessagesDynamoDb  messagesDynamoDb  = new MessagesDynamoDb(messageDynamoDb, config.getMessageDynamoDbConfiguration().getTableName(), config.getMessageDynamoDbConfiguration().getTimeToLive()); | ||||||
|     AbusiveHostRules  abusiveHostRules  = new AbusiveHostRules(abuseDatabase); |     AbusiveHostRules  abusiveHostRules  = new AbusiveHostRules(abuseDatabase); | ||||||
|  | @ -338,7 +349,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration | ||||||
|     MessagesCache              messagesCache              = new MessagesCache(messagesCluster, messagesCluster, keyspaceNotificationDispatchExecutor); |     MessagesCache              messagesCache              = new MessagesCache(messagesCluster, messagesCluster, keyspaceNotificationDispatchExecutor); | ||||||
|     PushLatencyManager         pushLatencyManager         = new PushLatencyManager(metricsCluster); |     PushLatencyManager         pushLatencyManager         = new PushLatencyManager(metricsCluster); | ||||||
|     MessagesManager            messagesManager            = new MessagesManager(messages, messagesDynamoDb, messagesCache, pushLatencyManager, experimentEnrollmentManager); |     MessagesManager            messagesManager            = new MessagesManager(messages, messagesDynamoDb, messagesCache, pushLatencyManager, experimentEnrollmentManager); | ||||||
|     AccountsManager            accountsManager            = new AccountsManager(accounts, directory, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); |     AccountsManager            accountsManager            = new AccountsManager(accounts, directory, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||||
|     RemoteConfigsManager       remoteConfigsManager       = new RemoteConfigsManager(remoteConfigs); |     RemoteConfigsManager       remoteConfigsManager       = new RemoteConfigsManager(remoteConfigs); | ||||||
|     FeatureFlagsManager        featureFlagsManager        = new FeatureFlagsManager(featureFlags, recurringJobExecutor); |     FeatureFlagsManager        featureFlagsManager        = new FeatureFlagsManager(featureFlags, recurringJobExecutor); | ||||||
|     DeadLetterHandler          deadLetterHandler          = new DeadLetterHandler(accountsManager, messagesManager); |     DeadLetterHandler          deadLetterHandler          = new DeadLetterHandler(accountsManager, messagesManager); | ||||||
|  |  | ||||||
|  | @ -0,0 +1,50 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2021 Signal Messenger, LLC | ||||||
|  |  * SPDX-License-Identifier: AGPL-3.0-only | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.whispersystems.textsecuregcm.configuration; | ||||||
|  | 
 | ||||||
|  | import com.fasterxml.jackson.annotation.JsonProperty; | ||||||
|  | 
 | ||||||
|  | import javax.validation.Valid; | ||||||
|  | import javax.validation.constraints.NotBlank; | ||||||
|  | import javax.validation.constraints.NotEmpty; | ||||||
|  | import java.time.Duration; | ||||||
|  | 
 | ||||||
|  | public class DynamoDbConfiguration { | ||||||
|  | 
 | ||||||
|  |     @JsonProperty | ||||||
|  |     @NotBlank | ||||||
|  |     private String region; | ||||||
|  | 
 | ||||||
|  |     @JsonProperty | ||||||
|  |     @NotBlank | ||||||
|  |     private String tableName; | ||||||
|  | 
 | ||||||
|  |     @JsonProperty | ||||||
|  |     private Duration clientExecutionTimeout = Duration.ofSeconds(30); | ||||||
|  | 
 | ||||||
|  |     @JsonProperty | ||||||
|  |     private Duration clientRequestTimeout = Duration.ofSeconds(10); | ||||||
|  | 
 | ||||||
|  |     @Valid | ||||||
|  |     @NotEmpty | ||||||
|  |     public String getRegion() { | ||||||
|  |         return region; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Valid | ||||||
|  |     @NotEmpty | ||||||
|  |     public String getTableName() { | ||||||
|  |         return tableName; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public Duration getClientExecutionTimeout() { | ||||||
|  |         return clientExecutionTimeout; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public Duration getClientRequestTimeout() { | ||||||
|  |         return clientRequestTimeout; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | @ -9,35 +9,12 @@ import javax.validation.Valid; | ||||||
| import javax.validation.constraints.NotEmpty; | import javax.validation.constraints.NotEmpty; | ||||||
| import java.time.Duration; | import java.time.Duration; | ||||||
| 
 | 
 | ||||||
| public class MessageDynamoDbConfiguration { | public class MessageDynamoDbConfiguration extends DynamoDbConfiguration { | ||||||
|   private String region; | 
 | ||||||
|   private String tableName; |  | ||||||
|   private Duration timeToLive = Duration.ofDays(7); |   private Duration timeToLive = Duration.ofDays(7); | ||||||
|   private Duration clientExecutionTimeout = Duration.ofSeconds(30); |  | ||||||
|   private Duration clientRequestTimeout = Duration.ofSeconds(10); |  | ||||||
| 
 |  | ||||||
|   @Valid |  | ||||||
|   @NotEmpty |  | ||||||
|   public String getRegion() { |  | ||||||
|     return region; |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   @Valid |  | ||||||
|   @NotEmpty |  | ||||||
|   public String getTableName() { |  | ||||||
|     return tableName; |  | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   @Valid |   @Valid | ||||||
|   public Duration getTimeToLive() { |   public Duration getTimeToLive() { | ||||||
|     return timeToLive; |     return timeToLive; | ||||||
|   } |   } | ||||||
| 
 |  | ||||||
|   public Duration getClientExecutionTimeout() { |  | ||||||
|     return clientExecutionTimeout; |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   public Duration getClientRequestTimeout() { |  | ||||||
|     return clientRequestTimeout; |  | ||||||
|   } |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -62,7 +62,7 @@ public class KeysController { | ||||||
|   @GET |   @GET | ||||||
|   @Produces(MediaType.APPLICATION_JSON) |   @Produces(MediaType.APPLICATION_JSON) | ||||||
|   public PreKeyCount getStatus(@Auth Account account) { |   public PreKeyCount getStatus(@Auth Account account) { | ||||||
|     int count = keys.getCount(account.getNumber(), account.getAuthenticatedDevice().get().getId()); |     int count = keys.getCount(account, account.getAuthenticatedDevice().get().getId()); | ||||||
| 
 | 
 | ||||||
|     if (count > 0) { |     if (count > 0) { | ||||||
|       count = count - 1; |       count = count - 1; | ||||||
|  | @ -98,7 +98,7 @@ public class KeysController { | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     keys.store(account.getNumber(), device.getId(), preKeys.getPreKeys()); |     keys.store(account, device.getId(), preKeys.getPreKeys()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Timed |   @Timed | ||||||
|  | @ -179,12 +179,12 @@ public class KeysController { | ||||||
|   private List<KeyRecord> getLocalKeys(Account destination, String deviceIdSelector) { |   private List<KeyRecord> getLocalKeys(Account destination, String deviceIdSelector) { | ||||||
|     try { |     try { | ||||||
|       if (deviceIdSelector.equals("*")) { |       if (deviceIdSelector.equals("*")) { | ||||||
|         return keys.get(destination.getNumber()); |         return keys.take(destination); | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|       long deviceId = Long.parseLong(deviceIdSelector); |       long deviceId = Long.parseLong(deviceIdSelector); | ||||||
| 
 | 
 | ||||||
|       return keys.get(destination.getNumber(), deviceId); |       return keys.take(destination, deviceId); | ||||||
|     } catch (NumberFormatException e) { |     } catch (NumberFormatException e) { | ||||||
|       throw new WebApplicationException(Response.status(422).build()); |       throw new WebApplicationException(Response.status(422).build()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,76 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2021 Signal Messenger, LLC | ||||||
|  |  * SPDX-License-Identifier: AGPL-3.0-only | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.whispersystems.textsecuregcm.storage; | ||||||
|  | 
 | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.BatchWriteItemOutcome; | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.DynamoDB; | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.TableWriteItems; | ||||||
|  | import io.micrometer.core.instrument.Counter; | ||||||
|  | import io.micrometer.core.instrument.Timer; | ||||||
|  | import org.slf4j.Logger; | ||||||
|  | import org.slf4j.LoggerFactory; | ||||||
|  | 
 | ||||||
|  | import java.util.ArrayList; | ||||||
|  | import java.util.List; | ||||||
|  | import java.util.concurrent.atomic.AtomicReference; | ||||||
|  | import java.util.function.Consumer; | ||||||
|  | 
 | ||||||
|  | import static com.codahale.metrics.MetricRegistry.name; | ||||||
|  | import static io.micrometer.core.instrument.Metrics.counter; | ||||||
|  | import static io.micrometer.core.instrument.Metrics.timer; | ||||||
|  | 
 | ||||||
|  | public class AbstractDynamoDbStore { | ||||||
|  | 
 | ||||||
|  |     private final DynamoDB dynamoDb; | ||||||
|  | 
 | ||||||
|  |     private final Timer   batchWriteItemsFirstPass   = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "true"); | ||||||
|  |     private final Timer   batchWriteItemsRetryPass   = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "false"); | ||||||
|  |     private final Counter batchWriteItemsUnprocessed = counter(name(getClass(), "batchWriteItemsUnprocessed")); | ||||||
|  | 
 | ||||||
|  |     private final Logger logger = LoggerFactory.getLogger(getClass()); | ||||||
|  | 
 | ||||||
|  |     private static final int MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE = 25;  // This was arbitrarily chosen and may be entirely too high. | ||||||
|  |     public static final int DYNAMO_DB_MAX_BATCH_SIZE = 25;  // This limit comes from Amazon Dynamo DB itself. It will reject batch writes larger than this. | ||||||
|  |     public static final int RESULT_SET_CHUNK_SIZE = 100; | ||||||
|  | 
 | ||||||
|  |     public AbstractDynamoDbStore(final DynamoDB dynamoDb) { | ||||||
|  |         this.dynamoDb = dynamoDb; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     protected DynamoDB getDynamoDb() { | ||||||
|  |         return dynamoDb; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     protected void executeTableWriteItemsUntilComplete(final TableWriteItems items) { | ||||||
|  |         AtomicReference<BatchWriteItemOutcome> outcome = new AtomicReference<>(); | ||||||
|  |         batchWriteItemsFirstPass.record(() -> outcome.set(dynamoDb.batchWriteItem(items))); | ||||||
|  |         int attemptCount = 0; | ||||||
|  |         while (!outcome.get().getUnprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) { | ||||||
|  |             batchWriteItemsRetryPass.record(() -> outcome.set(dynamoDb.batchWriteItemUnprocessed(outcome.get().getUnprocessedItems()))); | ||||||
|  |             ++attemptCount; | ||||||
|  |         } | ||||||
|  |         if (!outcome.get().getUnprocessedItems().isEmpty()) { | ||||||
|  |             logger.error("Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.", attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, outcome.get().getUnprocessedItems().size()); | ||||||
|  |             batchWriteItemsUnprocessed.increment(outcome.get().getUnprocessedItems().size()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     static <T> void writeInBatches(final Iterable<T> items, final Consumer<List<T>> action) { | ||||||
|  |         final List<T> batch = new ArrayList<>(DYNAMO_DB_MAX_BATCH_SIZE); | ||||||
|  | 
 | ||||||
|  |         for (T item : items) { | ||||||
|  |             batch.add(item); | ||||||
|  | 
 | ||||||
|  |             if (batch.size() == DYNAMO_DB_MAX_BATCH_SIZE) { | ||||||
|  |                 action.accept(batch); | ||||||
|  |                 batch.clear(); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         if (!batch.isEmpty()) { | ||||||
|  |             action.accept(batch); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | @ -56,6 +56,7 @@ public class AccountsManager { | ||||||
|   private final DirectoryManager          directory; |   private final DirectoryManager          directory; | ||||||
|   private final DirectoryQueue            directoryQueue; |   private final DirectoryQueue            directoryQueue; | ||||||
|   private final Keys                      keys; |   private final Keys                      keys; | ||||||
|  |   private final KeysDynamoDb              keysDynamoDb; | ||||||
|   private final MessagesManager           messagesManager; |   private final MessagesManager           messagesManager; | ||||||
|   private final UsernamesManager          usernamesManager; |   private final UsernamesManager          usernamesManager; | ||||||
|   private final ProfilesManager           profilesManager; |   private final ProfilesManager           profilesManager; | ||||||
|  | @ -73,12 +74,13 @@ public class AccountsManager { | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   public AccountsManager(Accounts accounts, DirectoryManager directory, FaultTolerantRedisCluster cacheCluster, final DirectoryQueue directoryQueue, final Keys keys, final MessagesManager messagesManager, final UsernamesManager usernamesManager, final ProfilesManager profilesManager) { |   public AccountsManager(Accounts accounts, DirectoryManager directory, FaultTolerantRedisCluster cacheCluster, final DirectoryQueue directoryQueue, final Keys keys, final KeysDynamoDb keysDynamoDb, final MessagesManager messagesManager, final UsernamesManager usernamesManager, final ProfilesManager profilesManager) { | ||||||
|     this.accounts         = accounts; |     this.accounts         = accounts; | ||||||
|     this.directory        = directory; |     this.directory        = directory; | ||||||
|     this.cacheCluster     = cacheCluster; |     this.cacheCluster     = cacheCluster; | ||||||
|     this.directoryQueue   = directoryQueue; |     this.directoryQueue   = directoryQueue; | ||||||
|     this.keys             = keys; |     this.keys             = keys; | ||||||
|  |     this.keysDynamoDb     = keysDynamoDb; | ||||||
|     this.messagesManager  = messagesManager; |     this.messagesManager  = messagesManager; | ||||||
|     this.usernamesManager = usernamesManager; |     this.usernamesManager = usernamesManager; | ||||||
|     this.profilesManager  = profilesManager; |     this.profilesManager  = profilesManager; | ||||||
|  | @ -150,7 +152,8 @@ public class AccountsManager { | ||||||
|       directoryQueue.deleteAccount(account); |       directoryQueue.deleteAccount(account); | ||||||
|       directory.remove(account.getNumber()); |       directory.remove(account.getNumber()); | ||||||
|       profilesManager.deleteAll(account.getUuid()); |       profilesManager.deleteAll(account.getUuid()); | ||||||
|       keys.delete(account.getNumber()); |       keys.delete(account); | ||||||
|  |       keysDynamoDb.delete(account); | ||||||
|       messagesManager.clear(account.getNumber(), account.getUuid()); |       messagesManager.clear(account.getNumber(), account.getUuid()); | ||||||
|       redisDelete(account); |       redisDelete(account); | ||||||
|       databaseDelete(account); |       databaseDelete(account); | ||||||
|  |  | ||||||
|  | @ -5,6 +5,8 @@ | ||||||
| 
 | 
 | ||||||
| package org.whispersystems.textsecuregcm.storage; | package org.whispersystems.textsecuregcm.storage; | ||||||
| 
 | 
 | ||||||
|  | import java.util.Objects; | ||||||
|  | 
 | ||||||
| public class KeyRecord { | public class KeyRecord { | ||||||
| 
 | 
 | ||||||
|   private long    id; |   private long    id; | ||||||
|  | @ -41,4 +43,20 @@ public class KeyRecord { | ||||||
|     return publicKey; |     return publicKey; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   @Override | ||||||
|  |   public boolean equals(final Object o) { | ||||||
|  |     if (this == o) return true; | ||||||
|  |     if (o == null || getClass() != o.getClass()) return false; | ||||||
|  |     final KeyRecord keyRecord = (KeyRecord)o; | ||||||
|  |     return id == keyRecord.id && | ||||||
|  |             deviceId == keyRecord.deviceId && | ||||||
|  |             keyId == keyRecord.keyId && | ||||||
|  |             Objects.equals(number, keyRecord.number) && | ||||||
|  |             Objects.equals(publicKey, keyRecord.publicKey); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @Override | ||||||
|  |   public int hashCode() { | ||||||
|  |     return Objects.hash(id, number, deviceId, keyId, publicKey); | ||||||
|  |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -27,7 +27,7 @@ import java.util.function.Supplier; | ||||||
| 
 | 
 | ||||||
| import static com.codahale.metrics.MetricRegistry.name; | import static com.codahale.metrics.MetricRegistry.name; | ||||||
| 
 | 
 | ||||||
| public class Keys { | public class Keys implements PreKeyStore { | ||||||
| 
 | 
 | ||||||
|   private final MetricRegistry metricRegistry  = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); |   private final MetricRegistry metricRegistry  = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); | ||||||
|   private final Meter          fallbackMeter   = metricRegistry.meter(name(Keys.class, "fallback")); |   private final Meter          fallbackMeter   = metricRegistry.meter(name(Keys.class, "fallback")); | ||||||
|  | @ -49,7 +49,10 @@ public class Keys { | ||||||
|     this.retry = Retry.of("keys", retryConfiguration.toRetryConfigBuilder().build()); |     this.retry = Retry.of("keys", retryConfiguration.toRetryConfigBuilder().build()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   public void store(String number, long deviceId, List<PreKey> keys) { |   @Override | ||||||
|  |   public void store(Account account, long deviceId, List<PreKey> keys) { | ||||||
|  |     final String number = account.getNumber(); | ||||||
|  | 
 | ||||||
|     retry.executeRunnable(() -> { |     retry.executeRunnable(() -> { | ||||||
|       database.use(jdbi -> jdbi.useTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { |       database.use(jdbi -> jdbi.useTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { | ||||||
|         try (Timer.Context ignored = storeTimer.time()) { |         try (Timer.Context ignored = storeTimer.time()) { | ||||||
|  | @ -74,8 +77,12 @@ public class Keys { | ||||||
|     }); |     }); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   public List<KeyRecord> get(String number, long deviceId) { |   @Override | ||||||
|     /* try { |   public List<KeyRecord> take(Account account, long deviceId) { | ||||||
|  |     /* | ||||||
|  |     final String number = account.getNumber(); | ||||||
|  | 
 | ||||||
|  |     try { | ||||||
|       return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { |       return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { | ||||||
|         try (Timer.Context ignored = getDevicetTimer.time()) { |         try (Timer.Context ignored = getDevicetTimer.time()) { | ||||||
|           return handle.createQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *") |           return handle.createQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *") | ||||||
|  | @ -95,8 +102,12 @@ public class Keys { | ||||||
|     return new LinkedList<>(); |     return new LinkedList<>(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   public List<KeyRecord> get(String number) { |   @Override | ||||||
|     /* try { |   public List<KeyRecord> take(Account account) { | ||||||
|  |     /* | ||||||
|  |     final String number = account.getNumber(); | ||||||
|  | 
 | ||||||
|  |     try { | ||||||
|       return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { |       return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { | ||||||
|         try (Timer.Context ignored = getTimer.time()) { |         try (Timer.Context ignored = getTimer.time()) { | ||||||
|           return handle.createQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *") |           return handle.createQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *") | ||||||
|  | @ -115,7 +126,10 @@ public class Keys { | ||||||
|     return new LinkedList<>(); |     return new LinkedList<>(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   public int getCount(String number, long deviceId) { |   @Override | ||||||
|  |   public int getCount(Account account, long deviceId) { | ||||||
|  |     final String number = account.getNumber(); | ||||||
|  | 
 | ||||||
|     return database.with(jdbi -> jdbi.withHandle(handle -> { |     return database.with(jdbi -> jdbi.withHandle(handle -> { | ||||||
|       try (Timer.Context ignored = getCountTimer.time()) { |       try (Timer.Context ignored = getCountTimer.time()) { | ||||||
|         return handle.createQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id") |         return handle.createQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id") | ||||||
|  | @ -127,7 +141,9 @@ public class Keys { | ||||||
|     })); |     })); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   public void delete(final String number) { |   public void delete(final Account account) { | ||||||
|  |     final String number = account.getNumber(); | ||||||
|  | 
 | ||||||
|     database.use(jdbi -> jdbi.useHandle(handle -> { |     database.use(jdbi -> jdbi.useHandle(handle -> { | ||||||
|       try (Timer.Context ignored = getCountTimer.time()) { |       try (Timer.Context ignored = getCountTimer.time()) { | ||||||
|         handle.createUpdate("DELETE FROM keys WHERE number = :number") |         handle.createUpdate("DELETE FROM keys WHERE number = :number") | ||||||
|  |  | ||||||
|  | @ -0,0 +1,208 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2021 Signal Messenger, LLC | ||||||
|  |  * SPDX-License-Identifier: AGPL-3.0-only | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.whispersystems.textsecuregcm.storage; | ||||||
|  | 
 | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.DeleteItemOutcome; | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.DynamoDB; | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.Item; | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.PrimaryKey; | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.Table; | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.TableWriteItems; | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.spec.DeleteItemSpec; | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.spec.QuerySpec; | ||||||
|  | import com.amazonaws.services.dynamodbv2.model.ReturnValue; | ||||||
|  | import com.amazonaws.services.dynamodbv2.model.Select; | ||||||
|  | import com.google.common.annotations.VisibleForTesting; | ||||||
|  | import io.micrometer.core.instrument.DistributionSummary; | ||||||
|  | import io.micrometer.core.instrument.Metrics; | ||||||
|  | import io.micrometer.core.instrument.Timer; | ||||||
|  | import org.whispersystems.textsecuregcm.entities.PreKey; | ||||||
|  | 
 | ||||||
|  | import java.nio.ByteBuffer; | ||||||
|  | import java.util.ArrayList; | ||||||
|  | import java.util.Collections; | ||||||
|  | import java.util.List; | ||||||
|  | import java.util.Map; | ||||||
|  | import java.util.UUID; | ||||||
|  | 
 | ||||||
|  | import static com.codahale.metrics.MetricRegistry.name; | ||||||
|  | 
 | ||||||
|  | public class KeysDynamoDb extends AbstractDynamoDbStore implements PreKeyStore { | ||||||
|  | 
 | ||||||
|  |     private final Table table; | ||||||
|  | 
 | ||||||
|  |     static final String KEY_ACCOUNT_UUID = "U"; | ||||||
|  |     static final String KEY_DEVICE_ID_KEY_ID = "DK"; | ||||||
|  |     static final String KEY_PUBLIC_KEY = "P"; | ||||||
|  | 
 | ||||||
|  |     private static final Timer               STORE_KEYS_TIMER              = Metrics.timer(name(KeysDynamoDb.class, "storeKeys")); | ||||||
|  |     private static final Timer               TAKE_KEY_FOR_DEVICE_TIMER     = Metrics.timer(name(KeysDynamoDb.class, "takeKeyForDevice")); | ||||||
|  |     private static final Timer               TAKE_KEYS_FOR_ACCOUNT_TIMER   = Metrics.timer(name(KeysDynamoDb.class, "takeKeyForAccount")); | ||||||
|  |     private static final Timer               GET_KEY_COUNT_TIMER           = Metrics.timer(name(KeysDynamoDb.class, "getKeyCount")); | ||||||
|  |     private static final Timer               DELETE_KEYS_FOR_DEVICE_TIMER  = Metrics.timer(name(KeysDynamoDb.class, "deleteKeysForDevice")); | ||||||
|  |     private static final Timer               DELETE_KEYS_FOR_ACCOUNT_TIMER = Metrics.timer(name(KeysDynamoDb.class, "deleteKeysForAccount")); | ||||||
|  |     private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION    = Metrics.summary(name(KeysDynamoDb.class, "contestedKeys")); | ||||||
|  | 
 | ||||||
|  |     public KeysDynamoDb(final DynamoDB dynamoDB, final String tableName) { | ||||||
|  |         super(dynamoDB); | ||||||
|  | 
 | ||||||
|  |         this.table = dynamoDB.getTable(tableName); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public void store(final Account account, final long deviceId, final List<PreKey> keys) { | ||||||
|  |         STORE_KEYS_TIMER.record(() -> { | ||||||
|  |             delete(account, deviceId); | ||||||
|  | 
 | ||||||
|  |             writeInBatches(keys, batch -> { | ||||||
|  |                 final TableWriteItems items = new TableWriteItems(table.getTableName()); | ||||||
|  | 
 | ||||||
|  |                 for (final PreKey preKey : batch) { | ||||||
|  |                     items.addItemToPut(getItemFromPreKey(account.getUuid(), deviceId, preKey)); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 executeTableWriteItemsUntilComplete(items); | ||||||
|  |             }); | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public List<KeyRecord> take(final Account account, final long deviceId) { | ||||||
|  |         return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> { | ||||||
|  |             final byte[] partitionKey = getPartitionKey(account.getUuid()); | ||||||
|  | 
 | ||||||
|  |             final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") | ||||||
|  |                                                        .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) | ||||||
|  |                                                        .withValueMap(Map.of(":uuid", partitionKey, | ||||||
|  |                                                                             ":sortprefix", getSortKeyPrefix(deviceId))) | ||||||
|  |                                                        .withProjectionExpression(KEY_DEVICE_ID_KEY_ID) | ||||||
|  |                                                        .withConsistentRead(false); | ||||||
|  | 
 | ||||||
|  |             int contestedKeys = 0; | ||||||
|  | 
 | ||||||
|  |             try { | ||||||
|  |                 for (final Item candidate : table.query(querySpec)) { | ||||||
|  |                     final DeleteItemSpec deleteItemSpec = new DeleteItemSpec().withPrimaryKey(KEY_ACCOUNT_UUID, partitionKey, KEY_DEVICE_ID_KEY_ID, candidate.getBinary(KEY_DEVICE_ID_KEY_ID)) | ||||||
|  |                                                                               .withReturnValues(ReturnValue.ALL_OLD); | ||||||
|  | 
 | ||||||
|  |                     final DeleteItemOutcome outcome = table.deleteItem(deleteItemSpec); | ||||||
|  | 
 | ||||||
|  |                     if (outcome.getItem() != null) { | ||||||
|  |                         final PreKey preKey = getPreKeyFromItem(outcome.getItem()); | ||||||
|  |                         return List.of(new KeyRecord(-1, account.getNumber(), deviceId, preKey.getKeyId(), preKey.getPublicKey())); | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                     contestedKeys++; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 return Collections.emptyList(); | ||||||
|  |             } finally { | ||||||
|  |                 CONTESTED_KEY_DISTRIBUTION.record(contestedKeys); | ||||||
|  |             } | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public List<KeyRecord> take(final Account account) { | ||||||
|  |         return TAKE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { | ||||||
|  |             final List<KeyRecord> keyRecords = new ArrayList<>(); | ||||||
|  | 
 | ||||||
|  |             for (final Device device : account.getDevices()) { | ||||||
|  |                 keyRecords.addAll(take(account, device.getId())); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             return keyRecords; | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public int getCount(final Account account, final long deviceId) { | ||||||
|  |         return GET_KEY_COUNT_TIMER.record(() -> { | ||||||
|  |             final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") | ||||||
|  |                                                        .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) | ||||||
|  |                                                        .withValueMap(Map.of(":uuid", getPartitionKey(account.getUuid()), | ||||||
|  |                                                                             ":sortprefix", getSortKeyPrefix(deviceId))) | ||||||
|  |                                                        .withSelect(Select.COUNT) | ||||||
|  |                                                        .withConsistentRead(false); | ||||||
|  | 
 | ||||||
|  |             // This is very confusing, but does appear to be the intended behavior. See: | ||||||
|  |             // | ||||||
|  |             // - https://github.com/aws/aws-sdk-java/issues/693 | ||||||
|  |             // - https://github.com/aws/aws-sdk-java/issues/915 | ||||||
|  |             return table.query(querySpec).firstPage().getLowLevelResult().getQueryResult().getCount(); | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public void delete(final Account account) { | ||||||
|  |         DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { | ||||||
|  |             final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid") | ||||||
|  |                                                        .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID)) | ||||||
|  |                                                        .withValueMap(Map.of(":uuid", getPartitionKey(account.getUuid()))) | ||||||
|  |                                                        .withProjectionExpression(KEY_ACCOUNT_UUID + ", " + KEY_DEVICE_ID_KEY_ID) | ||||||
|  |                                                        .withConsistentRead(true); | ||||||
|  | 
 | ||||||
|  |             deleteItemsMatchingQuery(querySpec); | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @VisibleForTesting | ||||||
|  |     void delete(final Account account, final long deviceId) { | ||||||
|  |         DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> { | ||||||
|  |             final QuerySpec querySpec = new QuerySpec().withKeyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") | ||||||
|  |                                                        .withNameMap(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) | ||||||
|  |                                                        .withValueMap(Map.of(":uuid", getPartitionKey(account.getUuid()), | ||||||
|  |                                                                             ":sortprefix", getSortKeyPrefix(deviceId))) | ||||||
|  |                                                        .withProjectionExpression(KEY_ACCOUNT_UUID + ", " + KEY_DEVICE_ID_KEY_ID) | ||||||
|  |                                                        .withConsistentRead(true); | ||||||
|  | 
 | ||||||
|  |             deleteItemsMatchingQuery(querySpec); | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private void deleteItemsMatchingQuery(final QuerySpec querySpec) { | ||||||
|  |         writeInBatches(table.query(querySpec), batch -> { | ||||||
|  |             final TableWriteItems writeItems = new TableWriteItems(table.getTableName()); | ||||||
|  | 
 | ||||||
|  |             for (final Item item : batch) { | ||||||
|  |                 writeItems.addPrimaryKeyToDelete(new PrimaryKey(KEY_ACCOUNT_UUID, item.getBinary(KEY_ACCOUNT_UUID), KEY_DEVICE_ID_KEY_ID, item.getBinary(KEY_DEVICE_ID_KEY_ID))); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             executeTableWriteItemsUntilComplete(writeItems); | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private static byte[] getPartitionKey(final UUID accountUuid) { | ||||||
|  |         final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); | ||||||
|  |         byteBuffer.putLong(accountUuid.getMostSignificantBits()); | ||||||
|  |         byteBuffer.putLong(accountUuid.getLeastSignificantBits()); | ||||||
|  |         return byteBuffer.array(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private static byte[] getSortKey(final long deviceId, final long keyId) { | ||||||
|  |         final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); | ||||||
|  |         byteBuffer.putLong(deviceId); | ||||||
|  |         byteBuffer.putLong(keyId); | ||||||
|  |         return byteBuffer.array(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private static byte[] getSortKeyPrefix(final long deviceId) { | ||||||
|  |         final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]); | ||||||
|  |         byteBuffer.putLong(deviceId); | ||||||
|  |         return byteBuffer.array(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private Item getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) { | ||||||
|  |         return new Item().withBinary(KEY_ACCOUNT_UUID, getPartitionKey(accountUuid)) | ||||||
|  |                          .withBinary(KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId())) | ||||||
|  |                          .withString(KEY_PUBLIC_KEY, preKey.getPublicKey()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private PreKey getPreKeyFromItem(final Item item) { | ||||||
|  |         final long keyId = ByteBuffer.wrap(item.getBinary(KEY_DEVICE_ID_KEY_ID)).getLong(8); | ||||||
|  |         return new PreKey(keyId, item.getString(KEY_PUBLIC_KEY)); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | @ -5,7 +5,6 @@ | ||||||
| 
 | 
 | ||||||
| package org.whispersystems.textsecuregcm.storage; | package org.whispersystems.textsecuregcm.storage; | ||||||
| 
 | 
 | ||||||
| import com.amazonaws.services.dynamodbv2.document.BatchWriteItemOutcome; |  | ||||||
| import com.amazonaws.services.dynamodbv2.document.DeleteItemOutcome; | import com.amazonaws.services.dynamodbv2.document.DeleteItemOutcome; | ||||||
| import com.amazonaws.services.dynamodbv2.document.DynamoDB; | import com.amazonaws.services.dynamodbv2.document.DynamoDB; | ||||||
| import com.amazonaws.services.dynamodbv2.document.Index; | import com.amazonaws.services.dynamodbv2.document.Index; | ||||||
|  | @ -17,11 +16,8 @@ import com.amazonaws.services.dynamodbv2.document.api.QueryApi; | ||||||
| import com.amazonaws.services.dynamodbv2.document.spec.DeleteItemSpec; | import com.amazonaws.services.dynamodbv2.document.spec.DeleteItemSpec; | ||||||
| import com.amazonaws.services.dynamodbv2.document.spec.QuerySpec; | import com.amazonaws.services.dynamodbv2.document.spec.QuerySpec; | ||||||
| import com.amazonaws.services.dynamodbv2.model.ReturnValue; | import com.amazonaws.services.dynamodbv2.model.ReturnValue; | ||||||
| import io.micrometer.core.instrument.Counter; |  | ||||||
| import io.micrometer.core.instrument.Timer; | import io.micrometer.core.instrument.Timer; | ||||||
| import org.apache.commons.lang3.StringUtils; | import org.apache.commons.lang3.StringUtils; | ||||||
| import org.slf4j.Logger; |  | ||||||
| import org.slf4j.LoggerFactory; |  | ||||||
| import org.whispersystems.textsecuregcm.entities.MessageProtos; | import org.whispersystems.textsecuregcm.entities.MessageProtos; | ||||||
| import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; | import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; | ||||||
| 
 | 
 | ||||||
|  | @ -33,17 +29,11 @@ import java.util.List; | ||||||
| import java.util.Map; | import java.util.Map; | ||||||
| import java.util.Optional; | import java.util.Optional; | ||||||
| import java.util.UUID; | import java.util.UUID; | ||||||
| import java.util.concurrent.atomic.AtomicReference; |  | ||||||
| import java.util.function.Consumer; |  | ||||||
| 
 | 
 | ||||||
| import static com.codahale.metrics.MetricRegistry.name; | import static com.codahale.metrics.MetricRegistry.name; | ||||||
| import static io.micrometer.core.instrument.Metrics.counter; |  | ||||||
| import static io.micrometer.core.instrument.Metrics.timer; | import static io.micrometer.core.instrument.Metrics.timer; | ||||||
| 
 | 
 | ||||||
| public class MessagesDynamoDb { | public class MessagesDynamoDb extends AbstractDynamoDbStore { | ||||||
|   private static final int MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE = 25;  // This was arbitrarily chosen and may be entirely too high. |  | ||||||
|   private static final int DYNAMO_DB_MAX_BATCH_SIZE = 25;  // This limit comes from Amazon Dynamo DB itself. It will reject batch writes larger than this. |  | ||||||
|   public static final int RESULT_SET_CHUNK_SIZE = 100; |  | ||||||
| 
 | 
 | ||||||
|   private static final String KEY_PARTITION = "H"; |   private static final String KEY_PARTITION = "H"; | ||||||
|   private static final String KEY_SORT = "S"; |   private static final String KEY_SORT = "S"; | ||||||
|  | @ -60,10 +50,6 @@ public class MessagesDynamoDb { | ||||||
|   private static final String KEY_CONTENT = "C"; |   private static final String KEY_CONTENT = "C"; | ||||||
|   private static final String KEY_TTL = "E"; |   private static final String KEY_TTL = "E"; | ||||||
| 
 | 
 | ||||||
|   private final Logger logger = LoggerFactory.getLogger(getClass()); |  | ||||||
|   private final Timer batchWriteItemsFirstPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "true"); |  | ||||||
|   private final Timer batchWriteItemsRetryPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "false"); |  | ||||||
|   private final Counter batchWriteItemsUnprocessed = counter(name(getClass(), "batchWriteItemsUnprocessed")); |  | ||||||
|   private final Timer storeTimer = timer(name(getClass(), "store")); |   private final Timer storeTimer = timer(name(getClass(), "store")); | ||||||
|   private final Timer loadTimer = timer(name(getClass(), "load")); |   private final Timer loadTimer = timer(name(getClass(), "load")); | ||||||
|   private final Timer deleteBySourceAndTimestamp = timer(name(getClass(), "delete", "sourceAndTimestamp")); |   private final Timer deleteBySourceAndTimestamp = timer(name(getClass(), "delete", "sourceAndTimestamp")); | ||||||
|  | @ -71,18 +57,18 @@ public class MessagesDynamoDb { | ||||||
|   private final Timer deleteByAccount = timer(name(getClass(), "delete", "account")); |   private final Timer deleteByAccount = timer(name(getClass(), "delete", "account")); | ||||||
|   private final Timer deleteByDevice = timer(name(getClass(), "delete", "device")); |   private final Timer deleteByDevice = timer(name(getClass(), "delete", "device")); | ||||||
| 
 | 
 | ||||||
|   private final DynamoDB dynamoDb; |  | ||||||
|   private final String tableName; |   private final String tableName; | ||||||
|   private final Duration timeToLive; |   private final Duration timeToLive; | ||||||
| 
 | 
 | ||||||
|   public MessagesDynamoDb(DynamoDB dynamoDb, String tableName, Duration timeToLive) { |   public MessagesDynamoDb(DynamoDB dynamoDb, String tableName, Duration timeToLive) { | ||||||
|     this.dynamoDb = dynamoDb; |     super(dynamoDb); | ||||||
|  | 
 | ||||||
|     this.tableName = tableName; |     this.tableName = tableName; | ||||||
|     this.timeToLive = timeToLive; |     this.timeToLive = timeToLive; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   public void store(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid, final long destinationDeviceId) { |   public void store(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid, final long destinationDeviceId) { | ||||||
|     storeTimer.record(() -> doInBatches(messages, (messageBatch) -> storeBatch(messageBatch, destinationAccountUuid, destinationDeviceId), DYNAMO_DB_MAX_BATCH_SIZE)); |     storeTimer.record(() -> writeInBatches(messages, (messageBatch) -> storeBatch(messageBatch, destinationAccountUuid, destinationDeviceId))); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private void storeBatch(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid, final long destinationDeviceId) { |   private void storeBatch(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid, final long destinationDeviceId) { | ||||||
|  | @ -135,7 +121,7 @@ public class MessagesDynamoDb { | ||||||
|                                                  .withValueMap(Map.of(":part", partitionKey, |                                                  .withValueMap(Map.of(":part", partitionKey, | ||||||
|                                                                       ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId))) |                                                                       ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId))) | ||||||
|                                                  .withMaxResultSize(numberOfMessagesToFetch); |                                                  .withMaxResultSize(numberOfMessagesToFetch); | ||||||
|       final Table table = dynamoDb.getTable(tableName); |       final Table table = getDynamoDb().getTable(tableName); | ||||||
|       List<OutgoingMessageEntity> messageEntities = new ArrayList<>(numberOfMessagesToFetch); |       List<OutgoingMessageEntity> messageEntities = new ArrayList<>(numberOfMessagesToFetch); | ||||||
|       for (Item message : table.query(querySpec)) { |       for (Item message : table.query(querySpec)) { | ||||||
|         messageEntities.add(convertItemToOutgoingMessageEntity(message)); |         messageEntities.add(convertItemToOutgoingMessageEntity(message)); | ||||||
|  | @ -164,7 +150,7 @@ public class MessagesDynamoDb { | ||||||
|                                                                       ":source", source, |                                                                       ":source", source, | ||||||
|                                                                       ":timestamp", timestamp)); |                                                                       ":timestamp", timestamp)); | ||||||
| 
 | 
 | ||||||
|       final Table table = dynamoDb.getTable(tableName); |       final Table table = getDynamoDb().getTable(tableName); | ||||||
|       return deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(table, partitionKey, querySpec, table); |       return deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(table, partitionKey, querySpec, table); | ||||||
|     }); |     }); | ||||||
|   } |   } | ||||||
|  | @ -179,7 +165,7 @@ public class MessagesDynamoDb { | ||||||
|                                                                      "#uuid", LOCAL_INDEX_MESSAGE_UUID_KEY_SORT)) |                                                                      "#uuid", LOCAL_INDEX_MESSAGE_UUID_KEY_SORT)) | ||||||
|                                                  .withValueMap(Map.of(":part", partitionKey, |                                                  .withValueMap(Map.of(":part", partitionKey, | ||||||
|                                                                       ":uuid", convertLocalIndexMessageUuidSortKey(messageUuid))); |                                                                       ":uuid", convertLocalIndexMessageUuidSortKey(messageUuid))); | ||||||
|       final Table table = dynamoDb.getTable(tableName); |       final Table table = getDynamoDb().getTable(tableName); | ||||||
|       final Index index = table.getIndex(LOCAL_INDEX_MESSAGE_UUID_NAME); |       final Index index = table.getIndex(LOCAL_INDEX_MESSAGE_UUID_NAME); | ||||||
|       return deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(table, partitionKey, querySpec, index); |       return deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(table, partitionKey, querySpec, index); | ||||||
|     }); |     }); | ||||||
|  | @ -241,62 +227,24 @@ public class MessagesDynamoDb { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private void deleteRowsMatchingQuery(byte[] partitionKey, QuerySpec querySpec) { |   private void deleteRowsMatchingQuery(byte[] partitionKey, QuerySpec querySpec) { | ||||||
|     final Table table = dynamoDb.getTable(tableName); |     final Table table = getDynamoDb().getTable(tableName); | ||||||
|     doInBatches(table.query(querySpec), (itemBatch) -> deleteItems(partitionKey, itemBatch), DYNAMO_DB_MAX_BATCH_SIZE); |     writeInBatches(table.query(querySpec), (itemBatch) -> deleteItems(partitionKey, itemBatch)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private void deleteItems(byte[] partitionKey, List<Item> items) { |   private void deleteItems(byte[] partitionKey, List<Item> items) { | ||||||
|     final TableWriteItems tableWriteItems = new TableWriteItems(tableName); |     final TableWriteItems tableWriteItems = new TableWriteItems(tableName); | ||||||
|     items.stream().map((x) -> new PrimaryKey(KEY_PARTITION, partitionKey, KEY_SORT, x.getBinary(KEY_SORT))).forEach(tableWriteItems::addPrimaryKeyToDelete); |     items.stream().map(item -> new PrimaryKey(KEY_PARTITION, partitionKey, KEY_SORT, item.getBinary(KEY_SORT))).forEach(tableWriteItems::addPrimaryKeyToDelete); | ||||||
|     executeTableWriteItemsUntilComplete(tableWriteItems); |     executeTableWriteItemsUntilComplete(tableWriteItems); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private void executeTableWriteItemsUntilComplete(TableWriteItems items) { |  | ||||||
|     AtomicReference<BatchWriteItemOutcome> outcome = new AtomicReference<>(); |  | ||||||
|     batchWriteItemsFirstPass.record(() -> { |  | ||||||
|       outcome.set(dynamoDb.batchWriteItem(items)); |  | ||||||
|     }); |  | ||||||
|     int attemptCount = 0; |  | ||||||
|     while (!outcome.get().getUnprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) { |  | ||||||
|       batchWriteItemsRetryPass.record(() -> { |  | ||||||
|         outcome.set(dynamoDb.batchWriteItemUnprocessed(outcome.get().getUnprocessedItems())); |  | ||||||
|       }); |  | ||||||
|       ++attemptCount; |  | ||||||
|     } |  | ||||||
|     if (!outcome.get().getUnprocessedItems().isEmpty()) { |  | ||||||
|       logger.error("Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.", attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, outcome.get().getUnprocessedItems().size()); |  | ||||||
|       batchWriteItemsUnprocessed.increment(outcome.get().getUnprocessedItems().size()); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   private long getTtlForMessage(MessageProtos.Envelope message) { |   private long getTtlForMessage(MessageProtos.Envelope message) { | ||||||
|     return message.getServerTimestamp() / 1000 + timeToLive.getSeconds(); |     return message.getServerTimestamp() / 1000 + timeToLive.getSeconds(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private static <T> void doInBatches(final Iterable<T> items, final Consumer<List<T>> action, final int batchSize) { |  | ||||||
|     List<T> batch = new ArrayList<>(batchSize); |  | ||||||
| 
 |  | ||||||
|     for (T item : items) { |  | ||||||
|       batch.add(item); |  | ||||||
| 
 |  | ||||||
|       if (batch.size() == batchSize) { |  | ||||||
|         action.accept(batch); |  | ||||||
|         batch.clear(); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|     if (!batch.isEmpty()) { |  | ||||||
|       action.accept(batch); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   private static byte[] convertPartitionKey(final UUID destinationAccountUuid) { |   private static byte[] convertPartitionKey(final UUID destinationAccountUuid) { | ||||||
|     return convertUuidToBytes(destinationAccountUuid); |     return convertUuidToBytes(destinationAccountUuid); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private static UUID convertPartitionKey(final byte[] bytes) { |  | ||||||
|     return convertUuidFromBytes(bytes, "partition key"); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   private static byte[] convertSortKey(final long destinationDeviceId, final long serverTimestamp, final UUID messageUuid) { |   private static byte[] convertSortKey(final long destinationDeviceId, final long serverTimestamp, final UUID messageUuid) { | ||||||
|     ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[32]); |     ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[32]); | ||||||
|     byteBuffer.putLong(destinationDeviceId); |     byteBuffer.putLong(destinationDeviceId); | ||||||
|  |  | ||||||
|  | @ -0,0 +1,23 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2021 Signal Messenger, LLC | ||||||
|  |  * SPDX-License-Identifier: AGPL-3.0-only | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.whispersystems.textsecuregcm.storage; | ||||||
|  | 
 | ||||||
|  | import org.whispersystems.textsecuregcm.entities.PreKey; | ||||||
|  | 
 | ||||||
|  | import java.util.List; | ||||||
|  | 
 | ||||||
|  | public interface PreKeyStore { | ||||||
|  | 
 | ||||||
|  |     void store(Account account, long deviceId, List<PreKey> keys); | ||||||
|  | 
 | ||||||
|  |     int getCount(Account account, long deviceId); | ||||||
|  | 
 | ||||||
|  |     List<KeyRecord> take(Account account, long deviceId); | ||||||
|  | 
 | ||||||
|  |     List<KeyRecord> take(Account account); | ||||||
|  | 
 | ||||||
|  |     void delete(Account account); | ||||||
|  | } | ||||||
|  | @ -34,6 +34,7 @@ import org.whispersystems.textsecuregcm.storage.DirectoryManager; | ||||||
| import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; | import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; | ||||||
| import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; | import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; | ||||||
| import org.whispersystems.textsecuregcm.storage.Keys; | import org.whispersystems.textsecuregcm.storage.Keys; | ||||||
|  | import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; | ||||||
| import org.whispersystems.textsecuregcm.storage.Messages; | import org.whispersystems.textsecuregcm.storage.Messages; | ||||||
| import org.whispersystems.textsecuregcm.storage.MessagesCache; | import org.whispersystems.textsecuregcm.storage.MessagesCache; | ||||||
| import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; | import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; | ||||||
|  | @ -97,7 +98,16 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura | ||||||
|               .withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) configuration.getMessageDynamoDbConfiguration().getClientExecutionTimeout().toMillis())) |               .withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) configuration.getMessageDynamoDbConfiguration().getClientExecutionTimeout().toMillis())) | ||||||
|                                                                 .withRequestTimeout((int) configuration.getMessageDynamoDbConfiguration().getClientRequestTimeout().toMillis())) |                                                                 .withRequestTimeout((int) configuration.getMessageDynamoDbConfiguration().getClientRequestTimeout().toMillis())) | ||||||
|               .withCredentials(InstanceProfileCredentialsProvider.getInstance()); |               .withCredentials(InstanceProfileCredentialsProvider.getInstance()); | ||||||
|  | 
 | ||||||
|  |       AmazonDynamoDBClientBuilder keysDynamoDbClientBuilder = AmazonDynamoDBClientBuilder | ||||||
|  |               .standard() | ||||||
|  |               .withRegion(configuration.getKeysDynamoDbConfiguration().getRegion()) | ||||||
|  |               .withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) configuration.getKeysDynamoDbConfiguration().getClientExecutionTimeout().toMillis())) | ||||||
|  |                                                                 .withRequestTimeout((int) configuration.getKeysDynamoDbConfiguration().getClientRequestTimeout().toMillis())) | ||||||
|  |               .withCredentials(InstanceProfileCredentialsProvider.getInstance()); | ||||||
|  | 
 | ||||||
|       DynamoDB messageDynamoDb = new DynamoDB(clientBuilder.build()); |       DynamoDB messageDynamoDb = new DynamoDB(clientBuilder.build()); | ||||||
|  |       DynamoDB preKeyDynamoDb  = new DynamoDB(keysDynamoDbClientBuilder.build()); | ||||||
| 
 | 
 | ||||||
|       FaultTolerantRedisCluster cacheCluster = new FaultTolerantRedisCluster("main_cache_cluster", configuration.getCacheClusterConfiguration(), redisClusterClientResources); |       FaultTolerantRedisCluster cacheCluster = new FaultTolerantRedisCluster("main_cache_cluster", configuration.getCacheClusterConfiguration(), redisClusterClientResources); | ||||||
| 
 | 
 | ||||||
|  | @ -111,6 +121,7 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura | ||||||
|       Profiles                  profiles             = new Profiles(accountDatabase); |       Profiles                  profiles             = new Profiles(accountDatabase); | ||||||
|       ReservedUsernames         reservedUsernames    = new ReservedUsernames(accountDatabase); |       ReservedUsernames         reservedUsernames    = new ReservedUsernames(accountDatabase); | ||||||
|       Keys                      keys                 = new Keys(accountDatabase, configuration.getAccountsDatabaseConfiguration().getKeyOperationRetryConfiguration()); |       Keys                      keys                 = new Keys(accountDatabase, configuration.getAccountsDatabaseConfiguration().getKeyOperationRetryConfiguration()); | ||||||
|  |       KeysDynamoDb              keysDynamoDb         = new KeysDynamoDb(messageDynamoDb, configuration.getKeysDynamoDbConfiguration().getTableName()); | ||||||
|       Messages                  messages             = new Messages(messageDatabase); |       Messages                  messages             = new Messages(messageDatabase); | ||||||
|       MessagesDynamoDb          messagesDynamoDb     = new MessagesDynamoDb(messageDynamoDb, configuration.getMessageDynamoDbConfiguration().getTableName(), configuration.getMessageDynamoDbConfiguration().getTimeToLive()); |       MessagesDynamoDb          messagesDynamoDb     = new MessagesDynamoDb(messageDynamoDb, configuration.getMessageDynamoDbConfiguration().getTableName(), configuration.getMessageDynamoDbConfiguration().getTimeToLive()); | ||||||
|       ReplicatedJedisPool       redisClient          = new RedisClientFactory("directory_cache_delete_command", configuration.getDirectoryConfiguration().getRedisConfiguration().getUrl(), configuration.getDirectoryConfiguration().getRedisConfiguration().getReplicaUrls(), configuration.getDirectoryConfiguration().getRedisConfiguration().getCircuitBreakerConfiguration()).getRedisClientPool(); |       ReplicatedJedisPool       redisClient          = new RedisClientFactory("directory_cache_delete_command", configuration.getDirectoryConfiguration().getRedisConfiguration().getUrl(), configuration.getDirectoryConfiguration().getRedisConfiguration().getReplicaUrls(), configuration.getDirectoryConfiguration().getRedisConfiguration().getCircuitBreakerConfiguration()).getRedisClientPool(); | ||||||
|  | @ -124,7 +135,7 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura | ||||||
|       UsernamesManager          usernamesManager     = new UsernamesManager(usernames, reservedUsernames, cacheCluster); |       UsernamesManager          usernamesManager     = new UsernamesManager(usernames, reservedUsernames, cacheCluster); | ||||||
|       ProfilesManager           profilesManager      = new ProfilesManager(profiles, cacheCluster); |       ProfilesManager           profilesManager      = new ProfilesManager(profiles, cacheCluster); | ||||||
|       MessagesManager           messagesManager      = new MessagesManager(messages, messagesDynamoDb, messagesCache, pushLatencyManager, new ExperimentEnrollmentManager(dynamicConfigurationManager)); |       MessagesManager           messagesManager      = new MessagesManager(messages, messagesDynamoDb, messagesCache, pushLatencyManager, new ExperimentEnrollmentManager(dynamicConfigurationManager)); | ||||||
|       AccountsManager           accountsManager      = new AccountsManager(accounts, directory, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); |       AccountsManager           accountsManager      = new AccountsManager(accounts, directory, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||||
| 
 | 
 | ||||||
|       for (String user: users) { |       for (String user: users) { | ||||||
|         Optional<Account> account = accountsManager.get(user); |         Optional<Account> account = accountsManager.get(user); | ||||||
|  |  | ||||||
|  | @ -0,0 +1,40 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2021 Signal Messenger, LLC | ||||||
|  |  * SPDX-License-Identifier: AGPL-3.0-only | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.whispersystems.textsecuregcm.storage; | ||||||
|  | 
 | ||||||
|  | import com.amazonaws.services.dynamodbv2.document.DynamoDB; | ||||||
|  | import com.amazonaws.services.dynamodbv2.model.AttributeDefinition; | ||||||
|  | import com.amazonaws.services.dynamodbv2.model.CreateTableRequest; | ||||||
|  | import com.amazonaws.services.dynamodbv2.model.KeySchemaElement; | ||||||
|  | import com.amazonaws.services.dynamodbv2.model.ProvisionedThroughput; | ||||||
|  | import com.amazonaws.services.dynamodbv2.model.ScalarAttributeType; | ||||||
|  | import org.whispersystems.textsecuregcm.tests.util.LocalDynamoDbRule; | ||||||
|  | 
 | ||||||
|  | public class KeysDynamoDbRule extends LocalDynamoDbRule { | ||||||
|  |     public static final String TABLE_NAME = "Signal_Keys_Test"; | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     protected void before() throws Throwable { | ||||||
|  |         super.before(); | ||||||
|  | 
 | ||||||
|  |         final DynamoDB dynamoDB = getDynamoDB(); | ||||||
|  | 
 | ||||||
|  |         final CreateTableRequest createTableRequest = new CreateTableRequest() | ||||||
|  |                 .withTableName(TABLE_NAME) | ||||||
|  |                 .withKeySchema(new KeySchemaElement(KeysDynamoDb.KEY_ACCOUNT_UUID, "HASH"), | ||||||
|  |                                new KeySchemaElement(KeysDynamoDb.KEY_DEVICE_ID_KEY_ID, "RANGE")) | ||||||
|  |                 .withAttributeDefinitions(new AttributeDefinition(KeysDynamoDb.KEY_ACCOUNT_UUID, ScalarAttributeType.B), | ||||||
|  |                                           new AttributeDefinition(KeysDynamoDb.KEY_DEVICE_ID_KEY_ID, ScalarAttributeType.B)) | ||||||
|  |                 .withProvisionedThroughput(new ProvisionedThroughput(20L, 20L)); | ||||||
|  | 
 | ||||||
|  |         dynamoDB.createTable(createTableRequest); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     protected void after() { | ||||||
|  |         super.after(); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,136 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2021 Signal Messenger, LLC | ||||||
|  |  * SPDX-License-Identifier: AGPL-3.0-only | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.whispersystems.textsecuregcm.storage; | ||||||
|  | 
 | ||||||
|  | import org.junit.Before; | ||||||
|  | import org.junit.ClassRule; | ||||||
|  | import org.junit.Test; | ||||||
|  | import org.whispersystems.textsecuregcm.entities.PreKey; | ||||||
|  | 
 | ||||||
|  | import java.util.Collections; | ||||||
|  | import java.util.HashSet; | ||||||
|  | import java.util.List; | ||||||
|  | import java.util.Set; | ||||||
|  | import java.util.UUID; | ||||||
|  | 
 | ||||||
|  | import static org.junit.Assert.assertEquals; | ||||||
|  | import static org.mockito.Mockito.mock; | ||||||
|  | import static org.mockito.Mockito.when; | ||||||
|  | 
 | ||||||
|  | public class KeysDynamoDbTest { | ||||||
|  | 
 | ||||||
|  |     private Account account; | ||||||
|  |     private KeysDynamoDb keysDynamoDb; | ||||||
|  | 
 | ||||||
|  |     @ClassRule | ||||||
|  |     public static KeysDynamoDbRule dynamoDbRule = new KeysDynamoDbRule(); | ||||||
|  | 
 | ||||||
|  |     private static final String ACCOUNT_NUMBER = "+18005551234"; | ||||||
|  |     private static final long DEVICE_ID = 1L; | ||||||
|  | 
 | ||||||
|  |     @Before | ||||||
|  |     public void setup() { | ||||||
|  |         keysDynamoDb = new KeysDynamoDb(dynamoDbRule.getDynamoDB(), KeysDynamoDbRule.TABLE_NAME); | ||||||
|  | 
 | ||||||
|  |         account = mock(Account.class); | ||||||
|  |         when(account.getNumber()).thenReturn(ACCOUNT_NUMBER); | ||||||
|  |         when(account.getUuid()).thenReturn(UUID.randomUUID()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Test | ||||||
|  |     public void testStore() { | ||||||
|  |         assertEquals("Initial pre-key count for an account should be zero", | ||||||
|  |                 0, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  | 
 | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); | ||||||
|  |         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  | 
 | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); | ||||||
|  |         assertEquals("Repeatedly storing same key should have no effect", | ||||||
|  |                 1, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  | 
 | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(2, "different-public-key"))); | ||||||
|  |         assertEquals("Inserting a new key should overwrite all prior keys for the given account/device", | ||||||
|  |                 1, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  | 
 | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(3, "third-public-key"), new PreKey(4, "fourth-public-key"))); | ||||||
|  |         assertEquals("Inserting multiple new keys should overwrite all prior keys for the given account/device", | ||||||
|  |                 2, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Test | ||||||
|  |     public void testTakeAccount() { | ||||||
|  |         final Device firstDevice = mock(Device.class); | ||||||
|  |         final Device secondDevice = mock(Device.class); | ||||||
|  | 
 | ||||||
|  |         when(firstDevice.getId()).thenReturn(DEVICE_ID); | ||||||
|  |         when(secondDevice.getId()).thenReturn(DEVICE_ID + 1); | ||||||
|  |         when(account.getDevices()).thenReturn(Set.of(firstDevice, secondDevice)); | ||||||
|  | 
 | ||||||
|  |         assertEquals(Collections.emptyList(), keysDynamoDb.take(account)); | ||||||
|  | 
 | ||||||
|  |         final PreKey firstDevicePreKey = new PreKey(1, "public-key"); | ||||||
|  |         final PreKey secondDevicePreKey = new PreKey(2, "second-key"); | ||||||
|  | 
 | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID, List.of(firstDevicePreKey)); | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID + 1, List.of(secondDevicePreKey)); | ||||||
|  | 
 | ||||||
|  |         final Set<KeyRecord> expectedKeys = Set.of( | ||||||
|  |                 new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID, firstDevicePreKey.getKeyId(), firstDevicePreKey.getPublicKey()), | ||||||
|  |                 new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID + 1, secondDevicePreKey.getKeyId(), secondDevicePreKey.getPublicKey())); | ||||||
|  | 
 | ||||||
|  |         assertEquals(expectedKeys, new HashSet<>(keysDynamoDb.take(account))); | ||||||
|  |         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  |         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID + 1)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Test | ||||||
|  |     public void testTakeAccountAndDeviceId() { | ||||||
|  |         assertEquals(Collections.emptyList(), keysDynamoDb.take(account, DEVICE_ID)); | ||||||
|  | 
 | ||||||
|  |         final PreKey preKey = new PreKey(1, "public-key"); | ||||||
|  | 
 | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key"))); | ||||||
|  |         assertEquals(List.of(new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID, preKey.getKeyId(), preKey.getPublicKey())), keysDynamoDb.take(account, DEVICE_ID)); | ||||||
|  |         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Test | ||||||
|  |     public void testGetCount() { | ||||||
|  |         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  | 
 | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); | ||||||
|  |         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Test | ||||||
|  |     public void testDeleteByAccount() { | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); | ||||||
|  | 
 | ||||||
|  |         assertEquals(2, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  |         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID + 1)); | ||||||
|  | 
 | ||||||
|  |         keysDynamoDb.delete(account); | ||||||
|  | 
 | ||||||
|  |         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  |         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID + 1)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Test | ||||||
|  |     public void testDeleteByAccountAndDevice() { | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); | ||||||
|  |         keysDynamoDb.store(account, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); | ||||||
|  | 
 | ||||||
|  |         assertEquals(2, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  |         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID + 1)); | ||||||
|  | 
 | ||||||
|  |         keysDynamoDb.delete(account, DEVICE_ID); | ||||||
|  | 
 | ||||||
|  |         assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); | ||||||
|  |         assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID + 1)); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | @ -46,7 +46,7 @@ import io.dropwizard.testing.junit.ResourceTestRule; | ||||||
| import static org.assertj.core.api.Assertions.assertThat; | import static org.assertj.core.api.Assertions.assertThat; | ||||||
| import static org.mockito.Mockito.*; | import static org.mockito.Mockito.*; | ||||||
| 
 | 
 | ||||||
| public class KeyControllerTest { | public class KeysControllerTest { | ||||||
| 
 | 
 | ||||||
|   private static final String EXISTS_NUMBER = "+14152222222"; |   private static final String EXISTS_NUMBER = "+14152222222"; | ||||||
|   private static final UUID   EXISTS_UUID   = UUID.randomUUID(); |   private static final UUID   EXISTS_UUID   = UUID.randomUUID(); | ||||||
|  | @ -141,18 +141,16 @@ public class KeyControllerTest { | ||||||
| 
 | 
 | ||||||
|     List<KeyRecord> singleDevice = new LinkedList<>(); |     List<KeyRecord> singleDevice = new LinkedList<>(); | ||||||
|     singleDevice.add(SAMPLE_KEY); |     singleDevice.add(SAMPLE_KEY); | ||||||
|     when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(singleDevice); |     when(keys.take(eq(existsAccount), eq(1L))).thenReturn(singleDevice); | ||||||
| 
 |  | ||||||
|     when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(new LinkedList<>()); |  | ||||||
| 
 | 
 | ||||||
|     List<KeyRecord> multiDevice = new LinkedList<>(); |     List<KeyRecord> multiDevice = new LinkedList<>(); | ||||||
|     multiDevice.add(SAMPLE_KEY); |     multiDevice.add(SAMPLE_KEY); | ||||||
|     multiDevice.add(SAMPLE_KEY2); |     multiDevice.add(SAMPLE_KEY2); | ||||||
|     multiDevice.add(SAMPLE_KEY3); |     multiDevice.add(SAMPLE_KEY3); | ||||||
|     multiDevice.add(SAMPLE_KEY4); |     multiDevice.add(SAMPLE_KEY4); | ||||||
|     when(keys.get(EXISTS_NUMBER)).thenReturn(multiDevice); |     when(keys.take(existsAccount)).thenReturn(multiDevice); | ||||||
| 
 | 
 | ||||||
|     when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5); |     when(keys.getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L))).thenReturn(5); | ||||||
| 
 | 
 | ||||||
|     when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY); |     when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY); | ||||||
|     when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null); |     when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null); | ||||||
|  | @ -169,7 +167,7 @@ public class KeyControllerTest { | ||||||
| 
 | 
 | ||||||
|     assertThat(result.getCount()).isEqualTo(4); |     assertThat(result.getCount()).isEqualTo(4); | ||||||
| 
 | 
 | ||||||
|     verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L)); |     verify(keys).getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Test |   @Test | ||||||
|  | @ -183,7 +181,7 @@ public class KeyControllerTest { | ||||||
| 
 | 
 | ||||||
|     assertThat(result.getCount()).isEqualTo(4); |     assertThat(result.getCount()).isEqualTo(4); | ||||||
| 
 | 
 | ||||||
|     verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L)); |     verify(keys).getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -283,7 +281,7 @@ public class KeyControllerTest { | ||||||
|     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); |     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); | ||||||
|     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); |     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); | ||||||
| 
 | 
 | ||||||
|     verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); |     verify(keys).take(eq(existsAccount), eq(1L)); | ||||||
|     verifyNoMoreInteractions(keys); |     verifyNoMoreInteractions(keys); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -301,7 +299,7 @@ public class KeyControllerTest { | ||||||
|     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); |     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); | ||||||
|     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); |     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); | ||||||
| 
 | 
 | ||||||
|     verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); |     verify(keys).take(eq(existsAccount), eq(1L)); | ||||||
|     verifyNoMoreInteractions(keys); |     verifyNoMoreInteractions(keys); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -320,7 +318,7 @@ public class KeyControllerTest { | ||||||
|     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); |     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); | ||||||
|     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); |     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); | ||||||
| 
 | 
 | ||||||
|     verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); |     verify(keys).take(eq(existsAccount), eq(1L)); | ||||||
|     verifyNoMoreInteractions(keys); |     verifyNoMoreInteractions(keys); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -338,7 +336,7 @@ public class KeyControllerTest { | ||||||
|     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); |     assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); | ||||||
|     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); |     assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); | ||||||
| 
 | 
 | ||||||
|     verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); |     verify(keys).take(eq(existsAccount), eq(1L)); | ||||||
|     verifyNoMoreInteractions(keys); |     verifyNoMoreInteractions(keys); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -414,7 +412,7 @@ public class KeyControllerTest { | ||||||
|     assertThat(signedPreKey).isNull(); |     assertThat(signedPreKey).isNull(); | ||||||
|     assertThat(deviceId).isEqualTo(4); |     assertThat(deviceId).isEqualTo(4); | ||||||
| 
 | 
 | ||||||
|     verify(keys).get(eq(EXISTS_NUMBER)); |     verify(keys).take(eq(existsAccount)); | ||||||
|     verifyNoMoreInteractions(keys); |     verifyNoMoreInteractions(keys); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -464,7 +462,7 @@ public class KeyControllerTest { | ||||||
|     assertThat(signedPreKey).isNull(); |     assertThat(signedPreKey).isNull(); | ||||||
|     assertThat(deviceId).isEqualTo(4); |     assertThat(deviceId).isEqualTo(4); | ||||||
| 
 | 
 | ||||||
|     verify(keys).get(eq(EXISTS_NUMBER)); |     verify(keys).take(eq(existsAccount)); | ||||||
|     verifyNoMoreInteractions(keys); |     verifyNoMoreInteractions(keys); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -533,7 +531,7 @@ public class KeyControllerTest { | ||||||
|     assertThat(response.getStatus()).isEqualTo(204); |     assertThat(response.getStatus()).isEqualTo(204); | ||||||
| 
 | 
 | ||||||
|     ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class); |     ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class); | ||||||
|     verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture()); |     verify(keys).store(eq(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture()); | ||||||
| 
 | 
 | ||||||
|     List<PreKey> capturedList = listCaptor.getValue(); |     List<PreKey> capturedList = listCaptor.getValue(); | ||||||
|     assertThat(capturedList.size()).isEqualTo(1); |     assertThat(capturedList.size()).isEqualTo(1); | ||||||
|  | @ -567,7 +565,7 @@ public class KeyControllerTest { | ||||||
|     assertThat(response.getStatus()).isEqualTo(204); |     assertThat(response.getStatus()).isEqualTo(204); | ||||||
| 
 | 
 | ||||||
|     ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class); |     ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class); | ||||||
|     verify(keys).store(eq(AuthHelper.DISABLED_NUMBER), eq(1L), listCaptor.capture()); |     verify(keys).store(eq(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture()); | ||||||
| 
 | 
 | ||||||
|     List<PreKey> capturedList = listCaptor.getValue(); |     List<PreKey> capturedList = listCaptor.getValue(); | ||||||
|     assertThat(capturedList.size()).isEqualTo(1); |     assertThat(capturedList.size()).isEqualTo(1); | ||||||
|  | @ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.tests.storage; | ||||||
| import io.lettuce.core.RedisException; | import io.lettuce.core.RedisException; | ||||||
| import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; | import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; | ||||||
| import org.junit.Test; | import org.junit.Test; | ||||||
| import org.whispersystems.textsecuregcm.entities.Profile; |  | ||||||
| import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; | import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; | ||||||
| import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; | import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; | ||||||
| import org.whispersystems.textsecuregcm.storage.Account; | import org.whispersystems.textsecuregcm.storage.Account; | ||||||
|  | @ -16,6 +15,7 @@ import org.whispersystems.textsecuregcm.storage.Accounts; | ||||||
| import org.whispersystems.textsecuregcm.storage.AccountsManager; | import org.whispersystems.textsecuregcm.storage.AccountsManager; | ||||||
| import org.whispersystems.textsecuregcm.storage.DirectoryManager; | import org.whispersystems.textsecuregcm.storage.DirectoryManager; | ||||||
| import org.whispersystems.textsecuregcm.storage.Keys; | import org.whispersystems.textsecuregcm.storage.Keys; | ||||||
|  | import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; | ||||||
| import org.whispersystems.textsecuregcm.storage.MessagesManager; | import org.whispersystems.textsecuregcm.storage.MessagesManager; | ||||||
| import org.whispersystems.textsecuregcm.storage.ProfilesManager; | import org.whispersystems.textsecuregcm.storage.ProfilesManager; | ||||||
| import org.whispersystems.textsecuregcm.storage.UsernamesManager; | import org.whispersystems.textsecuregcm.storage.UsernamesManager; | ||||||
|  | @ -46,6 +46,7 @@ public class AccountsManagerTest { | ||||||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); |     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); |     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||||
|     Keys                                         keys             = mock(Keys.class); |     Keys                                         keys             = mock(Keys.class); | ||||||
|  |     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); |     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); |     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); |     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||||
|  | @ -55,7 +56,7 @@ public class AccountsManagerTest { | ||||||
|     when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(uuid.toString()); |     when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(uuid.toString()); | ||||||
|     when(commands.get(eq("Account3::" + uuid.toString()))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}"); |     when(commands.get(eq("Account3::" + uuid.toString()))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}"); | ||||||
| 
 | 
 | ||||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); |     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||||
|     Optional<Account> account         = accountsManager.get("+14152222222"); |     Optional<Account> account         = accountsManager.get("+14152222222"); | ||||||
| 
 | 
 | ||||||
|     assertTrue(account.isPresent()); |     assertTrue(account.isPresent()); | ||||||
|  | @ -76,6 +77,7 @@ public class AccountsManagerTest { | ||||||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); |     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); |     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||||
|     Keys                                         keys             = mock(Keys.class); |     Keys                                         keys             = mock(Keys.class); | ||||||
|  |     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); |     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); |     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); |     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||||
|  | @ -84,7 +86,7 @@ public class AccountsManagerTest { | ||||||
| 
 | 
 | ||||||
|     when(commands.get(eq("Account3::" + uuid.toString()))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}"); |     when(commands.get(eq("Account3::" + uuid.toString()))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}"); | ||||||
| 
 | 
 | ||||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); |     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||||
|     Optional<Account> account         = accountsManager.get(uuid); |     Optional<Account> account         = accountsManager.get(uuid); | ||||||
| 
 | 
 | ||||||
|     assertTrue(account.isPresent()); |     assertTrue(account.isPresent()); | ||||||
|  | @ -106,6 +108,7 @@ public class AccountsManagerTest { | ||||||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); |     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); |     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||||
|     Keys                                         keys             = mock(Keys.class); |     Keys                                         keys             = mock(Keys.class); | ||||||
|  |     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); |     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); |     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); |     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||||
|  | @ -115,7 +118,7 @@ public class AccountsManagerTest { | ||||||
|     when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(null); |     when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(null); | ||||||
|     when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); |     when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); | ||||||
| 
 | 
 | ||||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); |     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||||
|     Optional<Account> retrieved       = accountsManager.get("+14152222222"); |     Optional<Account> retrieved       = accountsManager.get("+14152222222"); | ||||||
| 
 | 
 | ||||||
|     assertTrue(retrieved.isPresent()); |     assertTrue(retrieved.isPresent()); | ||||||
|  | @ -138,6 +141,7 @@ public class AccountsManagerTest { | ||||||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); |     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); |     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||||
|     Keys                                         keys             = mock(Keys.class); |     Keys                                         keys             = mock(Keys.class); | ||||||
|  |     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); |     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); |     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); |     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||||
|  | @ -147,7 +151,7 @@ public class AccountsManagerTest { | ||||||
|     when(commands.get(eq("Account3::" + uuid))).thenReturn(null); |     when(commands.get(eq("Account3::" + uuid))).thenReturn(null); | ||||||
|     when(accounts.get(eq(uuid))).thenReturn(Optional.of(account)); |     when(accounts.get(eq(uuid))).thenReturn(Optional.of(account)); | ||||||
| 
 | 
 | ||||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); |     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||||
|     Optional<Account> retrieved       = accountsManager.get(uuid); |     Optional<Account> retrieved       = accountsManager.get(uuid); | ||||||
| 
 | 
 | ||||||
|     assertTrue(retrieved.isPresent()); |     assertTrue(retrieved.isPresent()); | ||||||
|  | @ -170,6 +174,7 @@ public class AccountsManagerTest { | ||||||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); |     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); |     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||||
|     Keys                                         keys             = mock(Keys.class); |     Keys                                         keys             = mock(Keys.class); | ||||||
|  |     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); |     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); |     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); |     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||||
|  | @ -179,7 +184,7 @@ public class AccountsManagerTest { | ||||||
|     when(commands.get(eq("AccountMap::+14152222222"))).thenThrow(new RedisException("Connection lost!")); |     when(commands.get(eq("AccountMap::+14152222222"))).thenThrow(new RedisException("Connection lost!")); | ||||||
|     when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); |     when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); | ||||||
| 
 | 
 | ||||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); |     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||||
|     Optional<Account> retrieved       = accountsManager.get("+14152222222"); |     Optional<Account> retrieved       = accountsManager.get("+14152222222"); | ||||||
| 
 | 
 | ||||||
|     assertTrue(retrieved.isPresent()); |     assertTrue(retrieved.isPresent()); | ||||||
|  | @ -202,6 +207,7 @@ public class AccountsManagerTest { | ||||||
|     DirectoryManager                             directoryManager = mock(DirectoryManager.class); |     DirectoryManager                             directoryManager = mock(DirectoryManager.class); | ||||||
|     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); |     DirectoryQueue                               directoryQueue   = mock(DirectoryQueue.class); | ||||||
|     Keys                                         keys             = mock(Keys.class); |     Keys                                         keys             = mock(Keys.class); | ||||||
|  |     KeysDynamoDb                                 keysDynamoDb     = mock(KeysDynamoDb.class); | ||||||
|     MessagesManager                              messagesManager  = mock(MessagesManager.class); |     MessagesManager                              messagesManager  = mock(MessagesManager.class); | ||||||
|     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); |     UsernamesManager                             usernamesManager = mock(UsernamesManager.class); | ||||||
|     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); |     ProfilesManager                              profilesManager  = mock(ProfilesManager.class); | ||||||
|  | @ -211,7 +217,7 @@ public class AccountsManagerTest { | ||||||
|     when(commands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!")); |     when(commands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!")); | ||||||
|     when(accounts.get(eq(uuid))).thenReturn(Optional.of(account)); |     when(accounts.get(eq(uuid))).thenReturn(Optional.of(account)); | ||||||
| 
 | 
 | ||||||
|     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, messagesManager, usernamesManager, profilesManager); |     AccountsManager   accountsManager = new AccountsManager(accounts, directoryManager, cacheCluster, directoryQueue, keys, keysDynamoDb, messagesManager, usernamesManager, profilesManager); | ||||||
|     Optional<Account> retrieved       = accountsManager.get(uuid); |     Optional<Account> retrieved       = accountsManager.get(uuid); | ||||||
| 
 | 
 | ||||||
|     assertTrue(retrieved.isPresent()); |     assertTrue(retrieved.isPresent()); | ||||||
|  |  | ||||||
|  | @ -24,6 +24,7 @@ import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguratio | ||||||
| import org.whispersystems.textsecuregcm.configuration.AccountsDatabaseConfiguration; | import org.whispersystems.textsecuregcm.configuration.AccountsDatabaseConfiguration; | ||||||
| import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; | import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; | ||||||
| import org.whispersystems.textsecuregcm.entities.PreKey; | import org.whispersystems.textsecuregcm.entities.PreKey; | ||||||
|  | import org.whispersystems.textsecuregcm.storage.Account; | ||||||
| import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; | import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; | ||||||
| import org.whispersystems.textsecuregcm.storage.KeyRecord; | import org.whispersystems.textsecuregcm.storage.KeyRecord; | ||||||
| import org.whispersystems.textsecuregcm.storage.Keys; | import org.whispersystems.textsecuregcm.storage.Keys; | ||||||
|  | @ -41,12 +42,13 @@ import static org.mockito.Mockito.doThrow; | ||||||
| import static org.mockito.Mockito.mock; | import static org.mockito.Mockito.mock; | ||||||
| import static org.mockito.Mockito.when; | import static org.mockito.Mockito.when; | ||||||
| 
 | 
 | ||||||
| @Ignore |  | ||||||
| public class KeysTest { | public class KeysTest { | ||||||
| 
 | 
 | ||||||
|   @Rule |   @Rule | ||||||
|   public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml")); |   public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml")); | ||||||
| 
 | 
 | ||||||
|  |   private Account firstAccount; | ||||||
|  |   private Account secondAccount; | ||||||
|   private Keys    keys; |   private Keys    keys; | ||||||
| 
 | 
 | ||||||
|   @Before |   @Before | ||||||
|  | @ -56,6 +58,12 @@ public class KeysTest { | ||||||
|                                                                             new CircuitBreakerConfiguration()); |                                                                             new CircuitBreakerConfiguration()); | ||||||
| 
 | 
 | ||||||
|     this.keys = new Keys(faultTolerantDatabase, new RetryConfiguration()); |     this.keys = new Keys(faultTolerantDatabase, new RetryConfiguration()); | ||||||
|  | 
 | ||||||
|  |     this.firstAccount  = mock(Account.class); | ||||||
|  |     this.secondAccount = mock(Account.class); | ||||||
|  | 
 | ||||||
|  |     when(firstAccount.getNumber()).thenReturn("+14152222222"); | ||||||
|  |     when(secondAccount.getNumber()).thenReturn("+14151111111"); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -79,18 +87,18 @@ public class KeysTest { | ||||||
|       anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); |       anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); |     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||||
|     keys.store("+14152222222", 2, deviceTwoPreKeys); |     keys.store(firstAccount, 2, deviceTwoPreKeys); | ||||||
| 
 | 
 | ||||||
|     keys.store("+14151111111", 1, oldAnotherDeviceOnePrKeys); |     keys.store(secondAccount, 1, oldAnotherDeviceOnePrKeys); | ||||||
|     keys.store("+14151111111", 1, anotherDeviceOnePreKeys); |     keys.store(secondAccount, 1, anotherDeviceOnePreKeys); | ||||||
|     keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); |     keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); | ||||||
| 
 | 
 | ||||||
|     PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM keys WHERE number = ? AND device_id = ? ORDER BY key_id"); |     PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM keys WHERE number = ? AND device_id = ? ORDER BY key_id"); | ||||||
|     verifyStoredState(statement, "+14152222222", 1); |     verifyStoredState(statement, firstAccount, 1); | ||||||
|     verifyStoredState(statement, "+14152222222", 2); |     verifyStoredState(statement, firstAccount, 2); | ||||||
|     verifyStoredState(statement, "+14151111111", 1); |     verifyStoredState(statement, secondAccount, 1); | ||||||
|     verifyStoredState(statement, "+14151111111", 2); |     verifyStoredState(statement, secondAccount, 2); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Test |   @Test | ||||||
|  | @ -102,11 +110,12 @@ public class KeysTest { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); |     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   @Ignore | ||||||
|   @Test |   @Test | ||||||
|   public void testGetForDevice() { |   public void testGetForDevice() { | ||||||
|     List<PreKey> deviceOnePreKeys = new LinkedList<>(); |     List<PreKey> deviceOnePreKeys = new LinkedList<>(); | ||||||
|  | @ -125,45 +134,46 @@ public class KeysTest { | ||||||
|       anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); |       anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); |     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||||
|     keys.store("+14152222222", 2, deviceTwoPreKeys); |     keys.store(firstAccount, 2, deviceTwoPreKeys); | ||||||
| 
 | 
 | ||||||
|     keys.store("+14151111111", 1, anotherDeviceOnePreKeys); |     keys.store(secondAccount, 1, anotherDeviceOnePreKeys); | ||||||
|     keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); |     keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); | ||||||
|     List<KeyRecord> records = keys.get("+14152222222", 1); |     List<KeyRecord> records = keys.take(firstAccount, 1); | ||||||
| 
 | 
 | ||||||
|     assertThat(records.size()).isEqualTo(1); |     assertThat(records.size()).isEqualTo(1); | ||||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(1); |     assertThat(records.get(0).getKeyId()).isEqualTo(1); | ||||||
|     assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey1"); |     assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey1"); | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(99); | ||||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); |     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); | ||||||
| 
 | 
 | ||||||
|     records = keys.get("+14152222222", 1); |     records = keys.take(firstAccount, 1); | ||||||
| 
 | 
 | ||||||
|     assertThat(records.size()).isEqualTo(1); |     assertThat(records.size()).isEqualTo(1); | ||||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(2); |     assertThat(records.get(0).getKeyId()).isEqualTo(2); | ||||||
|     assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey2"); |     assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey2"); | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(98); | ||||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); |     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); | ||||||
| 
 | 
 | ||||||
|     records = keys.get("+14152222222", 2); |     records = keys.take(firstAccount, 2); | ||||||
| 
 | 
 | ||||||
|     assertThat(records.size()).isEqualTo(1); |     assertThat(records.size()).isEqualTo(1); | ||||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(1); |     assertThat(records.get(0).getKeyId()).isEqualTo(1); | ||||||
|     assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device2PublicKey1"); |     assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device2PublicKey1"); | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(98); | ||||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99); |     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(99); | ||||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   @Ignore | ||||||
|   @Test |   @Test | ||||||
|   public void testGetForAllDevices() { |   public void testGetForAllDevices() { | ||||||
|     List<PreKey> deviceOnePreKeys = new LinkedList<>(); |     List<PreKey> deviceOnePreKeys = new LinkedList<>(); | ||||||
|  | @ -184,18 +194,18 @@ public class KeysTest { | ||||||
|       anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i)); |       anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); |     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||||
|     keys.store("+14152222222", 2, deviceTwoPreKeys); |     keys.store(firstAccount, 2, deviceTwoPreKeys); | ||||||
| 
 | 
 | ||||||
|     keys.store("+14151111111", 1, anotherDeviceOnePreKeys); |     keys.store(secondAccount, 1, anotherDeviceOnePreKeys); | ||||||
|     keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); |     keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); | ||||||
|     keys.store("+14151111111", 3, anotherDeviceThreePreKeys); |     keys.store(secondAccount, 3, anotherDeviceThreePreKeys); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); |     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); | ||||||
| 
 | 
 | ||||||
|     List<KeyRecord> records = keys.get("+14152222222"); |     List<KeyRecord> records = keys.take(firstAccount); | ||||||
| 
 | 
 | ||||||
|     assertThat(records.size()).isEqualTo(2); |     assertThat(records.size()).isEqualTo(2); | ||||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(1); |     assertThat(records.get(0).getKeyId()).isEqualTo(1); | ||||||
|  | @ -204,10 +214,10 @@ public class KeysTest { | ||||||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey1"))).isTrue(); |     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey1"))).isTrue(); | ||||||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey1"))).isTrue(); |     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey1"))).isTrue(); | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(99); | ||||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99); |     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(99); | ||||||
| 
 | 
 | ||||||
|     records = keys.get("+14152222222"); |     records = keys.take(firstAccount); | ||||||
| 
 | 
 | ||||||
|     assertThat(records.size()).isEqualTo(2); |     assertThat(records.size()).isEqualTo(2); | ||||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(2); |     assertThat(records.get(0).getKeyId()).isEqualTo(2); | ||||||
|  | @ -216,11 +226,11 @@ public class KeysTest { | ||||||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey2"))).isTrue(); |     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey2"))).isTrue(); | ||||||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey2"))).isTrue(); |     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey2"))).isTrue(); | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(98); | ||||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(98); |     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(98); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     records = keys.get("+14151111111"); |     records = keys.take(secondAccount); | ||||||
| 
 | 
 | ||||||
|     assertThat(records.size()).isEqualTo(3); |     assertThat(records.size()).isEqualTo(3); | ||||||
|     assertThat(records.get(0).getKeyId()).isEqualTo(1); |     assertThat(records.get(0).getKeyId()).isEqualTo(1); | ||||||
|  | @ -231,11 +241,12 @@ public class KeysTest { | ||||||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device2PublicKey1"))).isTrue(); |     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device2PublicKey1"))).isTrue(); | ||||||
|     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device3PublicKey1"))).isTrue(); |     assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device3PublicKey1"))).isTrue(); | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(99); |     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(99); | ||||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(99); |     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(99); | ||||||
|     assertThat(keys.getCount("+14151111111", 3)).isEqualTo(99); |     assertThat(keys.getCount(secondAccount, 3)).isEqualTo(99); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   @Ignore | ||||||
|   @Test |   @Test | ||||||
|   public void testGetForAllDevicesParallel() throws InterruptedException { |   public void testGetForAllDevicesParallel() throws InterruptedException { | ||||||
|     List<PreKey> deviceOnePreKeys = new LinkedList<>(); |     List<PreKey> deviceOnePreKeys = new LinkedList<>(); | ||||||
|  | @ -246,11 +257,11 @@ public class KeysTest { | ||||||
|       deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i)); |       deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); |     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||||
|     keys.store("+14152222222", 2, deviceTwoPreKeys); |     keys.store(firstAccount, 2, deviceTwoPreKeys); | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); |     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); | ||||||
| 
 | 
 | ||||||
|     List<Thread> threads = new LinkedList<>(); |     List<Thread> threads = new LinkedList<>(); | ||||||
| 
 | 
 | ||||||
|  | @ -260,7 +271,7 @@ public class KeysTest { | ||||||
|         final int MAX_RETRIES = 5; |         final int MAX_RETRIES = 5; | ||||||
|         for (int retryAttempt = 0; results == null && retryAttempt < MAX_RETRIES; ++retryAttempt) { |         for (int retryAttempt = 0; results == null && retryAttempt < MAX_RETRIES; ++retryAttempt) { | ||||||
|           try { |           try { | ||||||
|             results = keys.get("+14152222222"); |             results = keys.take(firstAccount); | ||||||
|           } catch (UnableToExecuteStatementException e) { |           } catch (UnableToExecuteStatementException e) { | ||||||
|             if (retryAttempt == MAX_RETRIES - 1) { |             if (retryAttempt == MAX_RETRIES - 1) { | ||||||
|               throw e; |               throw e; | ||||||
|  | @ -278,8 +289,8 @@ public class KeysTest { | ||||||
|       thread.join(); |       thread.join(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(80); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(80); | ||||||
|     assertThat(keys.getCount("+14152222222",2)).isEqualTo(80); |     assertThat(keys.getCount(firstAccount,2)).isEqualTo(80); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Test |   @Test | ||||||
|  | @ -302,32 +313,32 @@ public class KeysTest { | ||||||
|       anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i)); |       anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     keys.store("+14152222222", 1, deviceOnePreKeys); |     keys.store(firstAccount, 1, deviceOnePreKeys); | ||||||
|     keys.store("+14152222222", 2, deviceTwoPreKeys); |     keys.store(firstAccount, 2, deviceTwoPreKeys); | ||||||
| 
 | 
 | ||||||
|     keys.store("+14151111111", 1, anotherDeviceOnePreKeys); |     keys.store(secondAccount, 1, anotherDeviceOnePreKeys); | ||||||
|     keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); |     keys.store(secondAccount, 2, anotherDeviceTwoPreKeys); | ||||||
|     keys.store("+14151111111", 3, anotherDeviceThreePreKeys); |     keys.store(secondAccount, 3, anotherDeviceThreePreKeys); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); |     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14151111111", 3)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 3)).isEqualTo(100); | ||||||
| 
 | 
 | ||||||
|     keys.delete("+14152222222"); |     keys.delete(firstAccount); | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.getCount("+14152222222", 1)).isEqualTo(0); |     assertThat(keys.getCount(firstAccount, 1)).isEqualTo(0); | ||||||
|     assertThat(keys.getCount("+14152222222", 2)).isEqualTo(0); |     assertThat(keys.getCount(firstAccount, 2)).isEqualTo(0); | ||||||
|     assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 1)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 2)).isEqualTo(100); | ||||||
|     assertThat(keys.getCount("+14151111111", 3)).isEqualTo(100); |     assertThat(keys.getCount(secondAccount, 3)).isEqualTo(100); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Test |   @Test | ||||||
|   public void testEmptyKeyGet() { |   public void testEmptyKeyGet() { | ||||||
|     List<KeyRecord> records = keys.get("+14152222222"); |     List<KeyRecord> records = keys.take(firstAccount); | ||||||
| 
 | 
 | ||||||
|     assertThat(records.isEmpty()).isTrue(); |     assertThat(records.isEmpty()).isTrue(); | ||||||
|   } |   } | ||||||
|  | @ -361,21 +372,21 @@ public class KeysTest { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     try { |     try { | ||||||
|       keys.store("+14152222222", 1, deviceOnePreKeys); |       keys.store(firstAccount, 1, deviceOnePreKeys); | ||||||
|       throw new AssertionError(); |       throw new AssertionError(); | ||||||
|     } catch (TransactionException e) { |     } catch (TransactionException e) { | ||||||
|       // good |       // good | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     try { |     try { | ||||||
|       keys.store("+14152222222", 1, deviceOnePreKeys); |       keys.store(firstAccount, 1, deviceOnePreKeys); | ||||||
|       throw new AssertionError(); |       throw new AssertionError(); | ||||||
|     } catch (TransactionException e) { |     } catch (TransactionException e) { | ||||||
|       // good |       // good | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     try { |     try { | ||||||
|       keys.store("+14152222222", 1, deviceOnePreKeys); |       keys.store(firstAccount, 1, deviceOnePreKeys); | ||||||
|       throw new AssertionError(); |       throw new AssertionError(); | ||||||
|     } catch (CallNotPermittedException e) { |     } catch (CallNotPermittedException e) { | ||||||
|       // good |       // good | ||||||
|  | @ -384,7 +395,7 @@ public class KeysTest { | ||||||
|     Thread.sleep(1100); |     Thread.sleep(1100); | ||||||
| 
 | 
 | ||||||
|     try { |     try { | ||||||
|       keys.store("+14152222222", 1, deviceOnePreKeys); |       keys.store(firstAccount, 1, deviceOnePreKeys); | ||||||
|       throw new AssertionError(); |       throw new AssertionError(); | ||||||
|     } catch (TransactionException e) { |     } catch (TransactionException e) { | ||||||
|       // good |       // good | ||||||
|  | @ -401,7 +412,10 @@ public class KeysTest { | ||||||
|     Keys keys = new Keys(new FaultTolerantDatabase("testBreaker", jdbi, new CircuitBreakerConfiguration()), new RetryConfiguration()); |     Keys keys = new Keys(new FaultTolerantDatabase("testBreaker", jdbi, new CircuitBreakerConfiguration()), new RetryConfiguration()); | ||||||
| 
 | 
 | ||||||
|     // We're happy as long as nothing throws an exception |     // We're happy as long as nothing throws an exception | ||||||
|     keys.store("+18005551234", 1, Collections.emptyList()); |     Account account = mock(Account.class); | ||||||
|  |     when(account.getNumber()).thenReturn("+18005551234"); | ||||||
|  | 
 | ||||||
|  |     keys.store(account, 1, Collections.emptyList()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Test |   @Test | ||||||
|  | @ -414,12 +428,15 @@ public class KeysTest { | ||||||
| 
 | 
 | ||||||
|     Keys keys = new Keys(new FaultTolerantDatabase("testBreaker", jdbi, new CircuitBreakerConfiguration()), new RetryConfiguration()); |     Keys keys = new Keys(new FaultTolerantDatabase("testBreaker", jdbi, new CircuitBreakerConfiguration()), new RetryConfiguration()); | ||||||
| 
 | 
 | ||||||
|     assertThat(keys.get("+18005551234")).isEqualTo(Collections.emptyList()); |     Account account = mock(Account.class); | ||||||
|     assertThat(keys.get("+18005551234", 1)).isEqualTo(Collections.emptyList()); |     when(account.getNumber()).thenReturn("+18005551234"); | ||||||
|  | 
 | ||||||
|  |     assertThat(keys.take(account)).isEqualTo(Collections.emptyList()); | ||||||
|  |     assertThat(keys.take(account, 1)).isEqualTo(Collections.emptyList()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private void verifyStoredState(PreparedStatement statement, String number, int deviceId) throws SQLException { |   private void verifyStoredState(PreparedStatement statement, Account account, int deviceId) throws SQLException { | ||||||
|     statement.setString(1, number); |     statement.setString(1, account.getNumber()); | ||||||
|     statement.setInt(2, deviceId); |     statement.setInt(2, deviceId); | ||||||
| 
 | 
 | ||||||
|     ResultSet resultSet = statement.executeQuery(); |     ResultSet resultSet = statement.executeQuery(); | ||||||
|  | @ -431,7 +448,7 @@ public class KeysTest { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|       assertThat(keyId).isEqualTo(rowCount); |       assertThat(keyId).isEqualTo(rowCount); | ||||||
|       assertThat(publicKey).isEqualTo(number + "Device" + deviceId + "PublicKey" + rowCount); |       assertThat(publicKey).isEqualTo(account.getNumber() + "Device" + deviceId + "PublicKey" + rowCount); | ||||||
| 
 | 
 | ||||||
|       rowCount++; |       rowCount++; | ||||||
|     } |     } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Jon Chambers
						Jon Chambers