diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 9bc9cdaa3..c911a705c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -337,7 +337,7 @@ public class WhisperServerService extends Application> items) { - AtomicReference outcome = new AtomicReference<>(); - batchWriteItemsFirstPass.record( - () -> outcome.set(dynamoDbClient.batchWriteItem(BatchWriteItemRequest.builder().requestItems(items).build()))); + final AtomicReference outcome = new AtomicReference<>(); + writeAndStoreOutcome(items, batchWriteItemsFirstPass, outcome); int attemptCount = 0; while (!outcome.get().unprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) { - batchWriteItemsRetryPass.record(() -> outcome.set(dynamoDbClient.batchWriteItem(BatchWriteItemRequest.builder() - .requestItems(outcome.get().unprocessedItems()) - .build()))); + writeAndStoreOutcome(outcome.get().unprocessedItems(), batchWriteItemsRetryPass, outcome); ++attemptCount; } if (!outcome.get().unprocessedItems().isEmpty()) { - int totalItems = outcome.get().unprocessedItems().values().stream().mapToInt(List::size).sum(); + final int totalItems = outcome.get().unprocessedItems().values().stream().mapToInt(List::size).sum(); logger.error( "Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.", attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, totalItems); @@ -68,19 +70,28 @@ public abstract class AbstractDynamoDbStore { } } - protected List> scan(ScanRequest scanRequest, int max) { - + @Nonnull + protected List> scan(final ScanRequest scanRequest, final int max) { return db().scanPaginator(scanRequest) .items() .stream() .limit(max) - .collect(Collectors.toList()); + .toList(); + } + + private void writeAndStoreOutcome( + final Map> items, + final Timer timer, + final AtomicReference outcome) { + timer.record( + () -> outcome.set(dynamoDbClient.batchWriteItem(BatchWriteItemRequest.builder().requestItems(items).build())) + ); } static void writeInBatches(final Iterable items, final Consumer> action) { final List batch = new ArrayList<>(DYNAMO_DB_MAX_BATCH_SIZE); - for (T item : items) { + for (final T item : items) { batch.add(item); if (batch.size() == DYNAMO_DB_MAX_BATCH_SIZE) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 3e79bd41f..d4f3d8fa2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.storage; import static com.codahale.metrics.MetricRegistry.name; +import static java.util.Objects.requireNonNull; import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.annotations.VisibleForTesting; @@ -18,7 +19,6 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.time.Clock; import java.time.Duration; -import java.time.Instant; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -29,12 +29,14 @@ import java.util.UUID; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; +import javax.annotation.Nonnull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.util.AttributeValues; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.UUIDUtil; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; @@ -56,8 +58,31 @@ import software.amazon.awssdk.services.dynamodb.model.Update; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; import software.amazon.awssdk.utils.CompletableFutureUtils; +@SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class Accounts extends AbstractDynamoDbStore { + private static final Logger log = LoggerFactory.getLogger(Accounts.class); + + private static final byte RESERVED_USERNAME_HASH_VERSION = 1; + + private static final Timer CREATE_TIMER = Metrics.timer(name(Accounts.class, "create")); + private static final Timer CHANGE_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "changeNumber")); + private static final Timer SET_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "setUsername")); + private static final Timer RESERVE_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "reserveUsername")); + private static final Timer CLEAR_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "clearUsername")); + private static final Timer UPDATE_TIMER = Metrics.timer(name(Accounts.class, "update")); + private static final Timer GET_BY_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "getByNumber")); + private static final Timer GET_BY_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "getByUsername")); + private static final Timer GET_BY_PNI_TIMER = Metrics.timer(name(Accounts.class, "getByPni")); + private static final Timer GET_BY_UUID_TIMER = Metrics.timer(name(Accounts.class, "getByUuid")); + private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom")); + private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset")); + private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete")); + + private static final String CONDITIONAL_CHECK_FAILED = "ConditionalCheckFailed"; + + private static final String TRANSACTION_CONFLICT = "TransactionConflict"; + // uuid, primary key static final String KEY_ACCOUNT_UUID = "U"; // uuid, attribute on account table, primary key for PNI table @@ -78,45 +103,32 @@ public class Accounts extends AbstractDynamoDbStore { static final String ATTR_TTL = "TTL"; private final Clock clock; - private final DynamoDbClient client; + private final DynamoDbAsyncClient asyncClient; private final String phoneNumberConstraintTableName; + private final String phoneNumberIdentifierConstraintTableName; + private final String usernamesConstraintTableName; + private final String accountsTableName; private final int scanPageSize; - private static final byte RESERVED_USERNAME_HASH_VERSION = 1; - - private static final Timer CREATE_TIMER = Metrics.timer(name(Accounts.class, "create")); - private static final Timer CHANGE_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "changeNumber")); - private static final Timer SET_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "setUsername")); - private static final Timer RESERVE_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "reserveUsername")); - private static final Timer CLEAR_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "clearUsername")); - private static final Timer UPDATE_TIMER = Metrics.timer(name(Accounts.class, "update")); - private static final Timer GET_BY_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "getByNumber")); - private static final Timer GET_BY_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "getByUsername")); - private static final Timer GET_BY_PNI_TIMER = Metrics.timer(name(Accounts.class, "getByPni")); - private static final Timer GET_BY_UUID_TIMER = Metrics.timer(name(Accounts.class, "getByUuid")); - private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom")); - private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset")); - private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete")); - - private static final Logger log = LoggerFactory.getLogger(Accounts.class); @VisibleForTesting public Accounts( final Clock clock, - final DynamicConfigurationManager dynamicConfigurationManager, - DynamoDbClient client, DynamoDbAsyncClient asyncClient, - String accountsTableName, String phoneNumberConstraintTableName, - String phoneNumberIdentifierConstraintTableName, final String usernamesConstraintTableName, + final DynamoDbClient client, + final DynamoDbAsyncClient asyncClient, + final String accountsTableName, + final String phoneNumberConstraintTableName, + final String phoneNumberIdentifierConstraintTableName, + final String usernamesConstraintTableName, final int scanPageSize) { super(client); this.clock = clock; - this.client = client; this.asyncClient = asyncClient; this.phoneNumberConstraintTableName = phoneNumberConstraintTableName; this.phoneNumberIdentifierConstraintTableName = phoneNumberIdentifierConstraintTableName; @@ -125,105 +137,61 @@ public class Accounts extends AbstractDynamoDbStore { this.scanPageSize = scanPageSize; } - public Accounts(final DynamicConfigurationManager dynamicConfigurationManager, - DynamoDbClient client, DynamoDbAsyncClient asyncClient, - String accountsTableName, String phoneNumberConstraintTableName, - String phoneNumberIdentifierConstraintTableName, final String usernamesConstraintTableName, + public Accounts( + final DynamoDbClient client, + final DynamoDbAsyncClient asyncClient, + final String accountsTableName, + final String phoneNumberConstraintTableName, + final String phoneNumberIdentifierConstraintTableName, + final String usernamesConstraintTableName, final int scanPageSize) { - this(Clock.systemUTC(), dynamicConfigurationManager, client, asyncClient, accountsTableName, + this(Clock.systemUTC(), client, asyncClient, accountsTableName, phoneNumberConstraintTableName, phoneNumberIdentifierConstraintTableName, usernamesConstraintTableName, scanPageSize); } - public boolean create(Account account) { + public boolean create(final Account account) { return CREATE_TIMER.record(() -> { - try { - TransactWriteItem phoneNumberConstraintPut = TransactWriteItem.builder() - .put( - Put.builder() - .tableName(phoneNumberConstraintTableName) - .item(Map.of( - ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()), - KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) - .conditionExpression( - "attribute_not_exists(#number) OR (attribute_exists(#number) AND #uuid = :uuid)") - .expressionAttributeNames( - Map.of("#uuid", KEY_ACCOUNT_UUID, - "#number", ATTR_ACCOUNT_E164)) - .expressionAttributeValues( - Map.of(":uuid", AttributeValues.fromUUID(account.getUuid()))) - .returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD) - .build()) - .build(); + final AttributeValue uuidAttr = AttributeValues.fromUUID(account.getUuid()); + final AttributeValue numberAttr = AttributeValues.fromString(account.getNumber()); + final AttributeValue pniUuidAttr = AttributeValues.fromUUID(account.getPhoneNumberIdentifier()); - TransactWriteItem phoneNumberIdentifierConstraintPut = TransactWriteItem.builder() - .put( - Put.builder() - .tableName(phoneNumberIdentifierConstraintTableName) - .item(Map.of( - ATTR_PNI_UUID, AttributeValues.fromUUID(account.getPhoneNumberIdentifier()), - KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) - .conditionExpression( - "attribute_not_exists(#pni) OR (attribute_exists(#pni) AND #uuid = :uuid)") - .expressionAttributeNames( - Map.of("#uuid", KEY_ACCOUNT_UUID, - "#pni", ATTR_PNI_UUID)) - .expressionAttributeValues( - Map.of(":uuid", AttributeValues.fromUUID(account.getUuid()))) - .returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD) - .build()) - .build(); + final TransactWriteItem phoneNumberConstraintPut = buildConstraintTablePutIfAbsent( + phoneNumberConstraintTableName, uuidAttr, ATTR_ACCOUNT_E164, numberAttr); - final Map item = new HashMap<>(Map.of( - KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()), - ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()), - ATTR_PNI_UUID, AttributeValues.fromUUID(account.getPhoneNumberIdentifier()), - ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)), - ATTR_VERSION, AttributeValues.fromInt(account.getVersion()), - ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.shouldBeVisibleInDirectory()))); + final TransactWriteItem phoneNumberIdentifierConstraintPut = buildConstraintTablePutIfAbsent( + phoneNumberIdentifierConstraintTableName, uuidAttr, ATTR_PNI_UUID, pniUuidAttr); - // Add the UAK if it's in the account - account.getUnidentifiedAccessKey() - .map(AttributeValues::fromByteArray) - .ifPresent(uak -> item.put(ATTR_UAK, uak)); - - TransactWriteItem accountPut = TransactWriteItem.builder() - .put(Put.builder() - .conditionExpression("attribute_not_exists(#number) OR #number = :number") - .expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164)) - .expressionAttributeValues(Map.of(":number", AttributeValues.fromString(account.getNumber()))) - .tableName(accountsTableName) - .item(item) - .build()) - .build(); + final TransactWriteItem accountPut = buildAccountPut(account, uuidAttr, numberAttr, pniUuidAttr); final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() .transactItems(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut) .build(); try { - client.transactWriteItems(request); - } catch (TransactionCanceledException e) { + db().transactWriteItems(request); + } catch (final TransactionCanceledException e) { final CancellationReason accountCancellationReason = e.cancellationReasons().get(2); - if ("ConditionalCheckFailed".equals(accountCancellationReason.code())) { + if (conditionalCheckFailed(accountCancellationReason)) { throw new IllegalArgumentException("account identifier present with different phone number"); } final CancellationReason phoneNumberConstraintCancellationReason = e.cancellationReasons().get(0); final CancellationReason phoneNumberIdentifierConstraintCancellationReason = e.cancellationReasons().get(1); - if ("ConditionalCheckFailed".equals(phoneNumberConstraintCancellationReason.code()) || - "ConditionalCheckFailed".equals(phoneNumberIdentifierConstraintCancellationReason.code())) { + if (conditionalCheckFailed(phoneNumberConstraintCancellationReason) + || conditionalCheckFailed(phoneNumberIdentifierConstraintCancellationReason)) { // In theory, both reasons should trip in tandem and either should give us the information we need. Even so, // we'll be cautious here and make sure we're choosing a condition check that really failed. - final CancellationReason reason = "ConditionalCheckFailed".equals(phoneNumberConstraintCancellationReason.code()) ? - phoneNumberConstraintCancellationReason : phoneNumberIdentifierConstraintCancellationReason; + final CancellationReason reason = conditionalCheckFailed(phoneNumberConstraintCancellationReason) + ? phoneNumberConstraintCancellationReason + : phoneNumberIdentifierConstraintCancellationReason; - ByteBuffer actualAccountUuid = reason.item().get(KEY_ACCOUNT_UUID).b().asByteBuffer(); + final ByteBuffer actualAccountUuid = reason.item().get(KEY_ACCOUNT_UUID).b().asByteBuffer(); account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid)); final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow(); @@ -235,7 +203,7 @@ public class Accounts extends AbstractDynamoDbStore { return false; } - if ("TransactionConflict".equals(accountCancellationReason.code())) { + if (TRANSACTION_CONFLICT.equals(accountCancellationReason.code())) { // this should only happen if two clients manage to make concurrent create() calls throw new ContestedOptimisticLockException(); } @@ -243,7 +211,7 @@ public class Accounts extends AbstractDynamoDbStore { // this shouldn't happen throw new RuntimeException("could not create account: " + extractCancellationReasonCodes(e)); } - } catch (JsonProcessingException e) { + } catch (final JsonProcessingException e) { throw new IllegalArgumentException(e); } @@ -275,62 +243,34 @@ public class Accounts extends AbstractDynamoDbStore { try { final List writeItems = new ArrayList<>(); + final AttributeValue uuidAttr = AttributeValues.fromUUID(account.getUuid()); + final AttributeValue numberAttr = AttributeValues.fromString(number); + final AttributeValue pniAttr = AttributeValues.fromUUID(phoneNumberIdentifier); - writeItems.add(TransactWriteItem.builder() - .delete(Delete.builder() - .tableName(phoneNumberConstraintTableName) - .key(Map.of(ATTR_ACCOUNT_E164, AttributeValues.fromString(originalNumber))) - .build()) - .build()); - - writeItems.add(TransactWriteItem.builder() - .put(Put.builder() - .tableName(phoneNumberConstraintTableName) - .item(Map.of( - KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()), - ATTR_ACCOUNT_E164, AttributeValues.fromString(number))) - .conditionExpression("attribute_not_exists(#number)") - .expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164)) - .returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD) - .build()) - .build()); - - writeItems.add(TransactWriteItem.builder() - .delete(Delete.builder() - .tableName(phoneNumberIdentifierConstraintTableName) - .key(Map.of(ATTR_PNI_UUID, AttributeValues.fromUUID(originalPni))) - .build()) - .build()); - - writeItems.add(TransactWriteItem.builder() - .put(Put.builder() - .tableName(phoneNumberIdentifierConstraintTableName) - .item(Map.of( - KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()), - ATTR_PNI_UUID, AttributeValues.fromUUID(phoneNumberIdentifier))) - .conditionExpression("attribute_not_exists(#pni)") - .expressionAttributeNames(Map.of("#pni", ATTR_PNI_UUID)) - .returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD) - .build()) - .build()); - + writeItems.add(buildDelete(phoneNumberConstraintTableName, ATTR_ACCOUNT_E164, originalNumber)); + writeItems.add(buildConstraintTablePut(phoneNumberConstraintTableName, uuidAttr, ATTR_ACCOUNT_E164, numberAttr)); + writeItems.add(buildDelete(phoneNumberIdentifierConstraintTableName, ATTR_PNI_UUID, originalPni)); + writeItems.add(buildConstraintTablePut(phoneNumberIdentifierConstraintTableName, uuidAttr, ATTR_PNI_UUID, pniAttr)); writeItems.add( TransactWriteItem.builder() .update(Update.builder() .tableName(accountsTableName) - .key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) - .updateExpression("SET #data = :data, #number = :number, #pni = :pni, #cds = :cds ADD #version :version_increment") - .conditionExpression("attribute_exists(#number) AND #version = :version") - .expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164, + .key(Map.of(KEY_ACCOUNT_UUID, uuidAttr)) + .updateExpression( + "SET #data = :data, #number = :number, #pni = :pni, #cds = :cds ADD #version :version_increment") + .conditionExpression( + "attribute_exists(#number) AND #version = :version") + .expressionAttributeNames(Map.of( + "#number", ATTR_ACCOUNT_E164, "#data", ATTR_ACCOUNT_DATA, "#cds", ATTR_CANONICALLY_DISCOVERABLE, "#pni", ATTR_PNI_UUID, "#version", ATTR_VERSION)) .expressionAttributeValues(Map.of( + ":number", numberAttr, ":data", AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)), - ":number", AttributeValues.fromString(number), - ":pni", AttributeValues.fromUUID(phoneNumberIdentifier), ":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()), + ":pni", pniAttr, ":version", AttributeValues.fromInt(account.getVersion()), ":version_increment", AttributeValues.fromInt(1))) .build()) @@ -340,7 +280,7 @@ public class Accounts extends AbstractDynamoDbStore { .transactItems(writeItems) .build(); - client.transactWriteItems(request); + db().transactWriteItems(request); account.setVersion(account.getVersion() + 1); succeeded = true; @@ -354,21 +294,6 @@ public class Accounts extends AbstractDynamoDbStore { }); } - public static byte[] reservedUsernameHash(final UUID accountId, final String reservedUsername) { - final MessageDigest sha256; - try { - sha256 = MessageDigest.getInstance("SHA-256"); - } catch (NoSuchAlgorithmException e) { - throw new AssertionError(e); - } - final ByteBuffer byteBuffer = ByteBuffer.allocate(32 + 1); - sha256.update(reservedUsername.getBytes(StandardCharsets.UTF_8)); - sha256.update(UUIDUtil.toBytes(accountId)); - byteBuffer.put(RESERVED_USERNAME_HASH_VERSION); - byteBuffer.put(sha256.digest()); - return byteBuffer.array(); - } - /** * Reserve a username under a token * @@ -386,7 +311,7 @@ public class Accounts extends AbstractDynamoDbStore { boolean succeeded = false; - long expirationTime = clock.instant().plus(ttl).getEpochSecond(); + final long expirationTime = clock.instant().plus(ttl).getEpochSecond(); final UUID reservationToken = UUID.randomUUID(); try { @@ -425,14 +350,14 @@ public class Accounts extends AbstractDynamoDbStore { .transactItems(writeItems) .build(); - client.transactWriteItems(request); + db().transactWriteItems(request); account.setVersion(account.getVersion() + 1); succeeded = true; } catch (final JsonProcessingException e) { throw new IllegalArgumentException(e); } catch (final TransactionCanceledException e) { - if (e.cancellationReasons().stream().map(CancellationReason::code).anyMatch("ConditionalCheckFailed"::equals)) { + if (e.cancellationReasons().stream().map(CancellationReason::code).anyMatch(CONDITIONAL_CHECK_FAILED::equals)) { throw new ContestedOptimisticLockException(); } throw e; @@ -520,25 +445,21 @@ public class Accounts extends AbstractDynamoDbStore { .build()) .build()); - maybeOriginalUsername.ifPresent(originalUsername -> writeItems.add(TransactWriteItem.builder() - .delete(Delete.builder() - .tableName(usernamesConstraintTableName) - .key(Map.of(ATTR_USERNAME, AttributeValues.fromString(originalUsername))) - .build()) - .build())); + maybeOriginalUsername.ifPresent(originalUsername -> writeItems.add( + buildDelete(usernamesConstraintTableName, ATTR_USERNAME, originalUsername))); final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() .transactItems(writeItems) .build(); - client.transactWriteItems(request); + db().transactWriteItems(request); account.setVersion(account.getVersion() + 1); succeeded = true; } catch (final JsonProcessingException e) { throw new IllegalArgumentException(e); } catch (final TransactionCanceledException e) { - if (e.cancellationReasons().stream().map(CancellationReason::code).anyMatch("ConditionalCheckFailed"::equals)) { + if (e.cancellationReasons().stream().map(CancellationReason::code).anyMatch(CONDITIONAL_CHECK_FAILED::equals)) { throw new ContestedOptimisticLockException(); } throw e; @@ -551,7 +472,7 @@ public class Accounts extends AbstractDynamoDbStore { } } - public void clearUsername(Account account) { + public void clearUsername(final Account account) { account.getUsername().ifPresent(username -> { CLEAR_USERNAME_TIMER.record(() -> { account.setUsername(null); @@ -578,25 +499,20 @@ public class Accounts extends AbstractDynamoDbStore { .build()) .build()); - writeItems.add(TransactWriteItem.builder() - .delete(Delete.builder() - .tableName(usernamesConstraintTableName) - .key(Map.of(ATTR_USERNAME, AttributeValues.fromString(username))) - .build()) - .build()); + writeItems.add(buildDelete(usernamesConstraintTableName, ATTR_USERNAME, username)); final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() .transactItems(writeItems) .build(); - client.transactWriteItems(request); + db().transactWriteItems(request); account.setVersion(account.getVersion() + 1); succeeded = true; } catch (final JsonProcessingException e) { throw new IllegalArgumentException(e); } catch (final TransactionCanceledException e) { - if ("ConditionalCheckFailed".equals(e.cancellationReasons().get(0).code())) { + if (conditionalCheckFailed(e.cancellationReasons().get(0))) { throw new ContestedOptimisticLockException(); } @@ -610,27 +526,18 @@ public class Accounts extends AbstractDynamoDbStore { }); } - /** - * Extract the cause from a CompletionException - */ - private static Throwable unwrap(Throwable throwable) { - while (throwable instanceof CompletionException e && throwable.getCause() != null) { - throwable = e.getCause(); - } - return throwable; - } - - public CompletionStage updateAsync(Account account) { + @Nonnull + public CompletionStage updateAsync(final Account account) { return record(UPDATE_TIMER, () -> { final UpdateItemRequest updateItemRequest; try { // username, e164, and pni cannot be modified through this method - Map attrNames = new HashMap<>(Map.of( + final Map attrNames = new HashMap<>(Map.of( "#number", ATTR_ACCOUNT_E164, "#data", ATTR_ACCOUNT_DATA, "#cds", ATTR_CANONICALLY_DISCOVERABLE, "#version", ATTR_VERSION)); - Map attrValues = new HashMap<>(Map.of( + final Map attrValues = new HashMap<>(Map.of( ":data", AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)), ":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()), ":version", AttributeValues.fromInt(account.getVersion()), @@ -654,7 +561,7 @@ public class Accounts extends AbstractDynamoDbStore { .expressionAttributeNames(attrNames) .expressionAttributeValues(attrValues) .build(); - } catch (JsonProcessingException e) { + } catch (final JsonProcessingException e) { throw new IllegalArgumentException(e); } @@ -664,7 +571,7 @@ public class Accounts extends AbstractDynamoDbStore { return (Void) null; }) .exceptionally(throwable -> { - final Throwable unwrapped = unwrap(throwable); + final Throwable unwrapped = ExceptionUtils.unwrap(throwable); if (unwrapped instanceof TransactionConflictException) { throw new ContestedOptimisticLockException(); } else if (unwrapped instanceof ConditionalCheckFailedException e) { @@ -679,12 +586,12 @@ public class Accounts extends AbstractDynamoDbStore { }); } - public void update(Account account) throws ContestedOptimisticLockException { + public void update(final Account account) throws ContestedOptimisticLockException { try { - this.updateAsync(account).toCompletableFuture().join(); - } catch (CompletionException e) { + updateAsync(account).toCompletableFuture().join(); + } catch (final CompletionException e) { // unwrap CompletionExceptions, throw as long is it's unchecked - Throwables.throwIfUnchecked(unwrap(e)); + Throwables.throwIfUnchecked(ExceptionUtils.unwrap(e)); // if we otherwise somehow got a wrapped checked exception, // rethrow the checked exception wrapped by the original CompletionException @@ -698,15 +605,14 @@ public class Accounts extends AbstractDynamoDbStore { } public boolean usernameAvailable(final Optional reservationToken, final String username) { - final GetItemResponse response = client.getItem(GetItemRequest.builder() - .tableName(usernamesConstraintTableName) - .key(Map.of(ATTR_USERNAME, AttributeValues.fromString(username))) - .build()); - if (!response.hasItem()) { + final Optional> usernameItem = itemByKey( + usernamesConstraintTableName, ATTR_USERNAME, AttributeValues.fromString(username)); + + if (usernameItem.isEmpty()) { // username is free return true; } - final Map item = response.item(); + final Map item = usernameItem.get(); if (AttributeValues.getLong(item, ATTR_TTL, Long.MAX_VALUE) < clock.instant().getEpochSecond()) { // username was reserved, but has expired @@ -719,112 +625,56 @@ public class Accounts extends AbstractDynamoDbStore { .orElse(false); } - public Optional getByE164(String number) { - return GET_BY_NUMBER_TIMER.record(() -> { - - final GetItemResponse response = client.getItem(GetItemRequest.builder() - .tableName(phoneNumberConstraintTableName) - .key(Map.of(ATTR_ACCOUNT_E164, AttributeValues.fromString(number))) - .build()); - - return Optional.ofNullable(response.item()) - .map(item -> item.get(KEY_ACCOUNT_UUID)) - .map(this::accountByUuid) - .map(Accounts::fromItem); - }); - } - - public Optional getByUsername(final String username) { - return GET_BY_USERNAME_TIMER.record(() -> { - - final GetItemResponse response = client.getItem(GetItemRequest.builder() - .tableName(usernamesConstraintTableName) - .key(Map.of(ATTR_USERNAME, AttributeValues.fromString(username))) - .build()); - - - return Optional.ofNullable(response.item()) - // ignore items with a ttl (reservations) - .filter(item -> !item.containsKey(ATTR_TTL)) - .map(item -> item.get(KEY_ACCOUNT_UUID)) - .map(this::accountByUuid) - .map(Accounts::fromItem); - }); + @Nonnull + public Optional getByE164(final String number) { + return getByIndirectLookup( + GET_BY_NUMBER_TIMER, phoneNumberConstraintTableName, ATTR_ACCOUNT_E164, AttributeValues.fromString(number)); } + @Nonnull public Optional getByPhoneNumberIdentifier(final UUID phoneNumberIdentifier) { - return GET_BY_PNI_TIMER.record(() -> { - - final GetItemResponse response = client.getItem(GetItemRequest.builder() - .tableName(phoneNumberIdentifierConstraintTableName) - .key(Map.of(ATTR_PNI_UUID, AttributeValues.fromUUID(phoneNumberIdentifier))) - .build()); - - return Optional.ofNullable(response.item()) - .map(item -> item.get(KEY_ACCOUNT_UUID)) - .map(this::accountByUuid) - .map(Accounts::fromItem); - }); + return getByIndirectLookup( + GET_BY_PNI_TIMER, phoneNumberIdentifierConstraintTableName, ATTR_PNI_UUID, AttributeValues.fromUUID(phoneNumberIdentifier)); } - private Map accountByUuid(AttributeValue uuid) { - GetItemResponse r = client.getItem(GetItemRequest.builder() - .tableName(accountsTableName) - .key(Map.of(KEY_ACCOUNT_UUID, uuid)) - .consistentRead(true) - .build()); - return r.item().isEmpty() ? null : r.item(); + @Nonnull + public Optional getByUsername(final String username) { + return getByIndirectLookup( + GET_BY_USERNAME_TIMER, + usernamesConstraintTableName, + ATTR_USERNAME, + AttributeValues.fromString(username), + item -> !item.containsKey(ATTR_TTL) // ignore items with a ttl (reservations) + ); } - public Optional getByAccountIdentifier(UUID uuid) { - return GET_BY_UUID_TIMER.record(() -> - Optional.ofNullable(accountByUuid(AttributeValues.fromUUID(uuid))) - .map(Accounts::fromItem)); + @Nonnull + public Optional getByAccountIdentifier(final UUID uuid) { + return requireNonNull(GET_BY_UUID_TIMER.record(() -> + itemByKey(accountsTableName, KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid)) + .map(Accounts::fromItem))); } - public void delete(UUID uuid) { - DELETE_TIMER.record(() -> { + public void delete(final UUID uuid) { + DELETE_TIMER.record(() -> getByAccountIdentifier(uuid).ifPresent(account -> { - getByAccountIdentifier(uuid).ifPresent(account -> { + final List transactWriteItems = new ArrayList<>(List.of( + buildDelete(phoneNumberConstraintTableName, ATTR_ACCOUNT_E164, account.getNumber()), + buildDelete(accountsTableName, KEY_ACCOUNT_UUID, uuid), + buildDelete(phoneNumberIdentifierConstraintTableName, ATTR_PNI_UUID, account.getPhoneNumberIdentifier()) + )); - TransactWriteItem phoneNumberDelete = TransactWriteItem.builder() - .delete(Delete.builder() - .tableName(phoneNumberConstraintTableName) - .key(Map.of(ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()))) - .build()) - .build(); + account.getUsername().ifPresent(username -> transactWriteItems.add( + buildDelete(usernamesConstraintTableName, ATTR_USERNAME, username))); - TransactWriteItem accountDelete = TransactWriteItem.builder() - .delete(Delete.builder() - .tableName(accountsTableName) - .key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid))) - .build()) - .build(); + final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() + .transactItems(transactWriteItems).build(); - final List transactWriteItems = new ArrayList<>(List.of(phoneNumberDelete, accountDelete)); - - transactWriteItems.add(TransactWriteItem.builder() - .delete(Delete.builder() - .tableName(phoneNumberIdentifierConstraintTableName) - .key(Map.of(ATTR_PNI_UUID, AttributeValues.fromUUID(account.getPhoneNumberIdentifier()))) - .build()) - .build()); - - account.getUsername().ifPresent(username -> transactWriteItems.add(TransactWriteItem.builder() - .delete(Delete.builder() - .tableName(usernamesConstraintTableName) - .key(Map.of(ATTR_USERNAME, AttributeValues.fromString(username))) - .build()) - .build())); - - TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() - .transactItems(transactWriteItems).build(); - - client.transactWriteItems(request); - }); - }); + db().transactWriteItems(request); + })); } + @Nonnull public AccountCrawlChunk getAllFrom(final UUID from, final int maxCount) { final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder() .limit(scanPageSize) @@ -833,6 +683,7 @@ public class Accounts extends AbstractDynamoDbStore { return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_OFFSET_TIMER); } + @Nonnull public AccountCrawlChunk getAllFromStart(final int maxCount) { final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder() .limit(scanPageSize); @@ -840,34 +691,185 @@ public class Accounts extends AbstractDynamoDbStore { return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_START_TIMER); } - private static CompletionStage record(final Timer timer, Supplier> toRecord) { - final Instant start = Instant.now(); - return toRecord.get().whenComplete((ignoreT, ignoreE) -> timer.record(Duration.between(start, Instant.now()))); + @Nonnull + private Optional getByIndirectLookup( + final Timer timer, + final String tableName, + final String keyName, + final AttributeValue keyValue) { + return getByIndirectLookup(timer, tableName, keyName, keyValue, i -> true); } + @Nonnull + private Optional getByIndirectLookup( + final Timer timer, + final String tableName, + final String keyName, + final AttributeValue keyValue, + final Predicate> predicate) { + + return requireNonNull(timer.record(() -> itemByKey(tableName, keyName, keyValue) + .filter(predicate) + .map(item -> item.get(KEY_ACCOUNT_UUID)) + .flatMap(uuid -> itemByKey(accountsTableName, KEY_ACCOUNT_UUID, uuid)) + .map(Accounts::fromItem))); + } + + @Nonnull + private Optional> itemByKey(final String table, final String keyName, final AttributeValue keyValue) { + final GetItemResponse response = db().getItem(GetItemRequest.builder() + .tableName(table) + .key(Map.of(keyName, keyValue)) + .consistentRead(true) + .build()); + return Optional.ofNullable(response.item()).filter(m -> !m.isEmpty()); + } + + @Nonnull + private TransactWriteItem buildAccountPut( + final Account account, + final AttributeValue uuidAttr, + final AttributeValue numberAttr, + final AttributeValue pniUuidAttr) throws JsonProcessingException { + + final Map item = new HashMap<>(Map.of( + KEY_ACCOUNT_UUID, uuidAttr, + ATTR_ACCOUNT_E164, numberAttr, + ATTR_PNI_UUID, pniUuidAttr, + ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)), + ATTR_VERSION, AttributeValues.fromInt(account.getVersion()), + ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.shouldBeVisibleInDirectory()))); + + // Add the UAK if it's in the account + account.getUnidentifiedAccessKey() + .map(AttributeValues::fromByteArray) + .ifPresent(uak -> item.put(ATTR_UAK, uak)); + + return TransactWriteItem.builder() + .put(Put.builder() + .conditionExpression("attribute_not_exists(#number) OR #number = :number") + .expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164)) + .expressionAttributeValues(Map.of(":number", numberAttr)) + .tableName(accountsTableName) + .item(item) + .build()) + .build(); + } + + @Nonnull + private static TransactWriteItem buildConstraintTablePutIfAbsent( + final String tableName, + final AttributeValue uuidAttr, + final String keyName, + final AttributeValue keyValue + ) { + return TransactWriteItem.builder() + .put(Put.builder() + .tableName(tableName) + .item(Map.of( + keyName, keyValue, + KEY_ACCOUNT_UUID, uuidAttr)) + .conditionExpression( + "attribute_not_exists(#key) OR #uuid = :uuid") + .expressionAttributeNames(Map.of( + "#key", keyName, + "#uuid", KEY_ACCOUNT_UUID)) + .expressionAttributeValues(Map.of( + ":uuid", uuidAttr)) + .returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD) + .build()) + .build(); + } + + @Nonnull + private static TransactWriteItem buildConstraintTablePut( + final String tableName, + final AttributeValue uuidAttr, + final String keyName, + final AttributeValue keyValue) { + return TransactWriteItem.builder() + .put(Put.builder() + .tableName(tableName) + .item(Map.of( + keyName, keyValue, + KEY_ACCOUNT_UUID, uuidAttr)) + .conditionExpression( + "attribute_not_exists(#key)") + .expressionAttributeNames(Map.of( + "#key", keyName)) + .returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD) + .build()) + .build(); + } + + @Nonnull + private static TransactWriteItem buildDelete(final String tableName, final String keyName, final String keyValue) { + return buildDelete(tableName, keyName, AttributeValues.fromString(keyValue)); + } + + @Nonnull + private static TransactWriteItem buildDelete(final String tableName, final String keyName, final UUID keyValue) { + return buildDelete(tableName, keyName, AttributeValues.fromUUID(keyValue)); + } + + @Nonnull + private static TransactWriteItem buildDelete(final String tableName, final String keyName, final AttributeValue keyValue) { + return TransactWriteItem.builder() + .delete(Delete.builder() + .tableName(tableName) + .key(Map.of(keyName, keyValue)) + .build()) + .build(); + } + + @Nonnull + private static CompletionStage record(final Timer timer, final Supplier> toRecord) { + final Timer.Sample sample = Timer.start(); + return toRecord.get().whenComplete((ignoreT, ignoreE) -> sample.stop(timer)); + } + + @Nonnull private AccountCrawlChunk scanForChunk(final ScanRequest.Builder scanRequestBuilder, final int maxCount, final Timer timer) { scanRequestBuilder.tableName(accountsTableName); - final List> items = timer.record(() -> scan(scanRequestBuilder.build(), maxCount)); + final List> items = requireNonNull(timer.record(() -> scan(scanRequestBuilder.build(), maxCount))); final List accounts = items.stream().map(Accounts::fromItem).toList(); return new AccountCrawlChunk(accounts, accounts.size() > 0 ? accounts.get(accounts.size() - 1).getUuid() : null); } + @Nonnull private static String extractCancellationReasonCodes(final TransactionCanceledException exception) { return exception.cancellationReasons().stream() .map(CancellationReason::code) .collect(Collectors.joining(", ")); } + @Nonnull + public static byte[] reservedUsernameHash(final UUID accountId, final String reservedUsername) { + final MessageDigest sha256; + try { + sha256 = MessageDigest.getInstance("SHA-256"); + } catch (final NoSuchAlgorithmException e) { + throw new AssertionError(e); + } + final ByteBuffer byteBuffer = ByteBuffer.allocate(32 + 1); + sha256.update(reservedUsername.getBytes(StandardCharsets.UTF_8)); + sha256.update(UUIDUtil.toBytes(accountId)); + byteBuffer.put(RESERVED_USERNAME_HASH_VERSION); + byteBuffer.put(sha256.digest()); + return byteBuffer.array(); + } + @VisibleForTesting - static Account fromItem(Map item) { - if (!item.containsKey(ATTR_ACCOUNT_DATA) || - !item.containsKey(ATTR_ACCOUNT_E164) || - // TODO: eventually require ATTR_CANONICALLY_DISCOVERABLE - !item.containsKey(KEY_ACCOUNT_UUID)) { + @Nonnull + static Account fromItem(final Map item) { + // TODO: eventually require ATTR_CANONICALLY_DISCOVERABLE + if (!item.containsKey(ATTR_ACCOUNT_DATA) + || !item.containsKey(ATTR_ACCOUNT_E164) + || !item.containsKey(KEY_ACCOUNT_UUID)) { throw new RuntimeException("item missing values"); } try { - Account account = SystemMapper.getMapper().readValue(item.get(ATTR_ACCOUNT_DATA).b().asByteArray(), Account.class); + final Account account = SystemMapper.getMapper().readValue(item.get(ATTR_ACCOUNT_DATA).b().asByteArray(), Account.class); final UUID accountIdentifier = UUIDUtil.fromByteBuffer(item.get(KEY_ACCOUNT_UUID).b().asByteBuffer()); final UUID phoneNumberIdentifierFromAttribute = AttributeValues.getUUID(item, ATTR_PNI_UUID, null); @@ -883,12 +885,18 @@ public class Accounts extends AbstractDynamoDbStore { account.setUuid(accountIdentifier); account.setUsername(AttributeValues.getString(item, ATTR_USERNAME, null)); account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n())); - account.setCanonicallyDiscoverable(Optional.ofNullable(item.get(ATTR_CANONICALLY_DISCOVERABLE)).map(av -> av.bool()).orElse(false)); + account.setCanonicallyDiscoverable(Optional.ofNullable(item.get(ATTR_CANONICALLY_DISCOVERABLE)) + .map(AttributeValue::bool) + .orElse(false)); return account; - } catch (IOException e) { + } catch (final IOException e) { throw new RuntimeException("Could not read stored account data", e); } } + + private static boolean conditionalCheckFailed(final CancellationReason reason) { + return CONDITIONAL_CHECK_FAILED.equals(reason.code()); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java index 91af572f3..f63ea86ff 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java @@ -6,5 +6,6 @@ package org.whispersystems.textsecuregcm.storage; public interface PubSubAddress { - public String serialize(); + + String serialize(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/ExceptionUtils.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/ExceptionUtils.java new file mode 100644 index 000000000..67d74def2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/ExceptionUtils.java @@ -0,0 +1,40 @@ +package org.whispersystems.textsecuregcm.util; + +import java.util.concurrent.CompletionException; + +public final class ExceptionUtils { + + private ExceptionUtils() { + // utility class + } + + /** + * Extracts the cause of a {@link CompletionException}. If the given {@code throwable} is a + * {@code CompletionException}, this method will recursively iterate through its causal chain until it finds the first + * cause that is not a {@code CompletionException}. If the last {@code CompletionException} in the causal chain has a + * {@code null} cause, then this method returns the last {@code CompletionException} in the chain. If the given + * {@code throwable} is not a {@code CompletionException}, then this method returns the original {@code throwable}. + * + * @param throwable the throwable to "unwrap" + * + * @return the first entity in the given {@code throwable}'s causal chain that is not a {@code CompletionException} + */ + public static Throwable unwrap(Throwable throwable) { + while (throwable instanceof CompletionException e && throwable.getCause() != null) { + throwable = e.getCause(); + } + return throwable; + } + + /** + * Wraps the given {@code throwable} in a {@link CompletionException} unless the given {@code throwable} is alreadt + * a {@code CompletionException}, in which case this method returns the original throwable. + * + * @param throwable the throwable to wrap in a {@code CompletionException} + */ + public static CompletionException wrap(final Throwable throwable) { + return throwable instanceof CompletionException completionException + ? completionException + : new CompletionException(throwable); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/SystemMapper.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/SystemMapper.java index 340617c47..7c3e6b667 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/SystemMapper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/SystemMapper.java @@ -10,20 +10,25 @@ import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import javax.annotation.Nonnull; public class SystemMapper { - private static final ObjectMapper mapper = new ObjectMapper(); + private static final ObjectMapper MAPPER = build(); - static { + + @Nonnull + public static ObjectMapper getMapper() { + return MAPPER; + } + + @Nonnull + private static ObjectMapper build() { + final ObjectMapper mapper = new ObjectMapper(); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); mapper.registerModule(new JavaTimeModule()); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - } - - public static ObjectMapper getMapper() { return mapper; } - } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java index 8bf20d7ad..eeadff1a9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java @@ -141,7 +141,7 @@ public class AssignUsernameCommand extends EnvironmentCommand