diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommand.java index 2e68cd10f..294c28c21 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommand.java @@ -30,8 +30,19 @@ public class StartPushNotificationExperimentCommand extends AbstractSinglePas @VisibleForTesting static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; - private static final Counter INITIAL_SAMPLE_ALREADY_EXISTS_COUNTER = - Metrics.counter(MetricsUtil.name(StartPushNotificationExperimentCommand.class, "initialSampleAlreadyExists")); + @VisibleForTesting + static final String DRY_RUN_ARGUMENT = "dry-run"; + + private static final Counter DEVICE_INSPECTED_COUNTER = + Metrics.counter(MetricsUtil.name(StartPushNotificationExperimentCommand.class, "deviceInspected")); + + private static final String RECORD_INITIAL_SAMPLE_COUNTER_NAME = + MetricsUtil.name(StartPushNotificationExperimentCommand.class, "recordInitialSample"); + + private static final String APPLY_TREATMENT_COUNTER_NAME = + MetricsUtil.name(StartPushNotificationExperimentCommand.class, "applyTreatment"); + + private static final String DRY_RUN_TAG_NAME = "dryRun"; private static final Logger log = LoggerFactory.getLogger(StartPushNotificationExperimentCommand.class); @@ -52,11 +63,19 @@ public class StartPushNotificationExperimentCommand extends AbstractSinglePas .dest(MAX_CONCURRENCY_ARGUMENT) .setDefault(DEFAULT_MAX_CONCURRENCY) .help("Max concurrency for DynamoDB operations"); + + subparser.addArgument("--dry-run") + .type(Boolean.class) + .dest(DRY_RUN_ARGUMENT) + .required(false) + .setDefault(true) + .help("If true, don't actually record samples or apply treatments"); } @Override protected void crawlAccounts(final Flux accounts) { final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT); + final boolean dryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT); final PushNotificationExperiment experiment = experimentFactory.buildExperiment(getCommandDependencies(), getConfiguration()); @@ -67,6 +86,7 @@ public class StartPushNotificationExperimentCommand extends AbstractSinglePas accounts .flatMap(account -> Flux.fromIterable(account.getDevices()) .map(device -> Tuples.of(account, device))) + .doOnNext(ignored -> DEVICE_INSPECTED_COUNTER.increment()) .filterWhen(accountAndDevice -> Mono.fromFuture(() -> experiment.isDeviceEligible(accountAndDevice.getT1(), accountAndDevice.getT2())), maxConcurrency) @@ -74,28 +94,31 @@ public class StartPushNotificationExperimentCommand extends AbstractSinglePas final UUID accountIdentifier = accountAndDevice.getT1().getIdentifier(IdentityType.ACI); final byte deviceId = accountAndDevice.getT2().getId(); - return Mono.fromFuture(() -> { - try { - return pushNotificationExperimentSamples.recordInitialState( + final Mono recordInitialSampleMono = dryRun + ? Mono.just(true) + : Mono.fromFuture(() -> { + try { + return pushNotificationExperimentSamples.recordInitialState( accountIdentifier, deviceId, experiment.getExperimentName(), isInExperimentGroup(accountIdentifier, deviceId, experiment.getExperimentName()), experiment.getState(accountAndDevice.getT1(), accountAndDevice.getT2())); - } catch (final JsonProcessingException e) { - throw new UncheckedIOException(e); - } + } catch (final JsonProcessingException e) { + throw new UncheckedIOException(e); + } + }) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)) + .onRetryExhaustedThrow(((backoffSpec, retrySignal) -> retrySignal.failure()))); + + return recordInitialSampleMono.mapNotNull(stateStored -> { + Metrics.counter(RECORD_INITIAL_SAMPLE_COUNTER_NAME, + DRY_RUN_TAG_NAME, String.valueOf(dryRun), + "initialSampleAlreadyExists", String.valueOf(!stateStored)) + .increment(); + + return stateStored ? accountAndDevice : null; }) - .mapNotNull(stateStored -> { - if (stateStored) { - return accountAndDevice; - } else { - INITIAL_SAMPLE_ALREADY_EXISTS_COUNTER.increment(); - return null; - } - }) - .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)) - .onRetryExhaustedThrow(((backoffSpec, retrySignal) -> retrySignal.failure()))) .onErrorResume(throwable -> { log.warn("Failed to record initial sample for {}:{} in experiment {}", accountIdentifier, deviceId, experiment.getExperimentName(), throwable); @@ -109,19 +132,26 @@ public class StartPushNotificationExperimentCommand extends AbstractSinglePas final boolean inExperimentGroup = isInExperimentGroup(account.getIdentifier(IdentityType.ACI), device.getId(), experiment.getExperimentName()); - return Mono.fromFuture(() -> inExperimentGroup - ? experiment.applyExperimentTreatment(account, device) - : experiment.applyControlTreatment(account, device)) - .onErrorResume(throwable -> { - log.warn("Failed to apply {} treatment for {}:{} in experiment {}", - inExperimentGroup ? "experimental" : " control", - account.getIdentifier(IdentityType.ACI), - device.getId(), - experiment.getExperimentName(), - throwable); + final Mono applyTreatmentMono = dryRun + ? Mono.empty() + : Mono.fromFuture(() -> inExperimentGroup + ? experiment.applyExperimentTreatment(account, device) + : experiment.applyControlTreatment(account, device)) + .onErrorResume(throwable -> { + log.warn("Failed to apply {} treatment for {}:{} in experiment {}", + inExperimentGroup ? "experimental" : " control", + account.getIdentifier(IdentityType.ACI), + device.getId(), + experiment.getExperimentName(), + throwable); - return Mono.empty(); - }); + return Mono.empty(); + }); + + return applyTreatmentMono + .doOnSuccess(ignored -> Metrics.counter(APPLY_TREATMENT_COUNTER_NAME, + DRY_RUN_TAG_NAME, String.valueOf(dryRun), + "treatment", inExperimentGroup ? "experiment" : "control").increment()); }, maxConcurrency) .then() .block(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommandTest.java index 4cbacb909..20ed2ecbb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommandTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommandTest.java @@ -18,6 +18,8 @@ import java.util.concurrent.CompletableFuture; import net.sourceforge.argparse4j.inf.Namespace; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; import org.whispersystems.textsecuregcm.identity.IdentityType; @@ -30,7 +32,7 @@ class StartPushNotificationExperimentCommandTest { private PushNotificationExperimentSamples pushNotificationExperimentSamples; private PushNotificationExperiment experiment; - private StartPushNotificationExperimentCommand startPushNotificationExperimentCommand; + private TestStartPushNotificationExperimentCommand startPushNotificationExperimentCommand; // Taken together, these parameters will produce a device that's enrolled in the experimental group (as opposed to the // control group) for an experiment. @@ -41,6 +43,7 @@ class StartPushNotificationExperimentCommandTest { private static class TestStartPushNotificationExperimentCommand extends StartPushNotificationExperimentCommand { private final CommandDependencies commandDependencies; + private boolean dryRun = false; public TestStartPushNotificationExperimentCommand( final PushNotificationExperimentSamples pushNotificationExperimentSamples, @@ -67,9 +70,15 @@ class StartPushNotificationExperimentCommandTest { null); } + void setDryRun(final boolean dryRun) { + this.dryRun = dryRun; + } + @Override protected Namespace getNamespace() { - return new Namespace(Map.of(StartPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)); + return new Namespace(Map.of( + StartPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1, + StartPushNotificationExperimentCommand.DRY_RUN_ARGUMENT, dryRun)); } @Override @@ -100,8 +109,11 @@ class StartPushNotificationExperimentCommandTest { new TestStartPushNotificationExperimentCommand(pushNotificationExperimentSamples, experiment); } - @Test - void crawlAccounts() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void crawlAccounts(final boolean dryRun) { + startPushNotificationExperimentCommand.setDryRun(dryRun); + final Device device = mock(Device.class); when(device.getId()).thenReturn(DEVICE_ID); @@ -110,11 +122,21 @@ class StartPushNotificationExperimentCommandTest { when(account.getDevices()).thenReturn(List.of(device)); assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account))); - verify(experiment).applyExperimentTreatment(account, device); + + if (dryRun) { + verify(experiment, never()).applyExperimentTreatment(any(), any()); + } else { + verify(experiment).applyExperimentTreatment(account, device); + } + + verify(experiment, never()).applyControlTreatment(account, device); } - @Test - void crawlAccountsExistingSample() throws JsonProcessingException { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void crawlAccountsExistingSample(final boolean dryRun) throws JsonProcessingException { + startPushNotificationExperimentCommand.setDryRun(dryRun); + final Device device = mock(Device.class); when(device.getId()).thenReturn(DEVICE_ID); @@ -126,11 +148,14 @@ class StartPushNotificationExperimentCommandTest { .thenReturn(CompletableFuture.completedFuture(false)); assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account))); - verify(experiment, never()).applyExperimentTreatment(account, device); + verify(experiment, never()).applyExperimentTreatment(any(), any()); } - @Test - void crawlAccountsSampleRetry() throws JsonProcessingException { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void crawlAccountsSampleRetry(final boolean dryRun) throws JsonProcessingException { + startPushNotificationExperimentCommand.setDryRun(dryRun); + final Device device = mock(Device.class); when(device.getId()).thenReturn(DEVICE_ID); @@ -144,9 +169,18 @@ class StartPushNotificationExperimentCommandTest { .thenReturn(CompletableFuture.completedFuture(true)); assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account))); - verify(experiment).applyExperimentTreatment(account, device); - verify(pushNotificationExperimentSamples, times(3)) - .recordInitialState(ACCOUNT_IDENTIFIER, DEVICE_ID, EXPERIMENT_NAME, true, "test"); + + if (dryRun) { + verify(experiment, never()).applyExperimentTreatment(any(), any()); + verify(pushNotificationExperimentSamples, never()) + .recordInitialState(any(), anyByte(), any(), anyBoolean(), any()); + } else { + verify(experiment).applyExperimentTreatment(account, device); + verify(pushNotificationExperimentSamples, times(3)) + .recordInitialState(ACCOUNT_IDENTIFIER, DEVICE_ID, EXPERIMENT_NAME, true, "test"); + } + + verify(experiment, never()).applyControlTreatment(account, device); } @Test