Use the async dynamo client to batch uak updates

This commit is contained in:
Ravi Khadiwala 2022-03-09 23:28:55 -06:00 committed by ravi-signal
parent de68c251f8
commit 5a88ff0811
9 changed files with 190 additions and 73 deletions

View File

@ -339,7 +339,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getAppConfig().getConfigurationName(),
DynamicConfiguration.class);
Accounts accounts = new Accounts(dynamicConfigurationManager, dynamoDbClient,
Accounts accounts = new Accounts(dynamicConfigurationManager,
dynamoDbClient,
dynamoDbAsyncClient,
config.getDynamoDbTables().getAccounts().getTableName(),
config.getDynamoDbTables().getAccounts().getPhoneNumberTableName(),
config.getDynamoDbTables().getAccounts().getPhoneNumberIdentifierTableName(),

View File

@ -6,7 +6,14 @@ public class DynamicUakMigrationConfiguration {
@JsonProperty
private boolean enabled = true;
@JsonProperty
private int maxOutstandingNormalizes = 25;
public boolean isEnabled() {
return enabled;
}
public int getMaxOutstandingNormalizes() {
return maxOutstandingNormalizes;
}
}

View File

@ -8,39 +8,47 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import com.google.common.collect.Lists;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicUakMigrationConfiguration;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.BatchExecuteStatementRequest;
import software.amazon.awssdk.services.dynamodb.model.BatchExecuteStatementResponse;
import software.amazon.awssdk.services.dynamodb.model.BatchStatementError;
import software.amazon.awssdk.services.dynamodb.model.BatchStatementRequest;
import software.amazon.awssdk.services.dynamodb.model.BatchStatementResponse;
import software.amazon.awssdk.services.dynamodb.model.CancellationReason;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.Delete;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughputExceededException;
import software.amazon.awssdk.services.dynamodb.model.Put;
import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
@ -50,7 +58,7 @@ import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledExcepti
import software.amazon.awssdk.services.dynamodb.model.TransactionConflictException;
import software.amazon.awssdk.services.dynamodb.model.Update;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemResponse;
import software.amazon.awssdk.utils.CompletableFutureUtils;
public class Accounts extends AbstractDynamoDbStore {
@ -71,8 +79,9 @@ public class Accounts extends AbstractDynamoDbStore {
// unidentified access key; byte[] or null
static final String ATTR_UAK = "UAK";
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final DynamoDbClient client;
private final DynamoDbAsyncClient asyncClient;
private final String phoneNumberConstraintTableName;
private final String phoneNumberIdentifierConstraintTableName;
@ -96,18 +105,21 @@ public class Accounts extends AbstractDynamoDbStore {
private static final Timer NORMALIZE_ITEM_TIMER = Metrics.timer(name(Accounts.class, "normalizeItem"));
private static final Counter UAK_NORMALIZE_SUCCESS_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakSuccess"));
private static final Counter UAK_NORMALIZE_ERROR_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakError"));
private static final String UAK_NORMALIZE_ERROR_NAME = name(Accounts.class, "normalizeUakError");
private static final String UAK_NORMALIZE_FAILURE_REASON_TAG_NAME = "reason";
private static final Logger log = LoggerFactory.getLogger(Accounts.class);
public Accounts(final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
DynamoDbClient client, String accountsTableName, String phoneNumberConstraintTableName,
DynamoDbClient client, DynamoDbAsyncClient asyncClient,
String accountsTableName, String phoneNumberConstraintTableName,
String phoneNumberIdentifierConstraintTableName, final String usernamesConstraintTableName,
final int scanPageSize) {
super(client);
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.client = client;
this.asyncClient = asyncClient;
this.phoneNumberConstraintTableName = phoneNumberConstraintTableName;
this.phoneNumberIdentifierConstraintTableName = phoneNumberIdentifierConstraintTableName;
this.accountsTableName = accountsTableName;
@ -469,10 +481,19 @@ public class Accounts extends AbstractDynamoDbStore {
});
}
public void update(Account account) throws ContestedOptimisticLockException {
UPDATE_TIMER.record(() -> {
final UpdateItemRequest updateItemRequest;
/**
* Extract the cause from a CompletionException
*/
private static Throwable unwrap(Throwable throwable) {
while (throwable instanceof CompletionException e && throwable.getCause() != null) {
throwable = e.getCause();
}
return throwable;
}
public CompletionStage<Void> updateAsync(Account account) {
return record(UPDATE_TIMER, () -> {
final UpdateItemRequest updateItemRequest;
try {
// username, e164, and pni cannot be modified through this method
Map<String, String> attrNames = new HashMap<>(Map.of(
@ -508,21 +529,41 @@ public class Accounts extends AbstractDynamoDbStore {
throw new IllegalArgumentException(e);
}
try {
final UpdateItemResponse response = client.updateItem(updateItemRequest);
account.setVersion(AttributeValues.getInt(response.attributes(), "V", account.getVersion() + 1));
} catch (final TransactionConflictException e) {
throw new ContestedOptimisticLockException();
} catch (final ConditionalCheckFailedException e) {
// the exception doesn't give details about which condition failed,
// but we can infer it was an optimistic locking failure if the UUID is known
throw getByAccountIdentifier(account.getUuid()).isPresent() ? new ContestedOptimisticLockException() : e;
}
return asyncClient.updateItem(updateItemRequest)
.thenApply(response -> {
account.setVersion(AttributeValues.getInt(response.attributes(), "V", account.getVersion() + 1));
return (Void) null;
})
.exceptionally(throwable -> {
final Throwable unwrapped = unwrap(throwable);
if (unwrapped instanceof TransactionConflictException) {
throw new ContestedOptimisticLockException();
} else if (unwrapped instanceof ConditionalCheckFailedException e) {
// the exception doesn't give details about which condition failed,
// but we can infer it was an optimistic locking failure if the UUID is known
throw getByAccountIdentifier(account.getUuid()).isPresent() ? new ContestedOptimisticLockException() : e;
} else {
// rethrow
throw CompletableFutureUtils.errorAsCompletionException(throwable);
}
});
});
}
public void update(Account account) throws ContestedOptimisticLockException {
try {
this.updateAsync(account).toCompletableFuture().join();
} catch (CompletionException e) {
// unwrap CompletionExceptions, throw as long is it's unchecked
Throwables.throwIfUnchecked(unwrap(e));
// if we otherwise somehow got a wrapped checked exception,
// rethrow the checked exception wrapped by the original CompletionException
log.error("Unexpected checked exception thrown from dynamo update", e);
throw e;
}
}
public Optional<Account> getByE164(String number) {
return GET_BY_NUMBER_TIMER.record(() -> {
@ -641,6 +682,11 @@ public class Accounts extends AbstractDynamoDbStore {
return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_START_TIMER);
}
private static <T> CompletionStage<T> record(final Timer timer, Supplier<CompletionStage<T>> toRecord) {
final Instant start = Instant.now();
return toRecord.get().whenComplete((ignoreT, ignoreE) -> timer.record(Duration.between(start, Instant.now())));
}
private List<Account> normalizeIfRequired(final List<Map<String, AttributeValue>> items) {
// The UAK top-level attribute may not exist on older records,
@ -653,52 +699,62 @@ public class Accounts extends AbstractDynamoDbStore {
final Account account = fromItem(item);
allAccounts.add(account);
if (!item.containsKey(ATTR_UAK) && account.getUnidentifiedAccessKey().isPresent()) {
boolean hasAttrUak = item.containsKey(ATTR_UAK);
if (!hasAttrUak && account.getUnidentifiedAccessKey().isPresent()) {
// the top level uak attribute doesn't exist, but there's a uak in the account
accountsToNormalize.add(account);
} else if (hasAttrUak && account.getUnidentifiedAccessKey().isPresent()) {
final AttributeValue attr = item.get(ATTR_UAK);
final byte[] nestedUak = account.getUnidentifiedAccessKey().get();
if (!Arrays.equals(attr.b().asByteArray(), nestedUak)) {
log.warn("Discovered mismatch between attribute UAK data UAK, normalizing");
accountsToNormalize.add(account);
}
}
}
if (!this.dynamicConfigurationManager.getConfiguration().getUakMigrationConfiguration().isEnabled()) {
final DynamicUakMigrationConfiguration currentConfig = this.dynamicConfigurationManager.getConfiguration().getUakMigrationConfiguration();
if (!currentConfig.isEnabled()) {
log.debug("Account normalization is disabled, skipping normalization for {} accounts", accountsToNormalize.size());
return allAccounts;
}
final int BATCH_SIZE = 25; // dynamodb max batch size
final String updateUakStatement = String.format("UPDATE %s SET %s = ? WHERE %s = ?", accountsTableName, ATTR_UAK, KEY_ACCOUNT_UUID);
for (List<Account> toNormalize : Lists.partition(accountsToNormalize, BATCH_SIZE)) {
NORMALIZE_ITEM_TIMER.record(() -> {
try {
final List<BatchStatementRequest> updateStatements = toNormalize.stream()
.map(account -> BatchStatementRequest.builder()
.statement(updateUakStatement)
.parameters(
AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()),
AttributeValues.fromUUID(account.getUuid()))
.build())
.toList();
for (List<Account> accounts : Lists.partition(accountsToNormalize, currentConfig.getMaxOutstandingNormalizes())) {
try {
final CompletableFuture<?>[] accountFutures = accounts.stream()
.map(account -> record(NORMALIZE_ITEM_TIMER,
() -> this.updateAsync(account).whenComplete((result, throwable) -> {
if (throwable == null) {
UAK_NORMALIZE_SUCCESS_COUNT.increment();
return;
}
final BatchExecuteStatementResponse result = client.batchExecuteStatement(BatchExecuteStatementRequest
.builder()
.statements(updateStatements)
.build());
throwable = unwrap(throwable);
if (throwable instanceof ContestedOptimisticLockException) {
// Could succeed on retry, but just backoff since this is a housekeeping operation
Metrics.counter(UAK_NORMALIZE_ERROR_NAME,
Tags.of(UAK_NORMALIZE_FAILURE_REASON_TAG_NAME, "ContestedOptimisticLock")).increment();
} else if (throwable instanceof ProvisionedThroughputExceededException) {
Metrics.counter(UAK_NORMALIZE_ERROR_NAME,
Tags.of(UAK_NORMALIZE_FAILURE_REASON_TAG_NAME, "ProvisionedThroughPutExceeded"))
.increment();
} else {
log.warn("Failed to normalize account, skipping", throwable);
Metrics.counter(UAK_NORMALIZE_ERROR_NAME,
Tags.of(UAK_NORMALIZE_FAILURE_REASON_TAG_NAME, "unknown"))
.increment();
}
})).toCompletableFuture()).toArray(CompletableFuture[]::new);
final Map<String, Long> errors = result.responses().stream()
.map(BatchStatementResponse::error)
.filter(e -> e != null)
.collect(Collectors.groupingBy(BatchStatementError::codeAsString, Collectors.counting()));
final long errorCount = errors.values().stream().mapToLong(Long::longValue).sum();
UAK_NORMALIZE_SUCCESS_COUNT.increment(toNormalize.size() - errorCount);
UAK_NORMALIZE_ERROR_COUNT.increment(errorCount);
if (!errors.isEmpty()) {
log.warn("Failed to normalize account uaks in batch of {}, error codes: {}", toNormalize.size(), errors);
}
} catch (final Exception e) {
UAK_NORMALIZE_ERROR_COUNT.increment(toNormalize.size());
log.warn("Failed to normalize accounts in a batch of {}", toNormalize.size(), e);
}
});
// wait for a futures in batch to complete
CompletableFuture
.allOf(accountFutures)
// exceptions handled in individual futures
.exceptionally(e -> null)
.join();
} catch (Exception e) {
log.warn("Failed to update batch of {} accounts, skipping", accounts.size(), e);
}
}
return allAccounts;
}

View File

@ -135,7 +135,9 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
VerificationCodeStore pendingAccounts = new VerificationCodeStore(dynamoDbClient,
configuration.getDynamoDbTables().getPendingAccounts().getTableName());
Accounts accounts = new Accounts(dynamicConfigurationManager, dynamoDbClient,
Accounts accounts = new Accounts(dynamicConfigurationManager,
dynamoDbClient,
dynamoDbAsyncClient,
configuration.getDynamoDbTables().getAccounts().getTableName(),
configuration.getDynamoDbTables().getAccounts().getPhoneNumberTableName(),
configuration.getDynamoDbTables().getAccounts().getPhoneNumberIdentifierTableName(),

View File

@ -138,7 +138,9 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
VerificationCodeStore pendingAccounts = new VerificationCodeStore(dynamoDbClient,
configuration.getDynamoDbTables().getPendingAccounts().getTableName());
Accounts accounts = new Accounts(dynamicConfigurationManager, dynamoDbClient,
Accounts accounts = new Accounts(dynamicConfigurationManager,
dynamoDbClient,
dynamoDbAsyncClient,
configuration.getDynamoDbTables().getAccounts().getTableName(),
configuration.getDynamoDbTables().getAccounts().getPhoneNumberTableName(),
configuration.getDynamoDbTables().getAccounts().getPhoneNumberIdentifierTableName(),

View File

@ -141,7 +141,9 @@ public class SetUserDiscoverabilityCommand extends EnvironmentCommand<WhisperSer
VerificationCodeStore pendingAccounts = new VerificationCodeStore(dynamoDbClient,
configuration.getDynamoDbTables().getPendingAccounts().getTableName());
Accounts accounts = new Accounts(dynamicConfigurationManager, dynamoDbClient,
Accounts accounts = new Accounts(dynamicConfigurationManager,
dynamoDbClient,
dynamoDbAsyncClient,
configuration.getDynamoDbTables().getAccounts().getTableName(),
configuration.getDynamoDbTables().getAccounts().getPhoneNumberTableName(),
configuration.getDynamoDbTables().getAccounts().getPhoneNumberIdentifierTableName(),

View File

@ -154,6 +154,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final Accounts accounts = new Accounts(
dynamicConfigurationManager,
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient(),
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbAsyncClient(),
ACCOUNTS_DYNAMO_EXTENSION.getTableName(),
NUMBERS_TABLE_NAME,
PNI_ASSIGNMENT_TABLE_NAME,

View File

@ -126,6 +126,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
accounts = new Accounts(
dynamicConfigurationManager,
dynamoDbExtension.getDynamoDbClient(),
dynamoDbExtension.getDynamoDbAsyncClient(),
dynamoDbExtension.getTableName(),
NUMBERS_TABLE_NAME,
PNI_TABLE_NAME,

View File

@ -28,6 +28,8 @@ import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.jdbi.v3.core.transaction.TransactionException;
@ -42,6 +44,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -139,6 +142,7 @@ class AccountsTest {
this.accounts = new Accounts(
mockDynamicConfigManager,
dynamoDbExtension.getDynamoDbClient(),
dynamoDbExtension.getDynamoDbAsyncClient(),
dynamoDbExtension.getTableName(),
NUMBER_CONSTRAINT_TABLE_NAME,
PNI_CONSTRAINT_TABLE_NAME,
@ -377,15 +381,20 @@ class AccountsTest {
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), account, true);
}
@Test
void testUpdateWithMockTransactionConflictException() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testUpdateWithMockTransactionConflictException(boolean wrapException) {
final DynamoDbClient dynamoDbClient = mock(DynamoDbClient.class);
accounts = new Accounts(mockDynamicConfigManager, dynamoDbClient,
dynamoDbExtension.getTableName(), NUMBER_CONSTRAINT_TABLE_NAME, PNI_CONSTRAINT_TABLE_NAME, USERNAME_CONSTRAINT_TABLE_NAME, SCAN_PAGE_SIZE);
final DynamoDbAsyncClient dynamoDbAsyncClient = mock(DynamoDbAsyncClient.class);
accounts = new Accounts(mockDynamicConfigManager, mock(DynamoDbClient.class),
dynamoDbAsyncClient, dynamoDbExtension.getTableName(),
NUMBER_CONSTRAINT_TABLE_NAME, PNI_CONSTRAINT_TABLE_NAME, USERNAME_CONSTRAINT_TABLE_NAME, SCAN_PAGE_SIZE);
when(dynamoDbClient.updateItem(any(UpdateItemRequest.class)))
.thenThrow(TransactionConflictException.class);
Exception e = TransactionConflictException.builder().build();
e = wrapException ? new CompletionException(e) : e;
when(dynamoDbAsyncClient.updateItem(any(UpdateItemRequest.class)))
.thenReturn(CompletableFuture.failedFuture(e));
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID());
@ -512,14 +521,15 @@ class AccountsTest {
configuration.setFailureRateThreshold(50);
final DynamoDbClient client = mock(DynamoDbClient.class);
final DynamoDbAsyncClient asyncClient = mock(DynamoDbAsyncClient.class);
when(client.transactWriteItems(any(TransactWriteItemsRequest.class)))
.thenThrow(RuntimeException.class);
when(client.updateItem(any(UpdateItemRequest.class)))
.thenThrow(RuntimeException.class);
when(asyncClient.updateItem(any(UpdateItemRequest.class)))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
Accounts accounts = new Accounts(mockDynamicConfigManager, client, ACCOUNTS_TABLE_NAME, NUMBER_CONSTRAINT_TABLE_NAME,
Accounts accounts = new Accounts(mockDynamicConfigManager, client, asyncClient, ACCOUNTS_TABLE_NAME, NUMBER_CONSTRAINT_TABLE_NAME,
PNI_CONSTRAINT_TABLE_NAME, USERNAME_CONSTRAINT_TABLE_NAME, SCAN_PAGE_SIZE);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID());
@ -816,6 +826,40 @@ class AccountsTest {
assertThat(item).doesNotContainKey(Accounts.ATTR_UAK);
}
@Test
void testUakMismatch() {
// If there's a UAK mismatch, we should correct it
final UUID accountIdentifier = UUID.randomUUID();
final Account account = generateAccount("+18005551234", accountIdentifier, UUID.randomUUID());
accounts.create(account);
// set the uak to garbage in the attributes
dynamoDbExtension.getDynamoDbClient().updateItem(UpdateItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.expressionAttributeNames(Map.of("#uak", Accounts.ATTR_UAK))
.expressionAttributeValues(Map.of(":uak", AttributeValues.fromByteArray("bad-uak".getBytes())))
.updateExpression("SET #uak = :uak").build());
// crawling should return 1 account and fix the uak mismatch
final AccountCrawlChunk allFromStart = accounts.getAllFromStart(1);
assertThat(allFromStart.getAccounts()).hasSize(1);
assertThat(allFromStart.getAccounts().get(0).getUuid()).isEqualTo(accountIdentifier);
assertThat(allFromStart.getAccounts().get(0).getUnidentifiedAccessKey().get()).isEqualTo(account.getUnidentifiedAccessKey().get());
// the top level uak should be the original
final Map<String, AttributeValue> item = dynamoDbExtension.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.consistentRead(true)
.build()).item();
assertThat(item).containsEntry(
Accounts.ATTR_UAK,
AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()));
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testAddMissingUakAttribute(boolean normalizeDisabled) throws JsonProcessingException {