diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java index c6b96036a..88f05d947 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java @@ -14,7 +14,6 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Scheduler; import reactor.util.function.Tuple2; import reactor.util.function.Tuples; import reactor.util.retry.Retry; @@ -24,6 +23,7 @@ import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; import software.amazon.awssdk.services.dynamodb.model.ReturnValue; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; @@ -198,7 +198,6 @@ public class PushNotificationExperimentSamples { * * @param experimentName the name of the experiment for which to fetch samples * @param stateClass the type of state object for sample in the given experiment - * @param totalSegments the number of segments into which the scan of the backing data store will be divided * * @return a publisher of tuples of ACI, device ID, and sample for all samples associated with the given experiment * @@ -206,38 +205,11 @@ public class PushNotificationExperimentSamples { * * @see Working with scans - Parallel scan */ - public Flux> getSamples(final String experimentName, - final Class stateClass, - final int totalSegments, - final Scheduler scheduler) { + public Flux> getSamples(final String experimentName, final Class stateClass) { - // Note that we're using a DynamoDB Scan operation instead of a Query. A Query would allow us to limit the search - // space to a specific experiment, but doesn't allow us to use segments. A Scan will always inspect all items in the - // table, but allows us to segment the search. Since we're generally calling this method in conjunction with "…and - // record a final state for the sample," distributing reads/writes across shards helps us avoid per-partition - // throughput limits. If we wind up with many concurrent experiments, it may be worthwhile to revisit this decision. - - if (totalSegments < 1) { - throw new IllegalArgumentException("Total number of segments must be positive"); - } - - return Flux.range(0, totalSegments) - .parallel() - .runOn(scheduler) - .flatMap(segment -> getSamplesFromSegment(experimentName, stateClass, segment, totalSegments)) - .sequential(); - } - - private Flux> getSamplesFromSegment(final String experimentName, - final Class stateClass, - final int segment, - final int totalSegments) { - - return Flux.from(dynamoDbAsyncClient.scanPaginator(ScanRequest.builder() + return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() .tableName(tableName) - .segment(segment) - .totalSegments(totalSegments) - .filterExpression("#experiment = :experiment") + .keyConditionExpression("#experiment = :experiment") .expressionAttributeNames(Map.of("#experiment", KEY_EXPERIMENT_NAME)) .expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(experimentName))) .build()) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java index d26950aaa..28cf4d5bd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java @@ -18,7 +18,6 @@ import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.storage.AccountsManager; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; import reactor.util.retry.Retry; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import java.time.Duration; @@ -32,9 +31,6 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW @VisibleForTesting static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; - @VisibleForTesting - static final String SEGMENT_COUNT_ARGUMENT = "segments"; - private static final String SAMPLES_READ_COUNTER_NAME = MetricsUtil.name(FinishPushNotificationExperimentCommand.class, "samplesRead"); @@ -68,13 +64,6 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW .dest(MAX_CONCURRENCY_ARGUMENT) .setDefault(DEFAULT_MAX_CONCURRENCY) .help("Max concurrency for DynamoDB operations"); - - subparser.addArgument("--segments") - .type(Integer.class) - .dest(SEGMENT_COUNT_ARGUMENT) - .required(false) - .setDefault(16) - .help("The total number of segments for a DynamoDB scan"); } @Override @@ -87,7 +76,6 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW experimentFactory.buildExperiment(commandDependencies, configuration); final int maxConcurrency = namespace.getInt(MAX_CONCURRENCY_ARGUMENT); - final int segments = namespace.getInt(SEGMENT_COUNT_ARGUMENT); log.info("Finishing \"{}\" with max concurrency: {}", experiment.getExperimentName(), maxConcurrency); @@ -95,10 +83,7 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW final PushNotificationExperimentSamples pushNotificationExperimentSamples = commandDependencies.pushNotificationExperimentSamples(); final Flux> finishedSamples = - pushNotificationExperimentSamples.getSamples(experiment.getExperimentName(), - experiment.getStateClass(), - segments, - Schedulers.parallel()) + pushNotificationExperimentSamples.getSamples(experiment.getExperimentName(), experiment.getStateClass()) .doOnNext(sample -> Metrics.counter(SAMPLES_READ_COUNTER_NAME, "final", String.valueOf(sample.finalState() != null)).increment()) .flatMap(sample -> { if (sample.finalState() == null) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java index ab8e4818b..734b44c16 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java @@ -7,7 +7,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import com.fasterxml.jackson.core.JsonProcessingException; import java.time.Clock; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; @@ -21,7 +20,6 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema; import org.whispersystems.textsecuregcm.util.SystemMapper; -import reactor.core.scheduler.Schedulers; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; @@ -250,7 +248,7 @@ class PushNotificationExperimentSamplesTest { new TestDeviceState(finalBounciness))); assertEquals(expectedSamples, - pushNotificationExperimentSamples.getSamples(experimentName, TestDeviceState.class, 1, Schedulers.immediate()).collect(Collectors.toSet()).block()); + pushNotificationExperimentSamples.getSamples(experimentName, TestDeviceState.class).collect(Collectors.toSet()).block()); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java index 7e30f1d19..9a67bce4d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java @@ -3,7 +3,6 @@ package org.whispersystems.textsecuregcm.workers; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -37,9 +36,8 @@ class FinishPushNotificationExperimentCommandTest { private static final String EXPERIMENT_NAME = "test"; - private static final Namespace NAMESPACE = new Namespace(Map.of( - FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1, - FinishPushNotificationExperimentCommand.SEGMENT_COUNT_ARGUMENT, 1)); + private static final Namespace NAMESPACE = + new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)); private static class TestFinishPushNotificationExperimentCommand extends FinishPushNotificationExperimentCommand { @@ -112,7 +110,7 @@ class FinishPushNotificationExperimentCommandTest { when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class))) .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); @@ -129,7 +127,7 @@ class FinishPushNotificationExperimentCommandTest { when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class))) .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); @@ -149,7 +147,7 @@ class FinishPushNotificationExperimentCommandTest { when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class))) .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); @@ -174,7 +172,7 @@ class FinishPushNotificationExperimentCommandTest { .thenReturn(CompletableFuture.failedFuture(new RuntimeException())) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class))) .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); @@ -199,7 +197,7 @@ class FinishPushNotificationExperimentCommandTest { when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class))) .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any())) @@ -227,7 +225,7 @@ class FinishPushNotificationExperimentCommandTest { when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class))) .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any())) @@ -241,7 +239,7 @@ class FinishPushNotificationExperimentCommandTest { @Test void runFinalSampleAlreadyRecorded() { - when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class))) .thenReturn(Flux.just(new PushNotificationExperimentSample<>(UUID.randomUUID(), Device.PRIMARY_ID, true, "test", "test"))); assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies));