Add a top-level uak to existing items

Items wirtten before we started storing the uak at
the top level only store the uak in the
account blob. The will be updated on account
crawl
This commit is contained in:
Ravi Khadiwala 2022-03-01 14:59:26 -06:00 committed by ravi-signal
parent 6283f5952d
commit 9cb098ad8a
2 changed files with 184 additions and 7 deletions

View File

@ -8,6 +8,8 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
@ -28,6 +30,11 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.BatchExecuteStatementRequest;
import software.amazon.awssdk.services.dynamodb.model.BatchExecuteStatementResponse;
import software.amazon.awssdk.services.dynamodb.model.BatchStatementError;
import software.amazon.awssdk.services.dynamodb.model.BatchStatementRequest;
import software.amazon.awssdk.services.dynamodb.model.BatchStatementResponse;
import software.amazon.awssdk.services.dynamodb.model.CancellationReason;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.Delete;
@ -84,6 +91,10 @@ public class Accounts extends AbstractDynamoDbStore {
private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom"));
private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset"));
private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete"));
private static final Timer NORMALIZE_ITEM_TIMER = Metrics.timer(name(Accounts.class, "normalizeItem"));
private static final Counter UAK_NORMALIZE_SUCCESS_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakSuccess"));
private static final Counter UAK_NORMALIZE_ERROR_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakError"));
private static final Logger log = LoggerFactory.getLogger(Accounts.class);
@ -627,15 +638,67 @@ public class Accounts extends AbstractDynamoDbStore {
return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_START_TIMER);
}
private List<Account> normalizeIfRequired(final List<Map<String, AttributeValue>> items) {
// The UAK top-level attribute may not exist on older records,
// if it is absent and there is a UAK in the account blob we'll
// add the UAK as a top-level attribute
// TODO: Can eliminate this once all uaks exist as top-level attributes
final List<Account> allAccounts = new ArrayList<>();
final List<Account> accountsToNormalize = new ArrayList<>();
for (Map<String, AttributeValue> item : items) {
final Account account = fromItem(item);
allAccounts.add(account);
if (!item.containsKey(ATTR_UAK) && account.getUnidentifiedAccessKey().isPresent()) {
// the top level uak attribute doesn't exist, but there's a uak in the account
accountsToNormalize.add(account);
}
}
final int BATCH_SIZE = 25; // dynamodb max batch size
final String updateUakStatement = String.format("UPDATE %s SET %s = ? WHERE %s = ?", accountsTableName, ATTR_UAK, KEY_ACCOUNT_UUID);
for (List<Account> toNormalize : Lists.partition(accountsToNormalize, BATCH_SIZE)) {
NORMALIZE_ITEM_TIMER.record(() -> {
try {
final List<BatchStatementRequest> updateStatements = toNormalize.stream()
.map(account -> BatchStatementRequest.builder()
.statement(updateUakStatement)
.parameters(
AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()),
AttributeValues.fromUUID(account.getUuid()))
.build())
.toList();
final BatchExecuteStatementResponse result = client.batchExecuteStatement(BatchExecuteStatementRequest
.builder()
.statements(updateStatements)
.build());
final Map<String, Long> errors = result.responses().stream()
.map(BatchStatementResponse::error)
.filter(e -> e != null)
.collect(Collectors.groupingBy(BatchStatementError::codeAsString, Collectors.counting()));
final long errorCount = errors.values().stream().mapToLong(Long::longValue).sum();
UAK_NORMALIZE_SUCCESS_COUNT.increment(toNormalize.size() - errorCount);
UAK_NORMALIZE_ERROR_COUNT.increment(errorCount);
if (!errors.isEmpty()) {
log.warn("Failed to normalize account uaks in batch of {}, error codes: {}", toNormalize.size(), errors);
}
} catch (final Exception e) {
UAK_NORMALIZE_ERROR_COUNT.increment(toNormalize.size());
log.warn("Failed to normalize accounts in a batch of {}", toNormalize.size(), e);
}
});
}
return allAccounts;
}
private AccountCrawlChunk scanForChunk(final ScanRequest.Builder scanRequestBuilder, final int maxCount, final Timer timer) {
scanRequestBuilder.tableName(accountsTableName);
final List<Account> accounts = timer.record(() -> scan(scanRequestBuilder.build(), maxCount)
.stream()
.map(Accounts::fromItem)
.collect(Collectors.toList()));
final List<Map<String, AttributeValue>> items = timer.record(() -> scan(scanRequestBuilder.build(), maxCount));
final List<Account> accounts = normalizeIfRequired(items);
return new AccountCrawlChunk(accounts, accounts.size() > 0 ? accounts.get(accounts.size() - 1).getUuid() : null);
}

View File

@ -28,17 +28,22 @@ import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.jdbi.v3.core.transaction.TransactionException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
@ -769,6 +774,115 @@ class AccountsTest {
assertThat(account.getUsername()).hasValueSatisfying(u -> assertThat(u).isEqualTo(username));
}
@Test
void testAddUakMissingInJson() {
// If there's no uak in the json, we shouldn't add an attribute on crawl
final UUID accountIdentifier = UUID.randomUUID();
final Account account = generateAccount("+18005551234", accountIdentifier, UUID.randomUUID());
account.setUnidentifiedAccessKey(null);
accounts.create(account);
// there should be no top level uak
Map<String, AttributeValue> item = dynamoDbExtension.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.consistentRead(true)
.build()).item();
assertThat(item).doesNotContainKey(Accounts.ATTR_UAK);
// crawling should return 1 account
final AccountCrawlChunk allFromStart = accounts.getAllFromStart(1);
assertThat(allFromStart.getAccounts()).hasSize(1);
assertThat(allFromStart.getAccounts().get(0).getUuid()).isEqualTo(accountIdentifier);
// there should still be no top level uak
item = dynamoDbExtension.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.consistentRead(true)
.build()).item();
assertThat(item).doesNotContainKey(Accounts.ATTR_UAK);
}
@Test
void testAddMissingUakAttribute() {
final UUID accountIdentifier = UUID.randomUUID();
final Account account = generateAccount("+18005551234", accountIdentifier, UUID.randomUUID());
accounts.create(account);
// remove the top level uak (simulates old format)
dynamoDbExtension.getDynamoDbClient().updateItem(UpdateItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.expressionAttributeNames(Map.of("#uak", Accounts.ATTR_UAK))
.updateExpression("REMOVE #uak").build());
// crawling should return 1 account, and fix the discrepancy between
// the json blob and the top level attributes
final AccountCrawlChunk allFromStart = accounts.getAllFromStart(1);
assertThat(allFromStart.getAccounts()).hasSize(1);
assertThat(allFromStart.getAccounts().get(0).getUuid()).isEqualTo(accountIdentifier);
// check that the attribute now exists at top level
final Map<String, AttributeValue> item = dynamoDbExtension.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.consistentRead(true)
.build()).item();
assertThat(item).containsEntry(Accounts.ATTR_UAK,
AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()));
}
@ParameterizedTest
@ValueSource(ints = {24, 25, 26, 101})
void testAddMissingUakAttributeBatched(int n) {
// generate N + 5 accounts
List<Account> allAccounts = IntStream.range(0, n + 5)
.mapToObj(i -> generateAccount(String.format("+1800555%04d", i), UUID.randomUUID(), UUID.randomUUID()))
.collect(Collectors.toList());
allAccounts.forEach(accounts::create);
// delete the UAK on n of them
Collections.shuffle(allAccounts);
allAccounts.stream().limit(n).forEach(account ->
dynamoDbExtension.getDynamoDbClient().updateItem(UpdateItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.expressionAttributeNames(Map.of("#uak", Accounts.ATTR_UAK))
.updateExpression("REMOVE #uak")
.build()));
// crawling should fix the discrepancy between
// the json blob and the top level attributes
AccountCrawlChunk chunk = accounts.getAllFromStart(7);
long verifiedCount = 0;
while (true) {
for (Account account : chunk.getAccounts()) {
// check that the attribute now exists at top level
final Map<String, AttributeValue> item = dynamoDbExtension.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.consistentRead(true)
.build()).item();
assertThat(item).containsEntry(Accounts.ATTR_UAK,
AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()));
verifiedCount++;
}
if (chunk.getLastUuid().isPresent()) {
chunk = accounts.getAllFrom(chunk.getLastUuid().get(), 7);
} else {
break;
}
}
assertThat(verifiedCount).isEqualTo(n + 5);
}
private Device generateDevice(long id) {
Random random = new Random(System.currentTimeMillis());
SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(),