diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java
index c6b96036a..88f05d947 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java
@@ -14,7 +14,6 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
-import reactor.core.scheduler.Scheduler;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
import reactor.util.retry.Retry;
@@ -24,6 +23,7 @@ import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
+import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
@@ -198,7 +198,6 @@ public class PushNotificationExperimentSamples {
*
* @param experimentName the name of the experiment for which to fetch samples
* @param stateClass the type of state object for sample in the given experiment
- * @param totalSegments the number of segments into which the scan of the backing data store will be divided
*
* @return a publisher of tuples of ACI, device ID, and sample for all samples associated with the given experiment
*
@@ -206,38 +205,11 @@ public class PushNotificationExperimentSamples {
*
* @see Working with scans - Parallel scan
*/
- public Flux> getSamples(final String experimentName,
- final Class stateClass,
- final int totalSegments,
- final Scheduler scheduler) {
+ public Flux> getSamples(final String experimentName, final Class stateClass) {
- // Note that we're using a DynamoDB Scan operation instead of a Query. A Query would allow us to limit the search
- // space to a specific experiment, but doesn't allow us to use segments. A Scan will always inspect all items in the
- // table, but allows us to segment the search. Since we're generally calling this method in conjunction with "…and
- // record a final state for the sample," distributing reads/writes across shards helps us avoid per-partition
- // throughput limits. If we wind up with many concurrent experiments, it may be worthwhile to revisit this decision.
-
- if (totalSegments < 1) {
- throw new IllegalArgumentException("Total number of segments must be positive");
- }
-
- return Flux.range(0, totalSegments)
- .parallel()
- .runOn(scheduler)
- .flatMap(segment -> getSamplesFromSegment(experimentName, stateClass, segment, totalSegments))
- .sequential();
- }
-
- private Flux> getSamplesFromSegment(final String experimentName,
- final Class stateClass,
- final int segment,
- final int totalSegments) {
-
- return Flux.from(dynamoDbAsyncClient.scanPaginator(ScanRequest.builder()
+ return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
- .segment(segment)
- .totalSegments(totalSegments)
- .filterExpression("#experiment = :experiment")
+ .keyConditionExpression("#experiment = :experiment")
.expressionAttributeNames(Map.of("#experiment", KEY_EXPERIMENT_NAME))
.expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(experimentName)))
.build())
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java
index d26950aaa..28cf4d5bd 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java
@@ -18,7 +18,6 @@ import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
-import reactor.core.scheduler.Schedulers;
import reactor.util.retry.Retry;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import java.time.Duration;
@@ -32,9 +31,6 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW
@VisibleForTesting
static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency";
- @VisibleForTesting
- static final String SEGMENT_COUNT_ARGUMENT = "segments";
-
private static final String SAMPLES_READ_COUNTER_NAME =
MetricsUtil.name(FinishPushNotificationExperimentCommand.class, "samplesRead");
@@ -68,13 +64,6 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW
.dest(MAX_CONCURRENCY_ARGUMENT)
.setDefault(DEFAULT_MAX_CONCURRENCY)
.help("Max concurrency for DynamoDB operations");
-
- subparser.addArgument("--segments")
- .type(Integer.class)
- .dest(SEGMENT_COUNT_ARGUMENT)
- .required(false)
- .setDefault(16)
- .help("The total number of segments for a DynamoDB scan");
}
@Override
@@ -87,7 +76,6 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW
experimentFactory.buildExperiment(commandDependencies, configuration);
final int maxConcurrency = namespace.getInt(MAX_CONCURRENCY_ARGUMENT);
- final int segments = namespace.getInt(SEGMENT_COUNT_ARGUMENT);
log.info("Finishing \"{}\" with max concurrency: {}", experiment.getExperimentName(), maxConcurrency);
@@ -95,10 +83,7 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW
final PushNotificationExperimentSamples pushNotificationExperimentSamples = commandDependencies.pushNotificationExperimentSamples();
final Flux> finishedSamples =
- pushNotificationExperimentSamples.getSamples(experiment.getExperimentName(),
- experiment.getStateClass(),
- segments,
- Schedulers.parallel())
+ pushNotificationExperimentSamples.getSamples(experiment.getExperimentName(), experiment.getStateClass())
.doOnNext(sample -> Metrics.counter(SAMPLES_READ_COUNTER_NAME, "final", String.valueOf(sample.finalState() != null)).increment())
.flatMap(sample -> {
if (sample.finalState() == null) {
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java
index ab8e4818b..734b44c16 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java
@@ -7,7 +7,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.time.Clock;
-import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
@@ -21,7 +20,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
import org.whispersystems.textsecuregcm.util.SystemMapper;
-import reactor.core.scheduler.Schedulers;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
@@ -250,7 +248,7 @@ class PushNotificationExperimentSamplesTest {
new TestDeviceState(finalBounciness)));
assertEquals(expectedSamples,
- pushNotificationExperimentSamples.getSamples(experimentName, TestDeviceState.class, 1, Schedulers.immediate()).collect(Collectors.toSet()).block());
+ pushNotificationExperimentSamples.getSamples(experimentName, TestDeviceState.class).collect(Collectors.toSet()).block());
}
@Test
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java
index 7e30f1d19..9a67bce4d 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java
@@ -3,7 +3,6 @@ package org.whispersystems.textsecuregcm.workers;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
-import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
@@ -37,9 +36,8 @@ class FinishPushNotificationExperimentCommandTest {
private static final String EXPERIMENT_NAME = "test";
- private static final Namespace NAMESPACE = new Namespace(Map.of(
- FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1,
- FinishPushNotificationExperimentCommand.SEGMENT_COUNT_ARGUMENT, 1));
+ private static final Namespace NAMESPACE =
+ new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1));
private static class TestFinishPushNotificationExperimentCommand extends FinishPushNotificationExperimentCommand {
@@ -112,7 +110,7 @@ class FinishPushNotificationExperimentCommandTest {
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
- when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any()))
+ when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class)))
.thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies));
@@ -129,7 +127,7 @@ class FinishPushNotificationExperimentCommandTest {
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
- when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any()))
+ when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class)))
.thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies));
@@ -149,7 +147,7 @@ class FinishPushNotificationExperimentCommandTest {
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
- when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any()))
+ when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class)))
.thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies));
@@ -174,7 +172,7 @@ class FinishPushNotificationExperimentCommandTest {
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
- when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any()))
+ when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class)))
.thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies));
@@ -199,7 +197,7 @@ class FinishPushNotificationExperimentCommandTest {
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
- when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any()))
+ when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class)))
.thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null)));
when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any()))
@@ -227,7 +225,7 @@ class FinishPushNotificationExperimentCommandTest {
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
- when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any()))
+ when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class)))
.thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null)));
when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any()))
@@ -241,7 +239,7 @@ class FinishPushNotificationExperimentCommandTest {
@Test
void runFinalSampleAlreadyRecorded() {
- when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any()))
+ when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class)))
.thenReturn(Flux.just(new PushNotificationExperimentSample<>(UUID.randomUUID(), Device.PRIMARY_ID, true, "test", "test")));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies));