some accounts classes refactorings

This commit is contained in:
Sergey Skrobotov 2022-12-02 11:09:00 -08:00
parent d0e7579f13
commit 9cf2635528
13 changed files with 430 additions and 352 deletions

View File

@ -337,7 +337,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.executorService(name(getClass(), "messageDeletionAsyncExecutor-%d")).maxThreads(16) .executorService(name(getClass(), "messageDeletionAsyncExecutor-%d")).maxThreads(16)
.workQueue(messageDeletionQueue).build(); .workQueue(messageDeletionQueue).build();
Accounts accounts = new Accounts(dynamicConfigurationManager, Accounts accounts = new Accounts(
dynamoDbClient, dynamoDbClient,
dynamoDbAsyncClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getAccounts().getTableName(), config.getDynamoDbTables().getAccounts().getTableName(),

View File

@ -16,7 +16,6 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
@ -25,20 +24,26 @@ import software.amazon.awssdk.services.dynamodb.model.BatchWriteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.BatchWriteItemResponse; import software.amazon.awssdk.services.dynamodb.model.BatchWriteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.WriteRequest; import software.amazon.awssdk.services.dynamodb.model.WriteRequest;
import javax.annotation.Nonnull;
public abstract class AbstractDynamoDbStore { public abstract class AbstractDynamoDbStore {
private final DynamoDbClient dynamoDbClient; private static final int MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE = 25; // This was arbitrarily chosen and may be entirely too high.
private final Timer batchWriteItemsFirstPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "true"); public static final int DYNAMO_DB_MAX_BATCH_SIZE = 25; // This limit comes from Amazon Dynamo DB itself. It will reject batch writes larger than this.
private final Timer batchWriteItemsRetryPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "false");
private final Counter batchWriteItemsUnprocessed = counter(name(getClass(), "batchWriteItemsUnprocessed")); public static final int RESULT_SET_CHUNK_SIZE = 100;
private final Logger logger = LoggerFactory.getLogger(getClass()); private final Logger logger = LoggerFactory.getLogger(getClass());
private static final int MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE = 25; // This was arbitrarily chosen and may be entirely too high. private final Timer batchWriteItemsFirstPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "true");
public static final int DYNAMO_DB_MAX_BATCH_SIZE = 25; // This limit comes from Amazon Dynamo DB itself. It will reject batch writes larger than this.
public static final int RESULT_SET_CHUNK_SIZE = 100; private final Timer batchWriteItemsRetryPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "false");
private final Counter batchWriteItemsUnprocessed = counter(name(getClass(), "batchWriteItemsUnprocessed"));
private final DynamoDbClient dynamoDbClient;
public AbstractDynamoDbStore(final DynamoDbClient dynamoDbClient) { public AbstractDynamoDbStore(final DynamoDbClient dynamoDbClient) {
this.dynamoDbClient = dynamoDbClient; this.dynamoDbClient = dynamoDbClient;
@ -49,18 +54,15 @@ public abstract class AbstractDynamoDbStore {
} }
protected void executeTableWriteItemsUntilComplete(final Map<String, List<WriteRequest>> items) { protected void executeTableWriteItemsUntilComplete(final Map<String, List<WriteRequest>> items) {
AtomicReference<BatchWriteItemResponse> outcome = new AtomicReference<>(); final AtomicReference<BatchWriteItemResponse> outcome = new AtomicReference<>();
batchWriteItemsFirstPass.record( writeAndStoreOutcome(items, batchWriteItemsFirstPass, outcome);
() -> outcome.set(dynamoDbClient.batchWriteItem(BatchWriteItemRequest.builder().requestItems(items).build())));
int attemptCount = 0; int attemptCount = 0;
while (!outcome.get().unprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) { while (!outcome.get().unprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) {
batchWriteItemsRetryPass.record(() -> outcome.set(dynamoDbClient.batchWriteItem(BatchWriteItemRequest.builder() writeAndStoreOutcome(outcome.get().unprocessedItems(), batchWriteItemsRetryPass, outcome);
.requestItems(outcome.get().unprocessedItems())
.build())));
++attemptCount; ++attemptCount;
} }
if (!outcome.get().unprocessedItems().isEmpty()) { if (!outcome.get().unprocessedItems().isEmpty()) {
int totalItems = outcome.get().unprocessedItems().values().stream().mapToInt(List::size).sum(); final int totalItems = outcome.get().unprocessedItems().values().stream().mapToInt(List::size).sum();
logger.error( logger.error(
"Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.", "Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.",
attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, totalItems); attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, totalItems);
@ -68,19 +70,28 @@ public abstract class AbstractDynamoDbStore {
} }
} }
protected List<Map<String, AttributeValue>> scan(ScanRequest scanRequest, int max) { @Nonnull
protected List<Map<String, AttributeValue>> scan(final ScanRequest scanRequest, final int max) {
return db().scanPaginator(scanRequest) return db().scanPaginator(scanRequest)
.items() .items()
.stream() .stream()
.limit(max) .limit(max)
.collect(Collectors.toList()); .toList();
}
private void writeAndStoreOutcome(
final Map<String, List<WriteRequest>> items,
final Timer timer,
final AtomicReference<BatchWriteItemResponse> outcome) {
timer.record(
() -> outcome.set(dynamoDbClient.batchWriteItem(BatchWriteItemRequest.builder().requestItems(items).build()))
);
} }
static <T> void writeInBatches(final Iterable<T> items, final Consumer<List<T>> action) { static <T> void writeInBatches(final Iterable<T> items, final Consumer<List<T>> action) {
final List<T> batch = new ArrayList<>(DYNAMO_DB_MAX_BATCH_SIZE); final List<T> batch = new ArrayList<>(DYNAMO_DB_MAX_BATCH_SIZE);
for (T item : items) { for (final T item : items) {
batch.add(item); batch.add(item);
if (batch.size() == DYNAMO_DB_MAX_BATCH_SIZE) { if (batch.size() == DYNAMO_DB_MAX_BATCH_SIZE) {

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import static java.util.Objects.requireNonNull;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
@ -18,7 +19,6 @@ import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@ -29,12 +29,14 @@ import java.util.UUID;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
@ -56,8 +58,31 @@ import software.amazon.awssdk.services.dynamodb.model.Update;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.CompletableFutureUtils;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class Accounts extends AbstractDynamoDbStore { public class Accounts extends AbstractDynamoDbStore {
private static final Logger log = LoggerFactory.getLogger(Accounts.class);
private static final byte RESERVED_USERNAME_HASH_VERSION = 1;
private static final Timer CREATE_TIMER = Metrics.timer(name(Accounts.class, "create"));
private static final Timer CHANGE_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "changeNumber"));
private static final Timer SET_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "setUsername"));
private static final Timer RESERVE_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "reserveUsername"));
private static final Timer CLEAR_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "clearUsername"));
private static final Timer UPDATE_TIMER = Metrics.timer(name(Accounts.class, "update"));
private static final Timer GET_BY_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "getByNumber"));
private static final Timer GET_BY_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "getByUsername"));
private static final Timer GET_BY_PNI_TIMER = Metrics.timer(name(Accounts.class, "getByPni"));
private static final Timer GET_BY_UUID_TIMER = Metrics.timer(name(Accounts.class, "getByUuid"));
private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom"));
private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset"));
private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete"));
private static final String CONDITIONAL_CHECK_FAILED = "ConditionalCheckFailed";
private static final String TRANSACTION_CONFLICT = "TransactionConflict";
// uuid, primary key // uuid, primary key
static final String KEY_ACCOUNT_UUID = "U"; static final String KEY_ACCOUNT_UUID = "U";
// uuid, attribute on account table, primary key for PNI table // uuid, attribute on account table, primary key for PNI table
@ -78,45 +103,32 @@ public class Accounts extends AbstractDynamoDbStore {
static final String ATTR_TTL = "TTL"; static final String ATTR_TTL = "TTL";
private final Clock clock; private final Clock clock;
private final DynamoDbClient client;
private final DynamoDbAsyncClient asyncClient; private final DynamoDbAsyncClient asyncClient;
private final String phoneNumberConstraintTableName; private final String phoneNumberConstraintTableName;
private final String phoneNumberIdentifierConstraintTableName; private final String phoneNumberIdentifierConstraintTableName;
private final String usernamesConstraintTableName; private final String usernamesConstraintTableName;
private final String accountsTableName; private final String accountsTableName;
private final int scanPageSize; private final int scanPageSize;
private static final byte RESERVED_USERNAME_HASH_VERSION = 1;
private static final Timer CREATE_TIMER = Metrics.timer(name(Accounts.class, "create"));
private static final Timer CHANGE_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "changeNumber"));
private static final Timer SET_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "setUsername"));
private static final Timer RESERVE_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "reserveUsername"));
private static final Timer CLEAR_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "clearUsername"));
private static final Timer UPDATE_TIMER = Metrics.timer(name(Accounts.class, "update"));
private static final Timer GET_BY_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "getByNumber"));
private static final Timer GET_BY_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "getByUsername"));
private static final Timer GET_BY_PNI_TIMER = Metrics.timer(name(Accounts.class, "getByPni"));
private static final Timer GET_BY_UUID_TIMER = Metrics.timer(name(Accounts.class, "getByUuid"));
private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom"));
private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset"));
private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete"));
private static final Logger log = LoggerFactory.getLogger(Accounts.class);
@VisibleForTesting @VisibleForTesting
public Accounts( public Accounts(
final Clock clock, final Clock clock,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, final DynamoDbClient client,
DynamoDbClient client, DynamoDbAsyncClient asyncClient, final DynamoDbAsyncClient asyncClient,
String accountsTableName, String phoneNumberConstraintTableName, final String accountsTableName,
String phoneNumberIdentifierConstraintTableName, final String usernamesConstraintTableName, final String phoneNumberConstraintTableName,
final String phoneNumberIdentifierConstraintTableName,
final String usernamesConstraintTableName,
final int scanPageSize) { final int scanPageSize) {
super(client); super(client);
this.clock = clock; this.clock = clock;
this.client = client;
this.asyncClient = asyncClient; this.asyncClient = asyncClient;
this.phoneNumberConstraintTableName = phoneNumberConstraintTableName; this.phoneNumberConstraintTableName = phoneNumberConstraintTableName;
this.phoneNumberIdentifierConstraintTableName = phoneNumberIdentifierConstraintTableName; this.phoneNumberIdentifierConstraintTableName = phoneNumberIdentifierConstraintTableName;
@ -125,105 +137,61 @@ public class Accounts extends AbstractDynamoDbStore {
this.scanPageSize = scanPageSize; this.scanPageSize = scanPageSize;
} }
public Accounts(final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, public Accounts(
DynamoDbClient client, DynamoDbAsyncClient asyncClient, final DynamoDbClient client,
String accountsTableName, String phoneNumberConstraintTableName, final DynamoDbAsyncClient asyncClient,
String phoneNumberIdentifierConstraintTableName, final String usernamesConstraintTableName, final String accountsTableName,
final String phoneNumberConstraintTableName,
final String phoneNumberIdentifierConstraintTableName,
final String usernamesConstraintTableName,
final int scanPageSize) { final int scanPageSize) {
this(Clock.systemUTC(), dynamicConfigurationManager, client, asyncClient, accountsTableName, this(Clock.systemUTC(), client, asyncClient, accountsTableName,
phoneNumberConstraintTableName, phoneNumberIdentifierConstraintTableName, usernamesConstraintTableName, phoneNumberConstraintTableName, phoneNumberIdentifierConstraintTableName, usernamesConstraintTableName,
scanPageSize); scanPageSize);
} }
public boolean create(Account account) { public boolean create(final Account account) {
return CREATE_TIMER.record(() -> { return CREATE_TIMER.record(() -> {
try { try {
TransactWriteItem phoneNumberConstraintPut = TransactWriteItem.builder() final AttributeValue uuidAttr = AttributeValues.fromUUID(account.getUuid());
.put( final AttributeValue numberAttr = AttributeValues.fromString(account.getNumber());
Put.builder() final AttributeValue pniUuidAttr = AttributeValues.fromUUID(account.getPhoneNumberIdentifier());
.tableName(phoneNumberConstraintTableName)
.item(Map.of(
ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()),
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.conditionExpression(
"attribute_not_exists(#number) OR (attribute_exists(#number) AND #uuid = :uuid)")
.expressionAttributeNames(
Map.of("#uuid", KEY_ACCOUNT_UUID,
"#number", ATTR_ACCOUNT_E164))
.expressionAttributeValues(
Map.of(":uuid", AttributeValues.fromUUID(account.getUuid())))
.returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD)
.build())
.build();
TransactWriteItem phoneNumberIdentifierConstraintPut = TransactWriteItem.builder() final TransactWriteItem phoneNumberConstraintPut = buildConstraintTablePutIfAbsent(
.put( phoneNumberConstraintTableName, uuidAttr, ATTR_ACCOUNT_E164, numberAttr);
Put.builder()
.tableName(phoneNumberIdentifierConstraintTableName)
.item(Map.of(
ATTR_PNI_UUID, AttributeValues.fromUUID(account.getPhoneNumberIdentifier()),
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.conditionExpression(
"attribute_not_exists(#pni) OR (attribute_exists(#pni) AND #uuid = :uuid)")
.expressionAttributeNames(
Map.of("#uuid", KEY_ACCOUNT_UUID,
"#pni", ATTR_PNI_UUID))
.expressionAttributeValues(
Map.of(":uuid", AttributeValues.fromUUID(account.getUuid())))
.returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD)
.build())
.build();
final Map<String, AttributeValue> item = new HashMap<>(Map.of( final TransactWriteItem phoneNumberIdentifierConstraintPut = buildConstraintTablePutIfAbsent(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()), phoneNumberIdentifierConstraintTableName, uuidAttr, ATTR_PNI_UUID, pniUuidAttr);
ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()),
ATTR_PNI_UUID, AttributeValues.fromUUID(account.getPhoneNumberIdentifier()),
ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)),
ATTR_VERSION, AttributeValues.fromInt(account.getVersion()),
ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.shouldBeVisibleInDirectory())));
// Add the UAK if it's in the account final TransactWriteItem accountPut = buildAccountPut(account, uuidAttr, numberAttr, pniUuidAttr);
account.getUnidentifiedAccessKey()
.map(AttributeValues::fromByteArray)
.ifPresent(uak -> item.put(ATTR_UAK, uak));
TransactWriteItem accountPut = TransactWriteItem.builder()
.put(Put.builder()
.conditionExpression("attribute_not_exists(#number) OR #number = :number")
.expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164))
.expressionAttributeValues(Map.of(":number", AttributeValues.fromString(account.getNumber())))
.tableName(accountsTableName)
.item(item)
.build())
.build();
final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut) .transactItems(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut)
.build(); .build();
try { try {
client.transactWriteItems(request); db().transactWriteItems(request);
} catch (TransactionCanceledException e) { } catch (final TransactionCanceledException e) {
final CancellationReason accountCancellationReason = e.cancellationReasons().get(2); final CancellationReason accountCancellationReason = e.cancellationReasons().get(2);
if ("ConditionalCheckFailed".equals(accountCancellationReason.code())) { if (conditionalCheckFailed(accountCancellationReason)) {
throw new IllegalArgumentException("account identifier present with different phone number"); throw new IllegalArgumentException("account identifier present with different phone number");
} }
final CancellationReason phoneNumberConstraintCancellationReason = e.cancellationReasons().get(0); final CancellationReason phoneNumberConstraintCancellationReason = e.cancellationReasons().get(0);
final CancellationReason phoneNumberIdentifierConstraintCancellationReason = e.cancellationReasons().get(1); final CancellationReason phoneNumberIdentifierConstraintCancellationReason = e.cancellationReasons().get(1);
if ("ConditionalCheckFailed".equals(phoneNumberConstraintCancellationReason.code()) || if (conditionalCheckFailed(phoneNumberConstraintCancellationReason)
"ConditionalCheckFailed".equals(phoneNumberIdentifierConstraintCancellationReason.code())) { || conditionalCheckFailed(phoneNumberIdentifierConstraintCancellationReason)) {
// In theory, both reasons should trip in tandem and either should give us the information we need. Even so, // In theory, both reasons should trip in tandem and either should give us the information we need. Even so,
// we'll be cautious here and make sure we're choosing a condition check that really failed. // we'll be cautious here and make sure we're choosing a condition check that really failed.
final CancellationReason reason = "ConditionalCheckFailed".equals(phoneNumberConstraintCancellationReason.code()) ? final CancellationReason reason = conditionalCheckFailed(phoneNumberConstraintCancellationReason)
phoneNumberConstraintCancellationReason : phoneNumberIdentifierConstraintCancellationReason; ? phoneNumberConstraintCancellationReason
: phoneNumberIdentifierConstraintCancellationReason;
ByteBuffer actualAccountUuid = reason.item().get(KEY_ACCOUNT_UUID).b().asByteBuffer(); final ByteBuffer actualAccountUuid = reason.item().get(KEY_ACCOUNT_UUID).b().asByteBuffer();
account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid)); account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid));
final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow(); final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow();
@ -235,7 +203,7 @@ public class Accounts extends AbstractDynamoDbStore {
return false; return false;
} }
if ("TransactionConflict".equals(accountCancellationReason.code())) { if (TRANSACTION_CONFLICT.equals(accountCancellationReason.code())) {
// this should only happen if two clients manage to make concurrent create() calls // this should only happen if two clients manage to make concurrent create() calls
throw new ContestedOptimisticLockException(); throw new ContestedOptimisticLockException();
} }
@ -243,7 +211,7 @@ public class Accounts extends AbstractDynamoDbStore {
// this shouldn't happen // this shouldn't happen
throw new RuntimeException("could not create account: " + extractCancellationReasonCodes(e)); throw new RuntimeException("could not create account: " + extractCancellationReasonCodes(e));
} }
} catch (JsonProcessingException e) { } catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} }
@ -275,62 +243,34 @@ public class Accounts extends AbstractDynamoDbStore {
try { try {
final List<TransactWriteItem> writeItems = new ArrayList<>(); final List<TransactWriteItem> writeItems = new ArrayList<>();
final AttributeValue uuidAttr = AttributeValues.fromUUID(account.getUuid());
final AttributeValue numberAttr = AttributeValues.fromString(number);
final AttributeValue pniAttr = AttributeValues.fromUUID(phoneNumberIdentifier);
writeItems.add(TransactWriteItem.builder() writeItems.add(buildDelete(phoneNumberConstraintTableName, ATTR_ACCOUNT_E164, originalNumber));
.delete(Delete.builder() writeItems.add(buildConstraintTablePut(phoneNumberConstraintTableName, uuidAttr, ATTR_ACCOUNT_E164, numberAttr));
.tableName(phoneNumberConstraintTableName) writeItems.add(buildDelete(phoneNumberIdentifierConstraintTableName, ATTR_PNI_UUID, originalPni));
.key(Map.of(ATTR_ACCOUNT_E164, AttributeValues.fromString(originalNumber))) writeItems.add(buildConstraintTablePut(phoneNumberIdentifierConstraintTableName, uuidAttr, ATTR_PNI_UUID, pniAttr));
.build())
.build());
writeItems.add(TransactWriteItem.builder()
.put(Put.builder()
.tableName(phoneNumberConstraintTableName)
.item(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()),
ATTR_ACCOUNT_E164, AttributeValues.fromString(number)))
.conditionExpression("attribute_not_exists(#number)")
.expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164))
.returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD)
.build())
.build());
writeItems.add(TransactWriteItem.builder()
.delete(Delete.builder()
.tableName(phoneNumberIdentifierConstraintTableName)
.key(Map.of(ATTR_PNI_UUID, AttributeValues.fromUUID(originalPni)))
.build())
.build());
writeItems.add(TransactWriteItem.builder()
.put(Put.builder()
.tableName(phoneNumberIdentifierConstraintTableName)
.item(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()),
ATTR_PNI_UUID, AttributeValues.fromUUID(phoneNumberIdentifier)))
.conditionExpression("attribute_not_exists(#pni)")
.expressionAttributeNames(Map.of("#pni", ATTR_PNI_UUID))
.returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD)
.build())
.build());
writeItems.add( writeItems.add(
TransactWriteItem.builder() TransactWriteItem.builder()
.update(Update.builder() .update(Update.builder()
.tableName(accountsTableName) .tableName(accountsTableName)
.key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) .key(Map.of(KEY_ACCOUNT_UUID, uuidAttr))
.updateExpression("SET #data = :data, #number = :number, #pni = :pni, #cds = :cds ADD #version :version_increment") .updateExpression(
.conditionExpression("attribute_exists(#number) AND #version = :version") "SET #data = :data, #number = :number, #pni = :pni, #cds = :cds ADD #version :version_increment")
.expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164, .conditionExpression(
"attribute_exists(#number) AND #version = :version")
.expressionAttributeNames(Map.of(
"#number", ATTR_ACCOUNT_E164,
"#data", ATTR_ACCOUNT_DATA, "#data", ATTR_ACCOUNT_DATA,
"#cds", ATTR_CANONICALLY_DISCOVERABLE, "#cds", ATTR_CANONICALLY_DISCOVERABLE,
"#pni", ATTR_PNI_UUID, "#pni", ATTR_PNI_UUID,
"#version", ATTR_VERSION)) "#version", ATTR_VERSION))
.expressionAttributeValues(Map.of( .expressionAttributeValues(Map.of(
":number", numberAttr,
":data", AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)), ":data", AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)),
":number", AttributeValues.fromString(number),
":pni", AttributeValues.fromUUID(phoneNumberIdentifier),
":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()), ":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()),
":pni", pniAttr,
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1))) ":version_increment", AttributeValues.fromInt(1)))
.build()) .build())
@ -340,7 +280,7 @@ public class Accounts extends AbstractDynamoDbStore {
.transactItems(writeItems) .transactItems(writeItems)
.build(); .build();
client.transactWriteItems(request); db().transactWriteItems(request);
account.setVersion(account.getVersion() + 1); account.setVersion(account.getVersion() + 1);
succeeded = true; succeeded = true;
@ -354,21 +294,6 @@ public class Accounts extends AbstractDynamoDbStore {
}); });
} }
public static byte[] reservedUsernameHash(final UUID accountId, final String reservedUsername) {
final MessageDigest sha256;
try {
sha256 = MessageDigest.getInstance("SHA-256");
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
final ByteBuffer byteBuffer = ByteBuffer.allocate(32 + 1);
sha256.update(reservedUsername.getBytes(StandardCharsets.UTF_8));
sha256.update(UUIDUtil.toBytes(accountId));
byteBuffer.put(RESERVED_USERNAME_HASH_VERSION);
byteBuffer.put(sha256.digest());
return byteBuffer.array();
}
/** /**
* Reserve a username under a token * Reserve a username under a token
* *
@ -386,7 +311,7 @@ public class Accounts extends AbstractDynamoDbStore {
boolean succeeded = false; boolean succeeded = false;
long expirationTime = clock.instant().plus(ttl).getEpochSecond(); final long expirationTime = clock.instant().plus(ttl).getEpochSecond();
final UUID reservationToken = UUID.randomUUID(); final UUID reservationToken = UUID.randomUUID();
try { try {
@ -425,14 +350,14 @@ public class Accounts extends AbstractDynamoDbStore {
.transactItems(writeItems) .transactItems(writeItems)
.build(); .build();
client.transactWriteItems(request); db().transactWriteItems(request);
account.setVersion(account.getVersion() + 1); account.setVersion(account.getVersion() + 1);
succeeded = true; succeeded = true;
} catch (final JsonProcessingException e) { } catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} catch (final TransactionCanceledException e) { } catch (final TransactionCanceledException e) {
if (e.cancellationReasons().stream().map(CancellationReason::code).anyMatch("ConditionalCheckFailed"::equals)) { if (e.cancellationReasons().stream().map(CancellationReason::code).anyMatch(CONDITIONAL_CHECK_FAILED::equals)) {
throw new ContestedOptimisticLockException(); throw new ContestedOptimisticLockException();
} }
throw e; throw e;
@ -520,25 +445,21 @@ public class Accounts extends AbstractDynamoDbStore {
.build()) .build())
.build()); .build());
maybeOriginalUsername.ifPresent(originalUsername -> writeItems.add(TransactWriteItem.builder() maybeOriginalUsername.ifPresent(originalUsername -> writeItems.add(
.delete(Delete.builder() buildDelete(usernamesConstraintTableName, ATTR_USERNAME, originalUsername)));
.tableName(usernamesConstraintTableName)
.key(Map.of(ATTR_USERNAME, AttributeValues.fromString(originalUsername)))
.build())
.build()));
final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(writeItems) .transactItems(writeItems)
.build(); .build();
client.transactWriteItems(request); db().transactWriteItems(request);
account.setVersion(account.getVersion() + 1); account.setVersion(account.getVersion() + 1);
succeeded = true; succeeded = true;
} catch (final JsonProcessingException e) { } catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} catch (final TransactionCanceledException e) { } catch (final TransactionCanceledException e) {
if (e.cancellationReasons().stream().map(CancellationReason::code).anyMatch("ConditionalCheckFailed"::equals)) { if (e.cancellationReasons().stream().map(CancellationReason::code).anyMatch(CONDITIONAL_CHECK_FAILED::equals)) {
throw new ContestedOptimisticLockException(); throw new ContestedOptimisticLockException();
} }
throw e; throw e;
@ -551,7 +472,7 @@ public class Accounts extends AbstractDynamoDbStore {
} }
} }
public void clearUsername(Account account) { public void clearUsername(final Account account) {
account.getUsername().ifPresent(username -> { account.getUsername().ifPresent(username -> {
CLEAR_USERNAME_TIMER.record(() -> { CLEAR_USERNAME_TIMER.record(() -> {
account.setUsername(null); account.setUsername(null);
@ -578,25 +499,20 @@ public class Accounts extends AbstractDynamoDbStore {
.build()) .build())
.build()); .build());
writeItems.add(TransactWriteItem.builder() writeItems.add(buildDelete(usernamesConstraintTableName, ATTR_USERNAME, username));
.delete(Delete.builder()
.tableName(usernamesConstraintTableName)
.key(Map.of(ATTR_USERNAME, AttributeValues.fromString(username)))
.build())
.build());
final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(writeItems) .transactItems(writeItems)
.build(); .build();
client.transactWriteItems(request); db().transactWriteItems(request);
account.setVersion(account.getVersion() + 1); account.setVersion(account.getVersion() + 1);
succeeded = true; succeeded = true;
} catch (final JsonProcessingException e) { } catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} catch (final TransactionCanceledException e) { } catch (final TransactionCanceledException e) {
if ("ConditionalCheckFailed".equals(e.cancellationReasons().get(0).code())) { if (conditionalCheckFailed(e.cancellationReasons().get(0))) {
throw new ContestedOptimisticLockException(); throw new ContestedOptimisticLockException();
} }
@ -610,27 +526,18 @@ public class Accounts extends AbstractDynamoDbStore {
}); });
} }
/** @Nonnull
* Extract the cause from a CompletionException public CompletionStage<Void> updateAsync(final Account account) {
*/
private static Throwable unwrap(Throwable throwable) {
while (throwable instanceof CompletionException e && throwable.getCause() != null) {
throwable = e.getCause();
}
return throwable;
}
public CompletionStage<Void> updateAsync(Account account) {
return record(UPDATE_TIMER, () -> { return record(UPDATE_TIMER, () -> {
final UpdateItemRequest updateItemRequest; final UpdateItemRequest updateItemRequest;
try { try {
// username, e164, and pni cannot be modified through this method // username, e164, and pni cannot be modified through this method
Map<String, String> attrNames = new HashMap<>(Map.of( final Map<String, String> attrNames = new HashMap<>(Map.of(
"#number", ATTR_ACCOUNT_E164, "#number", ATTR_ACCOUNT_E164,
"#data", ATTR_ACCOUNT_DATA, "#data", ATTR_ACCOUNT_DATA,
"#cds", ATTR_CANONICALLY_DISCOVERABLE, "#cds", ATTR_CANONICALLY_DISCOVERABLE,
"#version", ATTR_VERSION)); "#version", ATTR_VERSION));
Map<String, AttributeValue> attrValues = new HashMap<>(Map.of( final Map<String, AttributeValue> attrValues = new HashMap<>(Map.of(
":data", AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)), ":data", AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)),
":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()), ":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()),
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(account.getVersion()),
@ -654,7 +561,7 @@ public class Accounts extends AbstractDynamoDbStore {
.expressionAttributeNames(attrNames) .expressionAttributeNames(attrNames)
.expressionAttributeValues(attrValues) .expressionAttributeValues(attrValues)
.build(); .build();
} catch (JsonProcessingException e) { } catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} }
@ -664,7 +571,7 @@ public class Accounts extends AbstractDynamoDbStore {
return (Void) null; return (Void) null;
}) })
.exceptionally(throwable -> { .exceptionally(throwable -> {
final Throwable unwrapped = unwrap(throwable); final Throwable unwrapped = ExceptionUtils.unwrap(throwable);
if (unwrapped instanceof TransactionConflictException) { if (unwrapped instanceof TransactionConflictException) {
throw new ContestedOptimisticLockException(); throw new ContestedOptimisticLockException();
} else if (unwrapped instanceof ConditionalCheckFailedException e) { } else if (unwrapped instanceof ConditionalCheckFailedException e) {
@ -679,12 +586,12 @@ public class Accounts extends AbstractDynamoDbStore {
}); });
} }
public void update(Account account) throws ContestedOptimisticLockException { public void update(final Account account) throws ContestedOptimisticLockException {
try { try {
this.updateAsync(account).toCompletableFuture().join(); updateAsync(account).toCompletableFuture().join();
} catch (CompletionException e) { } catch (final CompletionException e) {
// unwrap CompletionExceptions, throw as long is it's unchecked // unwrap CompletionExceptions, throw as long is it's unchecked
Throwables.throwIfUnchecked(unwrap(e)); Throwables.throwIfUnchecked(ExceptionUtils.unwrap(e));
// if we otherwise somehow got a wrapped checked exception, // if we otherwise somehow got a wrapped checked exception,
// rethrow the checked exception wrapped by the original CompletionException // rethrow the checked exception wrapped by the original CompletionException
@ -698,15 +605,14 @@ public class Accounts extends AbstractDynamoDbStore {
} }
public boolean usernameAvailable(final Optional<UUID> reservationToken, final String username) { public boolean usernameAvailable(final Optional<UUID> reservationToken, final String username) {
final GetItemResponse response = client.getItem(GetItemRequest.builder() final Optional<Map<String, AttributeValue>> usernameItem = itemByKey(
.tableName(usernamesConstraintTableName) usernamesConstraintTableName, ATTR_USERNAME, AttributeValues.fromString(username));
.key(Map.of(ATTR_USERNAME, AttributeValues.fromString(username)))
.build()); if (usernameItem.isEmpty()) {
if (!response.hasItem()) {
// username is free // username is free
return true; return true;
} }
final Map<String, AttributeValue> item = response.item(); final Map<String, AttributeValue> item = usernameItem.get();
if (AttributeValues.getLong(item, ATTR_TTL, Long.MAX_VALUE) < clock.instant().getEpochSecond()) { if (AttributeValues.getLong(item, ATTR_TTL, Long.MAX_VALUE) < clock.instant().getEpochSecond()) {
// username was reserved, but has expired // username was reserved, but has expired
@ -719,112 +625,56 @@ public class Accounts extends AbstractDynamoDbStore {
.orElse(false); .orElse(false);
} }
public Optional<Account> getByE164(String number) { @Nonnull
return GET_BY_NUMBER_TIMER.record(() -> { public Optional<Account> getByE164(final String number) {
return getByIndirectLookup(
final GetItemResponse response = client.getItem(GetItemRequest.builder() GET_BY_NUMBER_TIMER, phoneNumberConstraintTableName, ATTR_ACCOUNT_E164, AttributeValues.fromString(number));
.tableName(phoneNumberConstraintTableName)
.key(Map.of(ATTR_ACCOUNT_E164, AttributeValues.fromString(number)))
.build());
return Optional.ofNullable(response.item())
.map(item -> item.get(KEY_ACCOUNT_UUID))
.map(this::accountByUuid)
.map(Accounts::fromItem);
});
}
public Optional<Account> getByUsername(final String username) {
return GET_BY_USERNAME_TIMER.record(() -> {
final GetItemResponse response = client.getItem(GetItemRequest.builder()
.tableName(usernamesConstraintTableName)
.key(Map.of(ATTR_USERNAME, AttributeValues.fromString(username)))
.build());
return Optional.ofNullable(response.item())
// ignore items with a ttl (reservations)
.filter(item -> !item.containsKey(ATTR_TTL))
.map(item -> item.get(KEY_ACCOUNT_UUID))
.map(this::accountByUuid)
.map(Accounts::fromItem);
});
} }
@Nonnull
public Optional<Account> getByPhoneNumberIdentifier(final UUID phoneNumberIdentifier) { public Optional<Account> getByPhoneNumberIdentifier(final UUID phoneNumberIdentifier) {
return GET_BY_PNI_TIMER.record(() -> { return getByIndirectLookup(
GET_BY_PNI_TIMER, phoneNumberIdentifierConstraintTableName, ATTR_PNI_UUID, AttributeValues.fromUUID(phoneNumberIdentifier));
final GetItemResponse response = client.getItem(GetItemRequest.builder()
.tableName(phoneNumberIdentifierConstraintTableName)
.key(Map.of(ATTR_PNI_UUID, AttributeValues.fromUUID(phoneNumberIdentifier)))
.build());
return Optional.ofNullable(response.item())
.map(item -> item.get(KEY_ACCOUNT_UUID))
.map(this::accountByUuid)
.map(Accounts::fromItem);
});
} }
private Map<String, AttributeValue> accountByUuid(AttributeValue uuid) { @Nonnull
GetItemResponse r = client.getItem(GetItemRequest.builder() public Optional<Account> getByUsername(final String username) {
.tableName(accountsTableName) return getByIndirectLookup(
.key(Map.of(KEY_ACCOUNT_UUID, uuid)) GET_BY_USERNAME_TIMER,
.consistentRead(true) usernamesConstraintTableName,
.build()); ATTR_USERNAME,
return r.item().isEmpty() ? null : r.item(); AttributeValues.fromString(username),
item -> !item.containsKey(ATTR_TTL) // ignore items with a ttl (reservations)
);
} }
public Optional<Account> getByAccountIdentifier(UUID uuid) { @Nonnull
return GET_BY_UUID_TIMER.record(() -> public Optional<Account> getByAccountIdentifier(final UUID uuid) {
Optional.ofNullable(accountByUuid(AttributeValues.fromUUID(uuid))) return requireNonNull(GET_BY_UUID_TIMER.record(() ->
.map(Accounts::fromItem)); itemByKey(accountsTableName, KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid))
.map(Accounts::fromItem)));
} }
public void delete(UUID uuid) { public void delete(final UUID uuid) {
DELETE_TIMER.record(() -> { DELETE_TIMER.record(() -> getByAccountIdentifier(uuid).ifPresent(account -> {
getByAccountIdentifier(uuid).ifPresent(account -> { final List<TransactWriteItem> transactWriteItems = new ArrayList<>(List.of(
buildDelete(phoneNumberConstraintTableName, ATTR_ACCOUNT_E164, account.getNumber()),
buildDelete(accountsTableName, KEY_ACCOUNT_UUID, uuid),
buildDelete(phoneNumberIdentifierConstraintTableName, ATTR_PNI_UUID, account.getPhoneNumberIdentifier())
));
TransactWriteItem phoneNumberDelete = TransactWriteItem.builder() account.getUsername().ifPresent(username -> transactWriteItems.add(
.delete(Delete.builder() buildDelete(usernamesConstraintTableName, ATTR_USERNAME, username)));
.tableName(phoneNumberConstraintTableName)
.key(Map.of(ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber())))
.build())
.build();
TransactWriteItem accountDelete = TransactWriteItem.builder() final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.delete(Delete.builder() .transactItems(transactWriteItems).build();
.tableName(accountsTableName)
.key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid)))
.build())
.build();
final List<TransactWriteItem> transactWriteItems = new ArrayList<>(List.of(phoneNumberDelete, accountDelete)); db().transactWriteItems(request);
}));
transactWriteItems.add(TransactWriteItem.builder()
.delete(Delete.builder()
.tableName(phoneNumberIdentifierConstraintTableName)
.key(Map.of(ATTR_PNI_UUID, AttributeValues.fromUUID(account.getPhoneNumberIdentifier())))
.build())
.build());
account.getUsername().ifPresent(username -> transactWriteItems.add(TransactWriteItem.builder()
.delete(Delete.builder()
.tableName(usernamesConstraintTableName)
.key(Map.of(ATTR_USERNAME, AttributeValues.fromString(username)))
.build())
.build()));
TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(transactWriteItems).build();
client.transactWriteItems(request);
});
});
} }
@Nonnull
public AccountCrawlChunk getAllFrom(final UUID from, final int maxCount) { public AccountCrawlChunk getAllFrom(final UUID from, final int maxCount) {
final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder() final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder()
.limit(scanPageSize) .limit(scanPageSize)
@ -833,6 +683,7 @@ public class Accounts extends AbstractDynamoDbStore {
return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_OFFSET_TIMER); return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_OFFSET_TIMER);
} }
@Nonnull
public AccountCrawlChunk getAllFromStart(final int maxCount) { public AccountCrawlChunk getAllFromStart(final int maxCount) {
final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder() final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder()
.limit(scanPageSize); .limit(scanPageSize);
@ -840,34 +691,185 @@ public class Accounts extends AbstractDynamoDbStore {
return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_START_TIMER); return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_START_TIMER);
} }
private static <T> CompletionStage<T> record(final Timer timer, Supplier<CompletionStage<T>> toRecord) { @Nonnull
final Instant start = Instant.now(); private Optional<Account> getByIndirectLookup(
return toRecord.get().whenComplete((ignoreT, ignoreE) -> timer.record(Duration.between(start, Instant.now()))); final Timer timer,
final String tableName,
final String keyName,
final AttributeValue keyValue) {
return getByIndirectLookup(timer, tableName, keyName, keyValue, i -> true);
} }
@Nonnull
private Optional<Account> getByIndirectLookup(
final Timer timer,
final String tableName,
final String keyName,
final AttributeValue keyValue,
final Predicate<? super Map<String, AttributeValue>> predicate) {
return requireNonNull(timer.record(() -> itemByKey(tableName, keyName, keyValue)
.filter(predicate)
.map(item -> item.get(KEY_ACCOUNT_UUID))
.flatMap(uuid -> itemByKey(accountsTableName, KEY_ACCOUNT_UUID, uuid))
.map(Accounts::fromItem)));
}
@Nonnull
private Optional<Map<String, AttributeValue>> itemByKey(final String table, final String keyName, final AttributeValue keyValue) {
final GetItemResponse response = db().getItem(GetItemRequest.builder()
.tableName(table)
.key(Map.of(keyName, keyValue))
.consistentRead(true)
.build());
return Optional.ofNullable(response.item()).filter(m -> !m.isEmpty());
}
@Nonnull
private TransactWriteItem buildAccountPut(
final Account account,
final AttributeValue uuidAttr,
final AttributeValue numberAttr,
final AttributeValue pniUuidAttr) throws JsonProcessingException {
final Map<String, AttributeValue> item = new HashMap<>(Map.of(
KEY_ACCOUNT_UUID, uuidAttr,
ATTR_ACCOUNT_E164, numberAttr,
ATTR_PNI_UUID, pniUuidAttr,
ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)),
ATTR_VERSION, AttributeValues.fromInt(account.getVersion()),
ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.shouldBeVisibleInDirectory())));
// Add the UAK if it's in the account
account.getUnidentifiedAccessKey()
.map(AttributeValues::fromByteArray)
.ifPresent(uak -> item.put(ATTR_UAK, uak));
return TransactWriteItem.builder()
.put(Put.builder()
.conditionExpression("attribute_not_exists(#number) OR #number = :number")
.expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164))
.expressionAttributeValues(Map.of(":number", numberAttr))
.tableName(accountsTableName)
.item(item)
.build())
.build();
}
@Nonnull
private static TransactWriteItem buildConstraintTablePutIfAbsent(
final String tableName,
final AttributeValue uuidAttr,
final String keyName,
final AttributeValue keyValue
) {
return TransactWriteItem.builder()
.put(Put.builder()
.tableName(tableName)
.item(Map.of(
keyName, keyValue,
KEY_ACCOUNT_UUID, uuidAttr))
.conditionExpression(
"attribute_not_exists(#key) OR #uuid = :uuid")
.expressionAttributeNames(Map.of(
"#key", keyName,
"#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(
":uuid", uuidAttr))
.returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD)
.build())
.build();
}
@Nonnull
private static TransactWriteItem buildConstraintTablePut(
final String tableName,
final AttributeValue uuidAttr,
final String keyName,
final AttributeValue keyValue) {
return TransactWriteItem.builder()
.put(Put.builder()
.tableName(tableName)
.item(Map.of(
keyName, keyValue,
KEY_ACCOUNT_UUID, uuidAttr))
.conditionExpression(
"attribute_not_exists(#key)")
.expressionAttributeNames(Map.of(
"#key", keyName))
.returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD)
.build())
.build();
}
@Nonnull
private static TransactWriteItem buildDelete(final String tableName, final String keyName, final String keyValue) {
return buildDelete(tableName, keyName, AttributeValues.fromString(keyValue));
}
@Nonnull
private static TransactWriteItem buildDelete(final String tableName, final String keyName, final UUID keyValue) {
return buildDelete(tableName, keyName, AttributeValues.fromUUID(keyValue));
}
@Nonnull
private static TransactWriteItem buildDelete(final String tableName, final String keyName, final AttributeValue keyValue) {
return TransactWriteItem.builder()
.delete(Delete.builder()
.tableName(tableName)
.key(Map.of(keyName, keyValue))
.build())
.build();
}
@Nonnull
private static <T> CompletionStage<T> record(final Timer timer, final Supplier<CompletionStage<T>> toRecord) {
final Timer.Sample sample = Timer.start();
return toRecord.get().whenComplete((ignoreT, ignoreE) -> sample.stop(timer));
}
@Nonnull
private AccountCrawlChunk scanForChunk(final ScanRequest.Builder scanRequestBuilder, final int maxCount, final Timer timer) { private AccountCrawlChunk scanForChunk(final ScanRequest.Builder scanRequestBuilder, final int maxCount, final Timer timer) {
scanRequestBuilder.tableName(accountsTableName); scanRequestBuilder.tableName(accountsTableName);
final List<Map<String, AttributeValue>> items = timer.record(() -> scan(scanRequestBuilder.build(), maxCount)); final List<Map<String, AttributeValue>> items = requireNonNull(timer.record(() -> scan(scanRequestBuilder.build(), maxCount)));
final List<Account> accounts = items.stream().map(Accounts::fromItem).toList(); final List<Account> accounts = items.stream().map(Accounts::fromItem).toList();
return new AccountCrawlChunk(accounts, accounts.size() > 0 ? accounts.get(accounts.size() - 1).getUuid() : null); return new AccountCrawlChunk(accounts, accounts.size() > 0 ? accounts.get(accounts.size() - 1).getUuid() : null);
} }
@Nonnull
private static String extractCancellationReasonCodes(final TransactionCanceledException exception) { private static String extractCancellationReasonCodes(final TransactionCanceledException exception) {
return exception.cancellationReasons().stream() return exception.cancellationReasons().stream()
.map(CancellationReason::code) .map(CancellationReason::code)
.collect(Collectors.joining(", ")); .collect(Collectors.joining(", "));
} }
@Nonnull
public static byte[] reservedUsernameHash(final UUID accountId, final String reservedUsername) {
final MessageDigest sha256;
try {
sha256 = MessageDigest.getInstance("SHA-256");
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
final ByteBuffer byteBuffer = ByteBuffer.allocate(32 + 1);
sha256.update(reservedUsername.getBytes(StandardCharsets.UTF_8));
sha256.update(UUIDUtil.toBytes(accountId));
byteBuffer.put(RESERVED_USERNAME_HASH_VERSION);
byteBuffer.put(sha256.digest());
return byteBuffer.array();
}
@VisibleForTesting @VisibleForTesting
static Account fromItem(Map<String, AttributeValue> item) { @Nonnull
if (!item.containsKey(ATTR_ACCOUNT_DATA) || static Account fromItem(final Map<String, AttributeValue> item) {
!item.containsKey(ATTR_ACCOUNT_E164) || // TODO: eventually require ATTR_CANONICALLY_DISCOVERABLE
// TODO: eventually require ATTR_CANONICALLY_DISCOVERABLE if (!item.containsKey(ATTR_ACCOUNT_DATA)
!item.containsKey(KEY_ACCOUNT_UUID)) { || !item.containsKey(ATTR_ACCOUNT_E164)
|| !item.containsKey(KEY_ACCOUNT_UUID)) {
throw new RuntimeException("item missing values"); throw new RuntimeException("item missing values");
} }
try { try {
Account account = SystemMapper.getMapper().readValue(item.get(ATTR_ACCOUNT_DATA).b().asByteArray(), Account.class); final Account account = SystemMapper.getMapper().readValue(item.get(ATTR_ACCOUNT_DATA).b().asByteArray(), Account.class);
final UUID accountIdentifier = UUIDUtil.fromByteBuffer(item.get(KEY_ACCOUNT_UUID).b().asByteBuffer()); final UUID accountIdentifier = UUIDUtil.fromByteBuffer(item.get(KEY_ACCOUNT_UUID).b().asByteBuffer());
final UUID phoneNumberIdentifierFromAttribute = AttributeValues.getUUID(item, ATTR_PNI_UUID, null); final UUID phoneNumberIdentifierFromAttribute = AttributeValues.getUUID(item, ATTR_PNI_UUID, null);
@ -883,12 +885,18 @@ public class Accounts extends AbstractDynamoDbStore {
account.setUuid(accountIdentifier); account.setUuid(accountIdentifier);
account.setUsername(AttributeValues.getString(item, ATTR_USERNAME, null)); account.setUsername(AttributeValues.getString(item, ATTR_USERNAME, null));
account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n())); account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n()));
account.setCanonicallyDiscoverable(Optional.ofNullable(item.get(ATTR_CANONICALLY_DISCOVERABLE)).map(av -> av.bool()).orElse(false)); account.setCanonicallyDiscoverable(Optional.ofNullable(item.get(ATTR_CANONICALLY_DISCOVERABLE))
.map(AttributeValue::bool)
.orElse(false));
return account; return account;
} catch (IOException e) { } catch (final IOException e) {
throw new RuntimeException("Could not read stored account data", e); throw new RuntimeException("Could not read stored account data", e);
} }
} }
private static boolean conditionalCheckFailed(final CancellationReason reason) {
return CONDITIONAL_CHECK_FAILED.equals(reason.code());
}
} }

View File

@ -6,5 +6,6 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
public interface PubSubAddress { public interface PubSubAddress {
public String serialize();
String serialize();
} }

View File

@ -0,0 +1,40 @@
package org.whispersystems.textsecuregcm.util;
import java.util.concurrent.CompletionException;
public final class ExceptionUtils {
private ExceptionUtils() {
// utility class
}
/**
* Extracts the cause of a {@link CompletionException}. If the given {@code throwable} is a
* {@code CompletionException}, this method will recursively iterate through its causal chain until it finds the first
* cause that is not a {@code CompletionException}. If the last {@code CompletionException} in the causal chain has a
* {@code null} cause, then this method returns the last {@code CompletionException} in the chain. If the given
* {@code throwable} is not a {@code CompletionException}, then this method returns the original {@code throwable}.
*
* @param throwable the throwable to "unwrap"
*
* @return the first entity in the given {@code throwable}'s causal chain that is not a {@code CompletionException}
*/
public static Throwable unwrap(Throwable throwable) {
while (throwable instanceof CompletionException e && throwable.getCause() != null) {
throwable = e.getCause();
}
return throwable;
}
/**
* Wraps the given {@code throwable} in a {@link CompletionException} unless the given {@code throwable} is alreadt
* a {@code CompletionException}, in which case this method returns the original throwable.
*
* @param throwable the throwable to wrap in a {@code CompletionException}
*/
public static CompletionException wrap(final Throwable throwable) {
return throwable instanceof CompletionException completionException
? completionException
: new CompletionException(throwable);
}
}

View File

@ -10,20 +10,25 @@ import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import javax.annotation.Nonnull;
public class SystemMapper { public class SystemMapper {
private static final ObjectMapper mapper = new ObjectMapper(); private static final ObjectMapper MAPPER = build();
static {
@Nonnull
public static ObjectMapper getMapper() {
return MAPPER;
}
@Nonnull
private static ObjectMapper build() {
final ObjectMapper mapper = new ObjectMapper();
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
mapper.registerModule(new JavaTimeModule()); mapper.registerModule(new JavaTimeModule());
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
}
public static ObjectMapper getMapper() {
return mapper; return mapper;
} }
} }

View File

@ -141,7 +141,7 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
VerificationCodeStore pendingAccounts = new VerificationCodeStore(dynamoDbClient, VerificationCodeStore pendingAccounts = new VerificationCodeStore(dynamoDbClient,
configuration.getDynamoDbTables().getPendingAccounts().getTableName()); configuration.getDynamoDbTables().getPendingAccounts().getTableName());
Accounts accounts = new Accounts(dynamicConfigurationManager, Accounts accounts = new Accounts(
dynamoDbClient, dynamoDbClient,
dynamoDbAsyncClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getAccounts().getTableName(), configuration.getDynamoDbTables().getAccounts().getTableName(),

View File

@ -143,7 +143,7 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
VerificationCodeStore pendingAccounts = new VerificationCodeStore(dynamoDbClient, VerificationCodeStore pendingAccounts = new VerificationCodeStore(dynamoDbClient,
configuration.getDynamoDbTables().getPendingAccounts().getTableName()); configuration.getDynamoDbTables().getPendingAccounts().getTableName());
Accounts accounts = new Accounts(dynamicConfigurationManager, Accounts accounts = new Accounts(
dynamoDbClient, dynamoDbClient,
dynamoDbAsyncClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getAccounts().getTableName(), configuration.getDynamoDbTables().getAccounts().getTableName(),

View File

@ -146,7 +146,7 @@ public class SetUserDiscoverabilityCommand extends EnvironmentCommand<WhisperSer
VerificationCodeStore pendingAccounts = new VerificationCodeStore(dynamoDbClient, VerificationCodeStore pendingAccounts = new VerificationCodeStore(dynamoDbClient,
configuration.getDynamoDbTables().getPendingAccounts().getTableName()); configuration.getDynamoDbTables().getPendingAccounts().getTableName());
Accounts accounts = new Accounts(dynamicConfigurationManager, Accounts accounts = new Accounts(
dynamoDbClient, dynamoDbClient,
dynamoDbAsyncClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getAccounts().getTableName(), configuration.getDynamoDbTables().getAccounts().getTableName(),

View File

@ -157,7 +157,6 @@ class AccountsManagerChangeNumberIntegrationTest {
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final Accounts accounts = new Accounts( final Accounts accounts = new Accounts(
dynamicConfigurationManager,
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient(), ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient(),
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbAsyncClient(), ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbAsyncClient(),
ACCOUNTS_DYNAMO_EXTENSION.getTableName(), ACCOUNTS_DYNAMO_EXTENSION.getTableName(),

View File

@ -127,7 +127,6 @@ class AccountsManagerConcurrentModificationIntegrationTest {
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
accounts = new Accounts( accounts = new Accounts(
dynamicConfigurationManager,
dynamoDbExtension.getDynamoDbClient(), dynamoDbExtension.getDynamoDbClient(),
dynamoDbExtension.getDynamoDbAsyncClient(), dynamoDbExtension.getDynamoDbAsyncClient(),
dynamoDbExtension.getTableName(), dynamoDbExtension.getTableName(),

View File

@ -5,6 +5,27 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Consumer;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
@ -22,17 +43,14 @@ import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.UsernameGenerator; import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import software.amazon.awssdk.services.dynamodb.model.*; import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import java.time.Clock; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.time.Duration; import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import java.time.Instant; import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
import java.util.*; import software.amazon.awssdk.services.dynamodb.model.KeyType;
import java.util.function.Consumer; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
import static org.assertj.core.api.Assertions.assertThat; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
class AccountsManagerUsernameIntegrationTest { class AccountsManagerUsernameIntegrationTest {
@ -126,7 +144,6 @@ class AccountsManagerUsernameIntegrationTest {
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
accounts = Mockito.spy(new Accounts( accounts = Mockito.spy(new Accounts(
dynamicConfigurationManager,
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient(), ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient(),
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbAsyncClient(), ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbAsyncClient(),
ACCOUNTS_DYNAMO_EXTENSION.getTableName(), ACCOUNTS_DYNAMO_EXTENSION.getTableName(),

View File

@ -17,7 +17,6 @@ import static org.mockito.Mockito.when;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.uuid.UUIDComparator; import com.fasterxml.uuid.UUIDComparator;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
@ -138,7 +137,6 @@ class AccountsTest {
this.accounts = new Accounts( this.accounts = new Accounts(
clock, clock,
mockDynamicConfigManager,
dynamoDbExtension.getDynamoDbClient(), dynamoDbExtension.getDynamoDbClient(),
dynamoDbExtension.getDynamoDbAsyncClient(), dynamoDbExtension.getDynamoDbAsyncClient(),
dynamoDbExtension.getTableName(), dynamoDbExtension.getTableName(),
@ -373,7 +371,7 @@ class AccountsTest {
void testUpdateWithMockTransactionConflictException(boolean wrapException) { void testUpdateWithMockTransactionConflictException(boolean wrapException) {
final DynamoDbAsyncClient dynamoDbAsyncClient = mock(DynamoDbAsyncClient.class); final DynamoDbAsyncClient dynamoDbAsyncClient = mock(DynamoDbAsyncClient.class);
accounts = new Accounts(mockDynamicConfigManager, mock(DynamoDbClient.class), accounts = new Accounts(mock(DynamoDbClient.class),
dynamoDbAsyncClient, dynamoDbExtension.getTableName(), dynamoDbAsyncClient, dynamoDbExtension.getTableName(),
NUMBER_CONSTRAINT_TABLE_NAME, PNI_CONSTRAINT_TABLE_NAME, USERNAME_CONSTRAINT_TABLE_NAME, SCAN_PAGE_SIZE); NUMBER_CONSTRAINT_TABLE_NAME, PNI_CONSTRAINT_TABLE_NAME, USERNAME_CONSTRAINT_TABLE_NAME, SCAN_PAGE_SIZE);