Add a method for iterating across all accounts

This commit is contained in:
Jon Chambers 2023-06-20 11:42:22 -04:00 committed by Jon Chambers
parent 97710540c0
commit 06997e19e0
3 changed files with 42 additions and 0 deletions

View File

@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -677,6 +678,23 @@ public class Accounts extends AbstractDynamoDbStore {
})); }));
} }
Flux<Account> getAll(final int segments) {
if (segments < 1) {
throw new IllegalArgumentException("Total number of segments must be positive");
}
return Flux.merge(
Flux.range(0, segments)
.map(segment -> asyncClient.scanPaginator(ScanRequest.builder()
.tableName(accountsTableName)
.consistentRead(true)
.segment(segment)
.totalSegments(segments)
.build())
.items()
.map(Accounts::fromItem)));
}
@Nonnull @Nonnull
public AccountCrawlChunk getAllFrom(final UUID from, final int maxCount) { public AccountCrawlChunk getAllFrom(final UUID from, final int maxCount) {
final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder() final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder()

View File

@ -58,6 +58,7 @@ import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
public class AccountsManager { public class AccountsManager {
@ -714,6 +715,10 @@ public class AccountsManager {
return accounts.getAllFrom(uuid, length); return accounts.getAllFrom(uuid, length);
} }
public Flux<Account> streamAllFromDynamo(final int segments) {
return accounts.getAll(segments);
}
public void delete(final Account account, final DeletionReason deletionReason) throws InterruptedException { public void delete(final Account account, final DeletionReason deletionReason) throws InterruptedException {
try (final Timer.Context ignored = deleteTimer.time()) { try (final Timer.Context ignored = deleteTimer.time()) {
accountLockManager.withLock(List.of(account.getNumber()), () -> { accountLockManager.withLock(List.of(account.getNumber()), () -> {

View File

@ -11,6 +11,7 @@ import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -36,6 +37,7 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import org.apache.commons.lang3.RandomUtils; import org.apache.commons.lang3.RandomUtils;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -492,6 +494,23 @@ class AccountsTest {
assertThat(users).isEmpty(); assertThat(users).isEmpty();
} }
@Test
void testGetAll() {
final List<Account> expectedAccounts = new ArrayList<>();
for (int i = 1; i <= 100; i++) {
final Account account = generateAccount("+1" + String.format("%03d", i), UUID.randomUUID(), UUID.randomUUID());
expectedAccounts.add(account);
accounts.create(account);
}
final List<Account> retrievedAccounts = accounts.getAll(2).collectList().block();
assertNotNull(retrievedAccounts);
assertEquals(expectedAccounts.stream().map(Account::getUuid).collect(Collectors.toSet()),
retrievedAccounts.stream().map(Account::getUuid).collect(Collectors.toSet()));
}
@Test @Test
void testDelete() { void testDelete() {
final Device deletedDevice = generateDevice(1); final Device deletedDevice = generateDevice(1);