From 9cb098ad8a33ecec743ab896ba4c24ba97386ba0 Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Tue, 1 Mar 2022 14:59:26 -0600 Subject: [PATCH] Add a top-level uak to existing items Items wirtten before we started storing the uak at the top level only store the uak in the account blob. The will be updated on account crawl --- .../textsecuregcm/storage/Accounts.java | 77 ++++++++++-- .../textsecuregcm/storage/AccountsTest.java | 114 ++++++++++++++++++ 2 files changed, 184 insertions(+), 7 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 2cf36cfb2..bf37c3548 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -8,6 +8,8 @@ import static com.codahale.metrics.MetricRegistry.name; import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; +import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.io.IOException; @@ -28,6 +30,11 @@ import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.UUIDUtil; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.BatchExecuteStatementRequest; +import software.amazon.awssdk.services.dynamodb.model.BatchExecuteStatementResponse; +import software.amazon.awssdk.services.dynamodb.model.BatchStatementError; +import software.amazon.awssdk.services.dynamodb.model.BatchStatementRequest; +import software.amazon.awssdk.services.dynamodb.model.BatchStatementResponse; import software.amazon.awssdk.services.dynamodb.model.CancellationReason; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.Delete; @@ -84,6 +91,10 @@ public class Accounts extends AbstractDynamoDbStore { private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom")); private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset")); private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete")); + private static final Timer NORMALIZE_ITEM_TIMER = Metrics.timer(name(Accounts.class, "normalizeItem")); + + private static final Counter UAK_NORMALIZE_SUCCESS_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakSuccess")); + private static final Counter UAK_NORMALIZE_ERROR_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakError")); private static final Logger log = LoggerFactory.getLogger(Accounts.class); @@ -627,15 +638,67 @@ public class Accounts extends AbstractDynamoDbStore { return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_START_TIMER); } + private List normalizeIfRequired(final List> items) { + + // The UAK top-level attribute may not exist on older records, + // if it is absent and there is a UAK in the account blob we'll + // add the UAK as a top-level attribute + // TODO: Can eliminate this once all uaks exist as top-level attributes + final List allAccounts = new ArrayList<>(); + final List accountsToNormalize = new ArrayList<>(); + for (Map item : items) { + final Account account = fromItem(item); + allAccounts.add(account); + + if (!item.containsKey(ATTR_UAK) && account.getUnidentifiedAccessKey().isPresent()) { + // the top level uak attribute doesn't exist, but there's a uak in the account + accountsToNormalize.add(account); + } + } + + final int BATCH_SIZE = 25; // dynamodb max batch size + final String updateUakStatement = String.format("UPDATE %s SET %s = ? WHERE %s = ?", accountsTableName, ATTR_UAK, KEY_ACCOUNT_UUID); + for (List toNormalize : Lists.partition(accountsToNormalize, BATCH_SIZE)) { + NORMALIZE_ITEM_TIMER.record(() -> { + try { + final List updateStatements = toNormalize.stream() + .map(account -> BatchStatementRequest.builder() + .statement(updateUakStatement) + .parameters( + AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()), + AttributeValues.fromUUID(account.getUuid())) + .build()) + .toList(); + + final BatchExecuteStatementResponse result = client.batchExecuteStatement(BatchExecuteStatementRequest + .builder() + .statements(updateStatements) + .build()); + + final Map errors = result.responses().stream() + .map(BatchStatementResponse::error) + .filter(e -> e != null) + .collect(Collectors.groupingBy(BatchStatementError::codeAsString, Collectors.counting())); + + final long errorCount = errors.values().stream().mapToLong(Long::longValue).sum(); + UAK_NORMALIZE_SUCCESS_COUNT.increment(toNormalize.size() - errorCount); + UAK_NORMALIZE_ERROR_COUNT.increment(errorCount); + if (!errors.isEmpty()) { + log.warn("Failed to normalize account uaks in batch of {}, error codes: {}", toNormalize.size(), errors); + } + } catch (final Exception e) { + UAK_NORMALIZE_ERROR_COUNT.increment(toNormalize.size()); + log.warn("Failed to normalize accounts in a batch of {}", toNormalize.size(), e); + } + }); + } + return allAccounts; + } + private AccountCrawlChunk scanForChunk(final ScanRequest.Builder scanRequestBuilder, final int maxCount, final Timer timer) { - scanRequestBuilder.tableName(accountsTableName); - - final List accounts = timer.record(() -> scan(scanRequestBuilder.build(), maxCount) - .stream() - .map(Accounts::fromItem) - .collect(Collectors.toList())); - + final List> items = timer.record(() -> scan(scanRequestBuilder.build(), maxCount)); + final List accounts = normalizeIfRequired(items); return new AccountCrawlChunk(accounts, accounts.size() > 0 ? accounts.get(accounts.size() - 1).getUuid() : null); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index 980bb504c..f6829f6b8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -28,17 +28,22 @@ import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.jdbi.v3.core.transaction.TransactionException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.SystemMapper; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; @@ -769,6 +774,115 @@ class AccountsTest { assertThat(account.getUsername()).hasValueSatisfying(u -> assertThat(u).isEqualTo(username)); } + @Test + void testAddUakMissingInJson() { + // If there's no uak in the json, we shouldn't add an attribute on crawl + final UUID accountIdentifier = UUID.randomUUID(); + + final Account account = generateAccount("+18005551234", accountIdentifier, UUID.randomUUID()); + account.setUnidentifiedAccessKey(null); + accounts.create(account); + + // there should be no top level uak + Map item = dynamoDbExtension.getDynamoDbClient() + .getItem(GetItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier))) + .consistentRead(true) + .build()).item(); + assertThat(item).doesNotContainKey(Accounts.ATTR_UAK); + + // crawling should return 1 account + final AccountCrawlChunk allFromStart = accounts.getAllFromStart(1); + assertThat(allFromStart.getAccounts()).hasSize(1); + assertThat(allFromStart.getAccounts().get(0).getUuid()).isEqualTo(accountIdentifier); + + // there should still be no top level uak + item = dynamoDbExtension.getDynamoDbClient() + .getItem(GetItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier))) + .consistentRead(true) + .build()).item(); + assertThat(item).doesNotContainKey(Accounts.ATTR_UAK); + } + + @Test + void testAddMissingUakAttribute() { + final UUID accountIdentifier = UUID.randomUUID(); + + final Account account = generateAccount("+18005551234", accountIdentifier, UUID.randomUUID()); + accounts.create(account); + + // remove the top level uak (simulates old format) + dynamoDbExtension.getDynamoDbClient().updateItem(UpdateItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier))) + .expressionAttributeNames(Map.of("#uak", Accounts.ATTR_UAK)) + .updateExpression("REMOVE #uak").build()); + + // crawling should return 1 account, and fix the discrepancy between + // the json blob and the top level attributes + final AccountCrawlChunk allFromStart = accounts.getAllFromStart(1); + assertThat(allFromStart.getAccounts()).hasSize(1); + assertThat(allFromStart.getAccounts().get(0).getUuid()).isEqualTo(accountIdentifier); + + // check that the attribute now exists at top level + final Map item = dynamoDbExtension.getDynamoDbClient() + .getItem(GetItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier))) + .consistentRead(true) + .build()).item(); + assertThat(item).containsEntry(Accounts.ATTR_UAK, + AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get())); + } + + @ParameterizedTest + @ValueSource(ints = {24, 25, 26, 101}) + void testAddMissingUakAttributeBatched(int n) { + // generate N + 5 accounts + List allAccounts = IntStream.range(0, n + 5) + .mapToObj(i -> generateAccount(String.format("+1800555%04d", i), UUID.randomUUID(), UUID.randomUUID())) + .collect(Collectors.toList()); + allAccounts.forEach(accounts::create); + + // delete the UAK on n of them + Collections.shuffle(allAccounts); + allAccounts.stream().limit(n).forEach(account -> + dynamoDbExtension.getDynamoDbClient().updateItem(UpdateItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) + .expressionAttributeNames(Map.of("#uak", Accounts.ATTR_UAK)) + .updateExpression("REMOVE #uak") + .build())); + + // crawling should fix the discrepancy between + // the json blob and the top level attributes + AccountCrawlChunk chunk = accounts.getAllFromStart(7); + long verifiedCount = 0; + while (true) { + for (Account account : chunk.getAccounts()) { + // check that the attribute now exists at top level + final Map item = dynamoDbExtension.getDynamoDbClient() + .getItem(GetItemRequest.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) + .consistentRead(true) + .build()).item(); + assertThat(item).containsEntry(Accounts.ATTR_UAK, + AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get())); + verifiedCount++; + } + if (chunk.getLastUuid().isPresent()) { + chunk = accounts.getAllFrom(chunk.getLastUuid().get(), 7); + } else { + break; + } + } + assertThat(verifiedCount).isEqualTo(n + 5); + } + private Device generateDevice(long id) { Random random = new Random(System.currentTimeMillis()); SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(),