diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 2cf36cfb2..bf37c3548 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -8,6 +8,8 @@ import static com.codahale.metrics.MetricRegistry.name; import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; +import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.io.IOException; @@ -28,6 +30,11 @@ import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.UUIDUtil; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.BatchExecuteStatementRequest; +import software.amazon.awssdk.services.dynamodb.model.BatchExecuteStatementResponse; +import software.amazon.awssdk.services.dynamodb.model.BatchStatementError; +import software.amazon.awssdk.services.dynamodb.model.BatchStatementRequest; +import software.amazon.awssdk.services.dynamodb.model.BatchStatementResponse; import software.amazon.awssdk.services.dynamodb.model.CancellationReason; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.Delete; @@ -84,6 +91,10 @@ public class Accounts extends AbstractDynamoDbStore { private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom")); private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset")); private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete")); + private static final Timer NORMALIZE_ITEM_TIMER = Metrics.timer(name(Accounts.class, "normalizeItem")); + + private static final Counter UAK_NORMALIZE_SUCCESS_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakSuccess")); + private static final Counter UAK_NORMALIZE_ERROR_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakError")); private static final Logger log = LoggerFactory.getLogger(Accounts.class); @@ -627,15 +638,67 @@ public class Accounts extends AbstractDynamoDbStore { return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_START_TIMER); } + private List normalizeIfRequired(final List> items) { + + // The UAK top-level attribute may not exist on older records, + // if it is absent and there is a UAK in the account blob we'll + // add the UAK as a top-level attribute + // TODO: Can eliminate this once all uaks exist as top-level attributes + final List allAccounts = new ArrayList<>(); + final List accountsToNormalize = new ArrayList<>(); + for (Map item : items) { + final Account account = fromItem(item); + allAccounts.add(account); + + if (!item.containsKey(ATTR_UAK) && account.getUnidentifiedAccessKey().isPresent()) { + // the top level uak attribute doesn't exist, but there's a uak in the account + accountsToNormalize.add(account); + } + } + + final int BATCH_SIZE = 25; // dynamodb max batch size + final String updateUakStatement = String.format("UPDATE %s SET %s = ? WHERE %s = ?", accountsTableName, ATTR_UAK, KEY_ACCOUNT_UUID); + for (List toNormalize : Lists.partition(accountsToNormalize, BATCH_SIZE)) { + NORMALIZE_ITEM_TIMER.record(() -> { + try { + final List updateStatements = toNormalize.stream() + .map(account -> BatchStatementRequest.builder() + .statement(updateUakStatement) + .parameters( + AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()), + AttributeValues.fromUUID(account.getUuid())) + .build()) + .toList(); + + final BatchExecuteStatementResponse result = client.batchExecuteStatement(BatchExecuteStatementRequest + .builder() + .statements(updateStatements) + .build()); + + final Map errors = result.responses().stream() + .map(BatchStatementResponse::error) + .filter(e -> e != null) + .collect(Collectors.groupingBy(BatchStatementError::codeAsString, Collectors.counting())); + + final long errorCount = errors.values().stream().mapToLong(Long::longValue).sum(); + UAK_NORMALIZE_SUCCESS_COUNT.increment(toNormalize.size() - errorCount); + UAK_NORMALIZE_ERROR_COUNT.increment(errorCount); + if (!errors.isEmpty()) { + log.warn("Failed to normalize account uaks in batch of {}, error codes: {}", toNormalize.size(), errors); + } + } catch (final Exception e) { + UAK_NORMALIZE_ERROR_COUNT.increment(toNormalize.size()); + log.warn("Failed to normalize accounts in a batch of {}", toNormalize.size(), e); + } + }); + } + return allAccounts; + } + private AccountCrawlChunk scanForChunk(final ScanRequest.Builder scanRequestBuilder, final int maxCount, final Timer timer) { - scanRequestBuilder.tableName(accountsTableName); - - final List accounts = timer.record(() -> scan(scanRequestBuilder.build(), maxCount) - .stream() - .map(Accounts::fromItem) - .collect(Collectors.toList())); - + final List> items = timer.record(() -> scan(scanRequestBuilder.build(), maxCount)); + final List accounts = normalizeIfRequired(items); return new AccountCrawlChunk(accounts, accounts.size() > 0 ? accounts.get(accounts.size() - 1).getUuid() : null); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index 980bb504c..f6829f6b8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -28,17 +28,22 @@ import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.jdbi.v3.core.transaction.TransactionException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.SystemMapper; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; @@ -769,6 +774,115 @@ class AccountsTest { assertThat(account.getUsername()).hasValueSatisfying(u -> assertThat(u).isEqualTo(username)); } + @Test + void testAddUakMissingInJson() { + // If there's no uak in the json, we shouldn't add an attribute on crawl + final UUID accountIdentifier = UUID.randomUUID(); + + final Account account = generateAccount("+18005551234", accountIdentifier, UUID.randomUUID()); + account.setUnidentifiedAccessKey(null); + accounts.create(account); + + // there should be no top level uak + Map item = dynamoDbExtension.getDynamoDbClient() + .getItem(GetItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier))) + .consistentRead(true) + .build()).item(); + assertThat(item).doesNotContainKey(Accounts.ATTR_UAK); + + // crawling should return 1 account + final AccountCrawlChunk allFromStart = accounts.getAllFromStart(1); + assertThat(allFromStart.getAccounts()).hasSize(1); + assertThat(allFromStart.getAccounts().get(0).getUuid()).isEqualTo(accountIdentifier); + + // there should still be no top level uak + item = dynamoDbExtension.getDynamoDbClient() + .getItem(GetItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier))) + .consistentRead(true) + .build()).item(); + assertThat(item).doesNotContainKey(Accounts.ATTR_UAK); + } + + @Test + void testAddMissingUakAttribute() { + final UUID accountIdentifier = UUID.randomUUID(); + + final Account account = generateAccount("+18005551234", accountIdentifier, UUID.randomUUID()); + accounts.create(account); + + // remove the top level uak (simulates old format) + dynamoDbExtension.getDynamoDbClient().updateItem(UpdateItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier))) + .expressionAttributeNames(Map.of("#uak", Accounts.ATTR_UAK)) + .updateExpression("REMOVE #uak").build()); + + // crawling should return 1 account, and fix the discrepancy between + // the json blob and the top level attributes + final AccountCrawlChunk allFromStart = accounts.getAllFromStart(1); + assertThat(allFromStart.getAccounts()).hasSize(1); + assertThat(allFromStart.getAccounts().get(0).getUuid()).isEqualTo(accountIdentifier); + + // check that the attribute now exists at top level + final Map item = dynamoDbExtension.getDynamoDbClient() + .getItem(GetItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier))) + .consistentRead(true) + .build()).item(); + assertThat(item).containsEntry(Accounts.ATTR_UAK, + AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get())); + } + + @ParameterizedTest + @ValueSource(ints = {24, 25, 26, 101}) + void testAddMissingUakAttributeBatched(int n) { + // generate N + 5 accounts + List allAccounts = IntStream.range(0, n + 5) + .mapToObj(i -> generateAccount(String.format("+1800555%04d", i), UUID.randomUUID(), UUID.randomUUID())) + .collect(Collectors.toList()); + allAccounts.forEach(accounts::create); + + // delete the UAK on n of them + Collections.shuffle(allAccounts); + allAccounts.stream().limit(n).forEach(account -> + dynamoDbExtension.getDynamoDbClient().updateItem(UpdateItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) + .expressionAttributeNames(Map.of("#uak", Accounts.ATTR_UAK)) + .updateExpression("REMOVE #uak") + .build())); + + // crawling should fix the discrepancy between + // the json blob and the top level attributes + AccountCrawlChunk chunk = accounts.getAllFromStart(7); + long verifiedCount = 0; + while (true) { + for (Account account : chunk.getAccounts()) { + // check that the attribute now exists at top level + final Map item = dynamoDbExtension.getDynamoDbClient() + .getItem(GetItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) + .consistentRead(true) + .build()).item(); + assertThat(item).containsEntry(Accounts.ATTR_UAK, + AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get())); + verifiedCount++; + } + if (chunk.getLastUuid().isPresent()) { + chunk = accounts.getAllFrom(chunk.getLastUuid().get(), 7); + } else { + break; + } + } + assertThat(verifiedCount).isEqualTo(n + 5); + } + private Device generateDevice(long id) { Random random = new Random(System.currentTimeMillis()); SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(),