diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 88212fde8..7151507a1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.UUIDUtil; +import reactor.core.publisher.Flux; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; @@ -677,6 +678,23 @@ public class Accounts extends AbstractDynamoDbStore { })); } + Flux getAll(final int segments) { + if (segments < 1) { + throw new IllegalArgumentException("Total number of segments must be positive"); + } + + return Flux.merge( + Flux.range(0, segments) + .map(segment -> asyncClient.scanPaginator(ScanRequest.builder() + .tableName(accountsTableName) + .consistentRead(true) + .segment(segment) + .totalSegments(segments) + .build()) + .items() + .map(Accounts::fromItem))); + } + @Nonnull public AccountCrawlChunk getAllFrom(final UUID from, final int maxCount) { final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder() diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index b51078b2b..ad349fb60 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -58,6 +58,7 @@ import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; +import reactor.core.publisher.Flux; public class AccountsManager { @@ -714,6 +715,10 @@ public class AccountsManager { return accounts.getAllFrom(uuid, length); } + public Flux streamAllFromDynamo(final int segments) { + return accounts.getAll(segments); + } + public void delete(final Account account, final DeletionReason deletionReason) throws InterruptedException { try (final Timer.Context ignored = deleteTimer.time()) { accountLockManager.withLock(List.of(account.getNumber()), () -> { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index 1c8877943..2e50c6815 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -11,6 +11,7 @@ import static org.assertj.core.api.Assertions.assertThatNoException; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; @@ -36,6 +37,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; +import java.util.stream.Collectors; import org.apache.commons.lang3.RandomUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -492,6 +494,23 @@ class AccountsTest { assertThat(users).isEmpty(); } + @Test + void testGetAll() { + final List expectedAccounts = new ArrayList<>(); + + for (int i = 1; i <= 100; i++) { + final Account account = generateAccount("+1" + String.format("%03d", i), UUID.randomUUID(), UUID.randomUUID()); + expectedAccounts.add(account); + accounts.create(account); + } + + final List retrievedAccounts = accounts.getAll(2).collectList().block(); + + assertNotNull(retrievedAccounts); + assertEquals(expectedAccounts.stream().map(Account::getUuid).collect(Collectors.toSet()), + retrievedAccounts.stream().map(Account::getUuid).collect(Collectors.toSet())); + } + @Test void testDelete() { final Device deletedDevice = generateDevice(1);