diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/NotifyIdleDevicesWithoutMessagesPushNotificationExperiment.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/NotifyIdleDevicesWithoutMessagesPushNotificationExperiment.java index 500509e5b..581b9855e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/NotifyIdleDevicesWithoutMessagesPushNotificationExperiment.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/NotifyIdleDevicesWithoutMessagesPushNotificationExperiment.java @@ -68,6 +68,11 @@ public class NotifyIdleDevicesWithoutMessagesPushNotificationExperiment implemen .thenApply(mayHavePersistedMessages -> !mayHavePersistedMessages); } + @Override + public Class getStateClass() { + return DeviceLastSeenState.class; + } + @VisibleForTesting static boolean hasPushToken(final Device device) { // Exclude VOIP tokens since they have their own, distinct delivery mechanism diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperiment.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperiment.java index 2661b8b62..05a8fa43c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperiment.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperiment.java @@ -1,3 +1,4 @@ + package org.whispersystems.textsecuregcm.experiment; import org.whispersystems.textsecuregcm.storage.Account; @@ -32,6 +33,13 @@ public interface PushNotificationExperiment { */ CompletableFuture isDeviceEligible(Account account, Device device); + /** + * Returns the class of the state object stored for this experiment. + * + * @return the class of the state object stored for this experiment + */ + Class getStateClass(); + /** * Generates an experiment specific state "snapshot" of the given device. Experiment results are generally evaluated * by comparing a device's state before a treatment is applied and its state after the treatment is applied. diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSample.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSample.java index 01c034cda..90d39a8f1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSample.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSample.java @@ -1,4 +1,11 @@ package org.whispersystems.textsecuregcm.experiment; -public record PushNotificationExperimentSample(boolean inExperimentGroup, T initialState, T finalState) { +import javax.annotation.Nullable; +import java.util.UUID; + +public record PushNotificationExperimentSample(UUID accountIdentifier, + byte deviceId, + boolean inExperimentGroup, + T initialState, + @Nullable T finalState) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java index bb8f310d5..c6b96036a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java @@ -14,6 +14,7 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; import reactor.util.function.Tuple2; import reactor.util.function.Tuples; import reactor.util.retry.Retry; @@ -23,7 +24,6 @@ import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; -import software.amazon.awssdk.services.dynamodb.model.QueryRequest; import software.amazon.awssdk.services.dynamodb.model.ReturnValue; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; @@ -130,106 +130,129 @@ public class PushNotificationExperimentSamples { * @param finalState the final state of the object; must be serializable as a JSON text and of the same type as the * previously-stored initial state - * @return a future that completes when the final state has been stored; yields a finished sample if an initial sample - * was found or empty if no initial sample was found for the given account, device, and experiment + * @return A future that completes when the final state has been stored; yields a finished sample if an initial sample + * was found or empty if no initial sample was found for the given account, device, and experiment. The future may + * with a {@link JsonProcessingException} if the initial state could not be read or the final state could not be + * written as a JSON text. * * @param the type of state object for this sample - * - * @throws JsonProcessingException if the given {@code finalState} could not be serialized as a JSON text */ public CompletableFuture> recordFinalState(final UUID accountIdentifier, final byte deviceId, final String experimentName, - final T finalState) throws JsonProcessingException { + final T finalState) { + + CompletableFuture finalStateJsonFuture; + + // Process the final state JSON on the calling thread, but inside a CompletionStage so there's just one "channel" + // for reporting JSON exceptions. The alternative is to `throw JsonProcessingException`, but then callers would have + // to both catch the exception when calling this method and also watch the returned future for the same exception. + try { + finalStateJsonFuture = + CompletableFuture.completedFuture(SystemMapper.jsonMapper().writeValueAsString(finalState)); + } catch (final JsonProcessingException e) { + finalStateJsonFuture = CompletableFuture.failedFuture(e); + } final AttributeValue aciAndDeviceIdAttributeValue = buildSortKey(accountIdentifier, deviceId); - return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder() - .tableName(tableName) - .key(Map.of( - KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName), - ATTR_ACI_AND_DEVICE_ID, aciAndDeviceIdAttributeValue)) - // `UpdateItem` will, by default, create a new item if one does not already exist for the given primary key. We - // want update-only-if-exists behavior, though, and so check that there's already an existing item for this ACI - // and device ID. - .conditionExpression("#aciAndDeviceId = :aciAndDeviceId") - .updateExpression("SET #finalState = if_not_exists(#finalState, :finalState)") - .expressionAttributeNames(Map.of( - "#aciAndDeviceId", ATTR_ACI_AND_DEVICE_ID, - "#finalState", ATTR_FINAL_STATE)) - .expressionAttributeValues(Map.of( - ":aciAndDeviceId", aciAndDeviceIdAttributeValue, - ":finalState", AttributeValue.fromS(SystemMapper.jsonMapper().writeValueAsString(finalState)))) - .returnValues(ReturnValue.ALL_NEW) - .build()) - .thenApply(updateItemResponse -> { - try { - final boolean inExperimentGroup = updateItemResponse.attributes().get(ATTR_IN_EXPERIMENT_GROUP).bool(); + return finalStateJsonFuture.thenCompose(finalStateJson -> { + return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder() + .tableName(tableName) + .key(Map.of( + KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName), + ATTR_ACI_AND_DEVICE_ID, aciAndDeviceIdAttributeValue)) + // `UpdateItem` will, by default, create a new item if one does not already exist for the given primary key. We + // want update-only-if-exists behavior, though, and so check that there's already an existing item for this ACI + // and device ID. + .conditionExpression("#aciAndDeviceId = :aciAndDeviceId") + .updateExpression("SET #finalState = if_not_exists(#finalState, :finalState)") + .expressionAttributeNames(Map.of( + "#aciAndDeviceId", ATTR_ACI_AND_DEVICE_ID, + "#finalState", ATTR_FINAL_STATE)) + .expressionAttributeValues(Map.of( + ":aciAndDeviceId", aciAndDeviceIdAttributeValue, + ":finalState", AttributeValue.fromS(finalStateJson))) + .returnValues(ReturnValue.ALL_NEW) + .build()) + .thenApply(updateItemResponse -> { + try { + final boolean inExperimentGroup = updateItemResponse.attributes().get(ATTR_IN_EXPERIMENT_GROUP).bool(); - @SuppressWarnings("unchecked") final T parsedInitialState = - (T) parseState(updateItemResponse.attributes().get(ATTR_INITIAL_STATE).s(), finalState.getClass()); + @SuppressWarnings("unchecked") final T parsedInitialState = + (T) parseState(updateItemResponse.attributes().get(ATTR_INITIAL_STATE).s(), finalState.getClass()); - @SuppressWarnings("unchecked") final T parsedFinalState = - (T) parseState(updateItemResponse.attributes().get(ATTR_FINAL_STATE).s(), finalState.getClass()); + @SuppressWarnings("unchecked") final T parsedFinalState = + (T) parseState(updateItemResponse.attributes().get(ATTR_FINAL_STATE).s(), finalState.getClass()); - return new PushNotificationExperimentSample<>(inExperimentGroup, parsedInitialState, parsedFinalState); - } catch (final JsonProcessingException e) { - throw ExceptionUtils.wrap(e); - } - }); + return new PushNotificationExperimentSample<>(accountIdentifier, deviceId, inExperimentGroup, parsedInitialState, parsedFinalState); + } catch (final JsonProcessingException e) { + throw ExceptionUtils.wrap(e); + } + }); + }); } /** - * Returns a publisher across all samples pending a final state for a given experiment. + * Returns a publisher across all samples for a given experiment. * - * @param experimentName the name of the experiment for which to retrieve samples pending a final state - * - * @return a publisher across all samples pending a final state for a given experiment - */ - public Flux> getDevicesPendingFinalState(final String experimentName) { - return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() - .tableName(tableName) - .keyConditionExpression("#experiment = :experiment") - .filterExpression("attribute_not_exists(#finalState)") - .expressionAttributeNames(Map.of( - "#experiment", KEY_EXPERIMENT_NAME, - "#finalState", ATTR_FINAL_STATE)) - .expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(experimentName))) - .projectionExpression(ATTR_ACI_AND_DEVICE_ID) - .build()) - .items()) - .map(item -> parseSortKey(item.get(ATTR_ACI_AND_DEVICE_ID))); - } - - /** - * Returns a publisher across all finished samples (i.e. samples with a recorded final state) for a given experiment. - * - * @param experimentName the name of the experiment for which to retrieve finished samples + * @param experimentName the name of the experiment for which to fetch samples * @param stateClass the type of state object for sample in the given experiment + * @param totalSegments the number of segments into which the scan of the backing data store will be divided * - * @return a publisher across all finished samples for the given experiment + * @return a publisher of tuples of ACI, device ID, and sample for all samples associated with the given experiment * * @param the type of the sample's state objects + * + * @see Working with scans - Parallel scan */ - public Flux> getFinishedSamples(final String experimentName, - final Class stateClass) { - return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() - .tableName(tableName) - .keyConditionExpression("#experiment = :experiment") - .filterExpression("attribute_exists(#finalState)") - .expressionAttributeNames(Map.of( - "#experiment", KEY_EXPERIMENT_NAME, - "#finalState", ATTR_FINAL_STATE)) + public Flux> getSamples(final String experimentName, + final Class stateClass, + final int totalSegments, + final Scheduler scheduler) { + + // Note that we're using a DynamoDB Scan operation instead of a Query. A Query would allow us to limit the search + // space to a specific experiment, but doesn't allow us to use segments. A Scan will always inspect all items in the + // table, but allows us to segment the search. Since we're generally calling this method in conjunction with "…and + // record a final state for the sample," distributing reads/writes across shards helps us avoid per-partition + // throughput limits. If we wind up with many concurrent experiments, it may be worthwhile to revisit this decision. + + if (totalSegments < 1) { + throw new IllegalArgumentException("Total number of segments must be positive"); + } + + return Flux.range(0, totalSegments) + .parallel() + .runOn(scheduler) + .flatMap(segment -> getSamplesFromSegment(experimentName, stateClass, segment, totalSegments)) + .sequential(); + } + + private Flux> getSamplesFromSegment(final String experimentName, + final Class stateClass, + final int segment, + final int totalSegments) { + + return Flux.from(dynamoDbAsyncClient.scanPaginator(ScanRequest.builder() + .tableName(tableName) + .segment(segment) + .totalSegments(totalSegments) + .filterExpression("#experiment = :experiment") + .expressionAttributeNames(Map.of("#experiment", KEY_EXPERIMENT_NAME)) .expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(experimentName))) - .build()) - .items()) + .build()) + .items()) .handle((item, sink) -> { try { + final Tuple2 aciAndDeviceId = parseSortKey(item.get(ATTR_ACI_AND_DEVICE_ID)); + final boolean inExperimentGroup = item.get(ATTR_IN_EXPERIMENT_GROUP).bool(); final T initialState = parseState(item.get(ATTR_INITIAL_STATE).s(), stateClass); - final T finalState = parseState(item.get(ATTR_FINAL_STATE).s(), stateClass); + final T finalState = item.get(ATTR_FINAL_STATE) != null + ? parseState(item.get(ATTR_FINAL_STATE).s(), stateClass) + : null; - sink.next(new PushNotificationExperimentSample<>(inExperimentGroup, initialState, finalState)); + sink.next(new PushNotificationExperimentSample<>(aciAndDeviceId.getT1(), aciAndDeviceId.getT2(), inExperimentGroup, initialState, finalState)); } catch (final JsonProcessingException e) { sink.error(e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java index 193103c1d..d26950aaa 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java @@ -4,6 +4,8 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.annotations.VisibleForTesting; import io.dropwizard.core.Application; import io.dropwizard.core.setup.Environment; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; import net.sourceforge.argparse4j.inf.Namespace; import net.sourceforge.argparse4j.inf.Subparser; import org.slf4j.Logger; @@ -12,17 +14,14 @@ import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSample; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; -import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.Device; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.util.function.Tuples; +import reactor.core.scheduler.Schedulers; import reactor.util.retry.Retry; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; -import javax.annotation.Nullable; import java.time.Duration; -import java.util.UUID; public class FinishPushNotificationExperimentCommand extends AbstractCommandWithDependencies { @@ -33,6 +32,18 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW @VisibleForTesting static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; + @VisibleForTesting + static final String SEGMENT_COUNT_ARGUMENT = "segments"; + + private static final String SAMPLES_READ_COUNTER_NAME = + MetricsUtil.name(FinishPushNotificationExperimentCommand.class, "samplesRead"); + + private static final Counter ACCOUNT_READ_COUNTER = + Metrics.counter(MetricsUtil.name(FinishPushNotificationExperimentCommand.class, "accountRead")); + + private static final Counter FINAL_SAMPLE_STORED_COUNTER = + Metrics.counter(MetricsUtil.name(FinishPushNotificationExperimentCommand.class, "finalSampleStored")); + private static final Logger log = LoggerFactory.getLogger(FinishPushNotificationExperimentCommand.class); public FinishPushNotificationExperimentCommand(final String name, @@ -57,6 +68,13 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW .dest(MAX_CONCURRENCY_ARGUMENT) .setDefault(DEFAULT_MAX_CONCURRENCY) .help("Max concurrency for DynamoDB operations"); + + subparser.addArgument("--segments") + .type(Integer.class) + .dest(SEGMENT_COUNT_ARGUMENT) + .required(false) + .setDefault(16) + .help("The total number of segments for a DynamoDB scan"); } @Override @@ -69,6 +87,7 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW experimentFactory.buildExperiment(commandDependencies, configuration); final int maxConcurrency = namespace.getInt(MAX_CONCURRENCY_ARGUMENT); + final int segments = namespace.getInt(SEGMENT_COUNT_ARGUMENT); log.info("Finishing \"{}\" with max concurrency: {}", experiment.getExperimentName(), maxConcurrency); @@ -76,48 +95,44 @@ public class FinishPushNotificationExperimentCommand extends AbstractCommandW final PushNotificationExperimentSamples pushNotificationExperimentSamples = commandDependencies.pushNotificationExperimentSamples(); final Flux> finishedSamples = - pushNotificationExperimentSamples.getDevicesPendingFinalState(experiment.getExperimentName()) - .flatMap(accountIdentifierAndDeviceId -> - Mono.fromFuture(() -> accountsManager.getByAccountIdentifierAsync(accountIdentifierAndDeviceId.getT1())) + pushNotificationExperimentSamples.getSamples(experiment.getExperimentName(), + experiment.getStateClass(), + segments, + Schedulers.parallel()) + .doOnNext(sample -> Metrics.counter(SAMPLES_READ_COUNTER_NAME, "final", String.valueOf(sample.finalState() != null)).increment()) + .flatMap(sample -> { + if (sample.finalState() == null) { + // We still need to record a final state for this sample + return Mono.fromFuture(() -> accountsManager.getByAccountIdentifierAsync(sample.accountIdentifier())) .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))) - .map(maybeAccount -> Tuples.of(accountIdentifierAndDeviceId.getT1(), - accountIdentifierAndDeviceId.getT2(), maybeAccount)), maxConcurrency) - .map(accountIdentifierAndDeviceIdAndMaybeAccount -> { - final UUID accountIdentifier = accountIdentifierAndDeviceIdAndMaybeAccount.getT1(); - final byte deviceId = accountIdentifierAndDeviceIdAndMaybeAccount.getT2(); + .doOnNext(ignored -> ACCOUNT_READ_COUNTER.increment()) + .flatMap(maybeAccount -> { + final T finalState = experiment.getState(maybeAccount.orElse(null), + maybeAccount.flatMap(account -> account.getDevice(sample.deviceId())).orElse(null)); - @Nullable final Account account = accountIdentifierAndDeviceIdAndMaybeAccount.getT3() - .orElse(null); + return Mono.fromFuture( + () -> pushNotificationExperimentSamples.recordFinalState(sample.accountIdentifier(), + sample.deviceId(), + experiment.getExperimentName(), + finalState)) + .onErrorResume(ConditionalCheckFailedException.class, throwable -> Mono.empty()) + .onErrorResume(JsonProcessingException.class, throwable -> { + log.error("Failed to parse sample state JSON", throwable); + return Mono.empty(); + }) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))) + .onErrorResume(throwable -> { + log.warn("Failed to record final state for {}:{} in experiment {}", + sample.accountIdentifier(), sample.deviceId(), experiment.getExperimentName(), throwable); - @Nullable final Device device = accountIdentifierAndDeviceIdAndMaybeAccount.getT3() - .flatMap(a -> a.getDevice(deviceId)) - .orElse(null); - - return Tuples.of(accountIdentifier, deviceId, experiment.getState(account, device)); - }) - .flatMap(accountIdentifierAndDeviceIdAndFinalState -> { - final UUID accountIdentifier = accountIdentifierAndDeviceIdAndFinalState.getT1(); - final byte deviceId = accountIdentifierAndDeviceIdAndFinalState.getT2(); - final T finalState = accountIdentifierAndDeviceIdAndFinalState.getT3(); - - return Mono.fromFuture(() -> { - try { - return pushNotificationExperimentSamples.recordFinalState(accountIdentifier, deviceId, - experiment.getExperimentName(), finalState); - } catch (final JsonProcessingException e) { - throw new RuntimeException(e); - } - }) - .onErrorResume(ConditionalCheckFailedException.class, throwable -> Mono.empty()) - .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))) - .onErrorResume(throwable -> { - log.warn("Failed to record final state for {}:{} in experiment {}", - accountIdentifier, deviceId, experiment.getExperimentName(), throwable); - - return Mono.empty(); - }); - }, maxConcurrency) - .flatMap(Mono::justOrEmpty); + return Mono.empty(); + }) + .doOnSuccess(ignored -> FINAL_SAMPLE_STORED_COUNTER.increment()); + }); + } else { + return Mono.just(sample); + } + }, maxConcurrency); experiment.analyzeResults(finishedSamples); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/NotifyIdleDevicesWithoutMessagesPushNotificationExperimentTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/NotifyIdleDevicesWithoutMessagesPushNotificationExperimentTest.java index 9bb1a6531..67641a37d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/NotifyIdleDevicesWithoutMessagesPushNotificationExperimentTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/NotifyIdleDevicesWithoutMessagesPushNotificationExperimentTest.java @@ -206,7 +206,7 @@ class NotifyIdleDevicesWithoutMessagesPushNotificationExperimentTest { final DeviceLastSeenState state = new DeviceLastSeenState(true, 0, true, 0, tokenType); final PushNotificationExperimentSample sample = - new PushNotificationExperimentSample<>(inExperimentGroup, state, state); + new PushNotificationExperimentSample<>(UUID.randomUUID(), Device.PRIMARY_ID, inExperimentGroup, state, state); assertEquals(expectedPopulation, NotifyIdleDevicesWithoutMessagesPushNotificationExperiment.getPopulation(sample)); } @@ -234,7 +234,7 @@ class NotifyIdleDevicesWithoutMessagesPushNotificationExperimentTest { final NotifyIdleDevicesWithoutMessagesPushNotificationExperiment.Outcome expectedOutcome) { final PushNotificationExperimentSample sample = - new PushNotificationExperimentSample<>(true, initialState, finalState); + new PushNotificationExperimentSample<>(UUID.randomUUID(), Device.PRIMARY_ID, true, initialState, finalState); assertEquals(expectedOutcome, NotifyIdleDevicesWithoutMessagesPushNotificationExperiment.getOutcome(sample)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java index 004167f06..ab8e4818b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java @@ -9,9 +9,11 @@ import com.fasterxml.jackson.core.JsonProcessingException; import java.time.Clock; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletionException; import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; @@ -19,7 +21,7 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema; import org.whispersystems.textsecuregcm.util.SystemMapper; -import reactor.util.function.Tuples; +import reactor.core.scheduler.Schedulers; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; @@ -63,7 +65,7 @@ class PushNotificationExperimentSamplesTest { .join(), "Attempt to record an initial state should succeed for entirely new records"); - assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null), + assertEquals(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, inExperimentGroup, new TestDeviceState(bounciness), null), getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class)); assertTrue(pushNotificationExperimentSamples.recordInitialState(accountIdentifier, @@ -74,7 +76,7 @@ class PushNotificationExperimentSamplesTest { .join(), "Attempt to re-record an initial state should succeed for existing-but-unchanged records"); - assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null), + assertEquals(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, inExperimentGroup, new TestDeviceState(bounciness), null), getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class), "Recorded initial state should be unchanged after repeated write"); @@ -86,7 +88,7 @@ class PushNotificationExperimentSamplesTest { .join(), "Attempt to record a conflicting initial state should fail"); - assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null), + assertEquals(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, inExperimentGroup, new TestDeviceState(bounciness), null), getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class), "Recorded initial state should be unchanged after unsuccessful write"); @@ -98,7 +100,7 @@ class PushNotificationExperimentSamplesTest { .join(), "Attempt to record a new group assignment should fail"); - assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null), + assertEquals(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, inExperimentGroup, new TestDeviceState(bounciness), null), getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class), "Recorded initial state should be unchanged after unsuccessful write"); @@ -118,7 +120,7 @@ class PushNotificationExperimentSamplesTest { .join(), "Attempt to record an initial state should fail for samples with final states"); - assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), new TestDeviceState(finalBounciness)), + assertEquals(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, inExperimentGroup, new TestDeviceState(bounciness), new TestDeviceState(finalBounciness)), getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class), "Recorded initial state should be unchanged after unsuccessful write"); } @@ -148,7 +150,7 @@ class PushNotificationExperimentSamplesTest { .join(); final PushNotificationExperimentSample expectedSample = - new PushNotificationExperimentSample<>(inExperimentGroup, + new PushNotificationExperimentSample<>(accountIdentifier, deviceId, inExperimentGroup, new TestDeviceState(initialBounciness), new TestDeviceState(finalBounciness)); @@ -190,92 +192,65 @@ class PushNotificationExperimentSamplesTest { ? SystemMapper.jsonMapper().readValue(response.item().get(PushNotificationExperimentSamples.ATTR_FINAL_STATE).s(), stateClass) : null; - return new PushNotificationExperimentSample<>(inExperimentGroup, initialState, finalState); + return new PushNotificationExperimentSample<>(accountIdentifier, deviceId, inExperimentGroup, initialState, finalState); } @Test - void getDevicesPendingFinalState() throws JsonProcessingException { + void getSamples() throws JsonProcessingException { final String experimentName = "test-experiment"; - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); - final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean(); - final int initialBounciness = ThreadLocalRandom.current().nextInt(); + final UUID initialSampleAccountIdentifier = UUID.randomUUID(); + final byte initialSampleDeviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final boolean initialSampleInExperimentGroup = ThreadLocalRandom.current().nextBoolean(); - //noinspection DataFlowIssue - assertTrue(pushNotificationExperimentSamples.getDevicesPendingFinalState(experimentName).collectList().block().isEmpty()); + final UUID finalSampleAccountIdentifier = UUID.randomUUID(); + final byte finalSampleDeviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final boolean finalSampleInExperimentGroup = ThreadLocalRandom.current().nextBoolean(); - pushNotificationExperimentSamples.recordInitialState(accountIdentifier, - deviceId, - experimentName, - inExperimentGroup, - new TestDeviceState(initialBounciness)) - .join(); - - pushNotificationExperimentSamples.recordInitialState(accountIdentifier, - (byte) (deviceId + 1), - experimentName + "-different", - inExperimentGroup, - new TestDeviceState(initialBounciness)) - .join(); - - assertEquals(List.of(Tuples.of(accountIdentifier, deviceId)), - pushNotificationExperimentSamples.getDevicesPendingFinalState(experimentName).collectList().block()); - } - - @Test - void getFinishedSamples() throws JsonProcessingException { - final String experimentName = "test-experiment"; - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); - final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean(); final int initialBounciness = ThreadLocalRandom.current().nextInt(); final int finalBounciness = initialBounciness + 17; - //noinspection DataFlowIssue - assertTrue(pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block().isEmpty()); - - pushNotificationExperimentSamples.recordInitialState(accountIdentifier, - deviceId, + pushNotificationExperimentSamples.recordInitialState(initialSampleAccountIdentifier, + initialSampleDeviceId, experimentName, - inExperimentGroup, + initialSampleInExperimentGroup, new TestDeviceState(initialBounciness)) .join(); - //noinspection DataFlowIssue - assertTrue(pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block().isEmpty(), - "Publisher should not return unfinished samples"); - - pushNotificationExperimentSamples.recordFinalState(accountIdentifier, - deviceId, - experimentName, - new TestDeviceState(finalBounciness)) - .join(); - - final List> expectedSamples = - List.of(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(initialBounciness), new TestDeviceState(finalBounciness))); - - assertEquals( - expectedSamples, - pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block(), - "Publisher should return finished samples"); - - pushNotificationExperimentSamples.recordInitialState(accountIdentifier, - deviceId, - experimentName + "-different", - inExperimentGroup, + pushNotificationExperimentSamples.recordInitialState(finalSampleAccountIdentifier, + finalSampleDeviceId, + experimentName, + finalSampleInExperimentGroup, new TestDeviceState(initialBounciness)) .join(); - pushNotificationExperimentSamples.recordFinalState(accountIdentifier, - deviceId, - experimentName + "-different", + pushNotificationExperimentSamples.recordFinalState(finalSampleAccountIdentifier, + finalSampleDeviceId, + experimentName, new TestDeviceState(finalBounciness)) .join(); - assertEquals( - expectedSamples, - pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block(), - "Publisher should return finished samples only from named experiment"); + pushNotificationExperimentSamples.recordInitialState(UUID.randomUUID(), + (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID), + experimentName + "-different", + ThreadLocalRandom.current().nextBoolean(), + new TestDeviceState(ThreadLocalRandom.current().nextInt())) + .join(); + + final Set> expectedSamples = Set.of( + new PushNotificationExperimentSample<>(initialSampleAccountIdentifier, + initialSampleDeviceId, + initialSampleInExperimentGroup, + new TestDeviceState(initialBounciness), + null), + + new PushNotificationExperimentSample<>(finalSampleAccountIdentifier, + finalSampleDeviceId, + finalSampleInExperimentGroup, + new TestDeviceState(initialBounciness), + new TestDeviceState(finalBounciness))); + + assertEquals(expectedSamples, + pushNotificationExperimentSamples.getSamples(experimentName, TestDeviceState.class, 1, Schedulers.immediate()).collect(Collectors.toSet()).block()); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java index 795b49547..7e30f1d19 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java @@ -1,10 +1,24 @@ package org.whispersystems.textsecuregcm.workers; -import com.fasterxml.jackson.core.JsonProcessingException; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; import net.sourceforge.argparse4j.inf.Namespace; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSample; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; @@ -12,24 +26,8 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import reactor.core.publisher.Flux; -import reactor.util.function.Tuples; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; - -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyByte; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - class FinishPushNotificationExperimentCommandTest { private CommandDependencies commandDependencies; @@ -39,6 +37,10 @@ class FinishPushNotificationExperimentCommandTest { private static final String EXPERIMENT_NAME = "test"; + private static final Namespace NAMESPACE = new Namespace(Map.of( + FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1, + FinishPushNotificationExperimentCommand.SEGMENT_COUNT_ARGUMENT, 1)); + private static class TestFinishPushNotificationExperimentCommand extends FinishPushNotificationExperimentCommand { public TestFinishPushNotificationExperimentCommand(final PushNotificationExperiment experiment) { @@ -49,14 +51,20 @@ class FinishPushNotificationExperimentCommandTest { } @BeforeEach - void setUp() throws JsonProcessingException { + void setUp() { final AccountsManager accountsManager = mock(AccountsManager.class); final PushNotificationExperimentSamples pushNotificationExperimentSamples = mock(PushNotificationExperimentSamples.class); when(pushNotificationExperimentSamples.recordFinalState(any(), anyByte(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(new PushNotificationExperimentSample<>(true, "test", "test"))); + .thenAnswer(invocation -> { + final UUID accountIdentifier = invocation.getArgument(0); + final byte deviceId = invocation.getArgument(1); + + return CompletableFuture.completedFuture( + new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", "test")); + }); commandDependencies = new CommandDependencies(accountsManager, null, @@ -78,6 +86,7 @@ class FinishPushNotificationExperimentCommandTest { experiment = mock(PushNotificationExperiment.class); when(experiment.getExperimentName()).thenReturn(EXPERIMENT_NAME); when(experiment.getState(any(), any())).thenReturn("test"); + when(experiment.getStateClass()).thenReturn(String.class); doAnswer(invocation -> { final Flux> samples = invocation.getArgument(0); @@ -90,7 +99,7 @@ class FinishPushNotificationExperimentCommandTest { } @Test - void run() throws Exception { + void run() { final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = Device.PRIMARY_ID; @@ -103,69 +112,54 @@ class FinishPushNotificationExperimentCommandTest { when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) - .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); - - assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, - new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), - null, - commandDependencies)); + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); verify(experiment).getState(account, device); - verify(commandDependencies.pushNotificationExperimentSamples()) .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); } @Test - void runMissingAccount() throws Exception { + void runMissingAccount() { final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = Device.PRIMARY_ID; when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) - .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); - - assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, - new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), - null, - commandDependencies)); + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); verify(experiment).getState(null, null); - verify(commandDependencies.pushNotificationExperimentSamples()) .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); } @Test - void runMissingDevice() throws Exception { + void runMissingDevice() { final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = Device.PRIMARY_ID; final Account account = mock(Account.class); - when(account.getDevice(anyByte())).thenReturn(Optional.empty()); + when(account.getDevice(deviceId)).thenReturn(Optional.empty()); when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) - .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); - - assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, - new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), - null, - commandDependencies)); + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); verify(experiment).getState(account, null); - verify(commandDependencies.pushNotificationExperimentSamples()) .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); } @Test - void runAccountFetchRetry() throws Exception { + void runAccountFetchRetry() { final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = Device.PRIMARY_ID; @@ -180,24 +174,19 @@ class FinishPushNotificationExperimentCommandTest { .thenReturn(CompletableFuture.failedFuture(new RuntimeException())) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) - .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); - assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, - new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), - null, - commandDependencies)); + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); + verify(experiment).getState(account, device); + verify(commandDependencies.pushNotificationExperimentSamples()) + .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); verify(commandDependencies.accountsManager(), times(3)).getByAccountIdentifierAsync(accountIdentifier); - - verify(experiment).getState(account, device); - - verify(commandDependencies.pushNotificationExperimentSamples()) - .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); } @Test - void runStoreSampleRetry() throws Exception { + void runStoreSampleRetry() { final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = Device.PRIMARY_ID; @@ -210,27 +199,22 @@ class FinishPushNotificationExperimentCommandTest { when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) - .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any())) .thenReturn(CompletableFuture.failedFuture(new RuntimeException())) .thenReturn(CompletableFuture.failedFuture(new RuntimeException())) - .thenReturn(CompletableFuture.completedFuture(new PushNotificationExperimentSample<>(true, "test", "test"))); - - assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, - new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), - null, - commandDependencies)); + .thenReturn(CompletableFuture.completedFuture(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", "test"))); + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); verify(experiment).getState(account, device); - verify(commandDependencies.pushNotificationExperimentSamples(), times(3)) .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); } @Test - void runMissingInitialSample() throws Exception { + void runMissingInitialSample() { final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = Device.PRIMARY_ID; @@ -243,20 +227,27 @@ class FinishPushNotificationExperimentCommandTest { when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) - .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + .thenReturn(Flux.just(new PushNotificationExperimentSample<>(accountIdentifier, deviceId, true, "test", null))); when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any())) .thenReturn(CompletableFuture.failedFuture(ConditionalCheckFailedException.builder().build())); - assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, - new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), - null, - commandDependencies)); - + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); verify(experiment).getState(account, device); - verify(commandDependencies.pushNotificationExperimentSamples()) .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); } + + @Test + void runFinalSampleAlreadyRecorded() { + when(commandDependencies.pushNotificationExperimentSamples().getSamples(eq(EXPERIMENT_NAME), eq(String.class), anyInt(), any())) + .thenReturn(Flux.just(new PushNotificationExperimentSample<>(UUID.randomUUID(), Device.PRIMARY_ID, true, "test", "test"))); + + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, NAMESPACE, null, commandDependencies)); + verify(commandDependencies.accountsManager(), never()).getByAccountIdentifier(any()); + verify(experiment, never()).getState(any(), any()); + verify(commandDependencies.pushNotificationExperimentSamples(), never()) + .recordFinalState(any(), anyByte(), any(), any()); + } }