From 4ebad2c473c116e61991e3fc01589505d194f4a2 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Thu, 25 Jul 2024 11:36:05 -0400 Subject: [PATCH] Add a framework for running experiments to improve push notification reliability --- service/config/sample.yml | 2 + .../configuration/DynamoDbTables.java | 9 + .../PushNotificationExperiment.java | 68 ++++ .../PushNotificationExperimentSample.java | 4 + .../PushNotificationExperimentSamples.java | 284 +++++++++++++++ .../workers/CommandDependencies.java | 36 +- ...hNotificationExperimentSamplesCommand.java | 62 ++++ ...nishPushNotificationExperimentCommand.java | 117 +++++++ .../PushNotificationExperimentFactory.java | 10 + ...nPushNotificationSenderServiceCommand.java | 6 +- ...tartPushNotificationExperimentCommand.java | 133 +++++++ ...PushNotificationExperimentSamplesTest.java | 331 ++++++++++++++++++ .../storage/DynamoDbExtensionSchema.java | 15 + ...PushNotificationExperimentCommandTest.java | 252 +++++++++++++ ...PushNotificationExperimentCommandTest.java | 166 +++++++++ service/src/test/resources/config/test.yml | 2 + 16 files changed, 1489 insertions(+), 8 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperiment.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSample.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/workers/DiscardPushNotificationExperimentSamplesCommand.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/workers/PushNotificationExperimentFactory.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommand.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommandTest.java diff --git a/service/config/sample.yml b/service/config/sample.yml index ab500ac2d..4800d9caa 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -116,6 +116,8 @@ dynamoDbTables: tableName: Example_Profiles pushChallenge: tableName: Example_PushChallenge + pushNotificationExperimentSamples: + tableName: Example_PushNotificationExperimentSamples redeemedReceipts: tableName: Example_RedeemedReceipts expiration: P30D # Duration of time until rows expire diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java index bb0198ce2..0d221c638 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java @@ -63,6 +63,7 @@ public class DynamoDbTables { private final Table phoneNumberIdentifiers; private final Table profiles; private final Table pushChallenge; + private final Table pushNotificationExperimentSamples; private final TableWithExpiration redeemedReceipts; private final TableWithExpiration registrationRecovery; private final Table remoteConfig; @@ -88,6 +89,7 @@ public class DynamoDbTables { @JsonProperty("phoneNumberIdentifiers") final Table phoneNumberIdentifiers, @JsonProperty("profiles") final Table profiles, @JsonProperty("pushChallenge") final Table pushChallenge, + @JsonProperty("pushNotificationExperimentSamples") final Table pushNotificationExperimentSamples, @JsonProperty("redeemedReceipts") final TableWithExpiration redeemedReceipts, @JsonProperty("registrationRecovery") final TableWithExpiration registrationRecovery, @JsonProperty("remoteConfig") final Table remoteConfig, @@ -112,6 +114,7 @@ public class DynamoDbTables { this.phoneNumberIdentifiers = phoneNumberIdentifiers; this.profiles = profiles; this.pushChallenge = pushChallenge; + this.pushNotificationExperimentSamples = pushNotificationExperimentSamples; this.redeemedReceipts = redeemedReceipts; this.registrationRecovery = registrationRecovery; this.remoteConfig = remoteConfig; @@ -217,6 +220,12 @@ public class DynamoDbTables { return pushChallenge; } + @NotNull + @Valid + public Table getPushNotificationExperimentSamples() { + return pushNotificationExperimentSamples; + } + @NotNull @Valid public TableWithExpiration getRedeemedReceipts() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperiment.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperiment.java new file mode 100644 index 000000000..92ec44f3c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperiment.java @@ -0,0 +1,68 @@ +package org.whispersystems.textsecuregcm.experiment; + +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import javax.annotation.Nullable; +import java.util.concurrent.CompletableFuture; + +/** + * A push notification selects for eligible devices, applies a control or experimental treatment, and provides a + * mechanism for comparing device states before and after receiving the treatment. + * + * @param the type of state object stored for this experiment + */ +public interface PushNotificationExperiment { + + /** + * Returns the unique name of this experiment. + * + * @return the unique name of this experiment + */ + String getExperimentName(); + + /** + * Tests whether a device is eligible for this experiment. An eligible device may be assigned to either the control + * or experiment group within an experiment. Ineligible devices will not participate in the experiment in any way. + * + * @param account the account to which the device belongs + * @param device the device to test for eligibility in this experiment + * + * @return a future that yields a boolean value indicating whether the target device is eligible for this experiment + */ + CompletableFuture isDeviceEligible(Account account, Device device); + + /** + * Generates an experiment specific state "snapshot" of the given device. Experiment results are generally evaluated + * by comparing a device's state before a treatment is applied and its state after the treatment is applied. + * + * @param account the account to which the device belongs + * @param device the device for which to generate a state "snapshot" + * + * @return an experiment-specific state "snapshot" of the given device + */ + T getState(@Nullable Account account, @Nullable Device device); + + /** + * Applies a control treatment to the given device. In many cases (and by default) no action is taken for devices in + * the control group. + * + * @param account the account to which the device belongs + * @param device the device to which to apply the control treatment for this experiment + * + * @return a future that completes when the control treatment has been applied for the given device + */ + default CompletableFuture applyControlTreatment(Account account, Device device) { + return CompletableFuture.completedFuture(null); + }; + + /** + * Applies an experimental treatment to the given device. This generally involves sending or scheduling a specific + * type of push notification for the given device. + * + * @param account the account to which the device belongs + * @param device the device to which to apply the experimental treatment for this experiment + * + * @return a future that completes when the experimental treatment has been applied for the given device + */ + CompletableFuture applyExperimentTreatment(Account account, Device device); +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSample.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSample.java new file mode 100644 index 000000000..01c034cda --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSample.java @@ -0,0 +1,4 @@ +package org.whispersystems.textsecuregcm.experiment; + +public record PushNotificationExperimentSample(boolean inExperimentGroup, T initialState, T finalState) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java new file mode 100644 index 000000000..bb8f310d5 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamples.java @@ -0,0 +1,284 @@ +package org.whispersystems.textsecuregcm.experiment; + +import com.fasterxml.jackson.core.JsonProcessingException; +import java.nio.ByteBuffer; +import java.time.Clock; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; +import org.whispersystems.textsecuregcm.util.SystemMapper; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; +import software.amazon.awssdk.services.dynamodb.model.ReturnValue; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; + +public class PushNotificationExperimentSamples { + + private final DynamoDbAsyncClient dynamoDbAsyncClient; + private final String tableName; + private final Clock clock; + + // Experiment name; DynamoDB string; partition key + public static final String KEY_EXPERIMENT_NAME = "N"; + + // Combined ACI and device ID; DynamoDB byte array; sort key + public static final String ATTR_ACI_AND_DEVICE_ID = "AD"; + + // Whether the device is enrolled in the experiment group (as opposed to control group); DynamoDB boolean + static final String ATTR_IN_EXPERIMENT_GROUP = "X"; + + // The experiment-specific state of the device at the start of the experiment, represented as a JSON blob; DynamoDB + // string + static final String ATTR_INITIAL_STATE = "I"; + + // The experiment-specific state of the device at the end of the experiment, represented as a JSON blob; DynamoDB + // string + static final String ATTR_FINAL_STATE = "F"; + + // The time, in seconds since the epoch, at which this sample should be deleted automatically + static final String ATTR_TTL = "E"; + + private static final Duration FINAL_SAMPLE_TTL = Duration.ofDays(7); + + private static final Logger log = LoggerFactory.getLogger(PushNotificationExperimentSamples.class); + + public PushNotificationExperimentSamples(final DynamoDbAsyncClient dynamoDbAsyncClient, + final String tableName, + final Clock clock) { + + this.dynamoDbAsyncClient = dynamoDbAsyncClient; + this.tableName = tableName; + this.clock = clock; + } + + /** + * Writes the initial state of a device participating in a push notification experiment. + * + * @param accountIdentifier the account identifier for the account to which the target device is linked + * @param deviceId the identifier for the device within the given account + * @param experimentName the name of the experiment + * @param inExperimentGroup whether the given device is in the experiment group (as opposed to control group) + * @param initialState the initial state of the object; must be serializable as a JSON text + * + * @return a future that completes when the record has been stored; the future yields {@code true} if a new record + * was stored or {@code false} if a conflicting record already exists + * + * @param the type of state object for this sample + * + * @throws JsonProcessingException if the given {@code initialState} could not be serialized as a JSON text + */ + public CompletableFuture recordInitialState(final UUID accountIdentifier, + final byte deviceId, + final String experimentName, + final boolean inExperimentGroup, + final T initialState) throws JsonProcessingException { + + final AttributeValue initialStateAttributeValue = + AttributeValue.fromS(SystemMapper.jsonMapper().writeValueAsString(initialState)); + + final AttributeValue inExperimentGroupAttributeValue = AttributeValue.fromBool(inExperimentGroup); + + return dynamoDbAsyncClient.putItem(PutItemRequest.builder() + .tableName(tableName) + .item(Map.of( + KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName), + ATTR_ACI_AND_DEVICE_ID, buildSortKey(accountIdentifier, deviceId), + ATTR_IN_EXPERIMENT_GROUP, inExperimentGroupAttributeValue, + ATTR_INITIAL_STATE, initialStateAttributeValue, + ATTR_TTL, AttributeValue.fromN(String.valueOf(clock.instant().plus(FINAL_SAMPLE_TTL).getEpochSecond())))) + .conditionExpression("(attribute_not_exists(#inExperimentGroup) OR #inExperimentGroup = :inExperimentGroup) AND (attribute_not_exists(#initialState) OR #initialState = :initialState) AND attribute_not_exists(#finalState)") + .expressionAttributeNames(Map.of( + "#inExperimentGroup", ATTR_IN_EXPERIMENT_GROUP, + "#initialState", ATTR_INITIAL_STATE, + "#finalState", ATTR_FINAL_STATE)) + .expressionAttributeValues(Map.of( + ":inExperimentGroup", inExperimentGroupAttributeValue, + ":initialState", initialStateAttributeValue)) + .build()) + .thenApply(ignored -> true) + .exceptionally(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof ConditionalCheckFailedException) { + return false; + } + + throw ExceptionUtils.wrap(throwable); + }); + } + + /** + * Writes the final state of a device participating in a push notification experiment. + * + * @param accountIdentifier the account identifier for the account to which the target device is linked + * @param deviceId the identifier for the device within the given account + * @param experimentName the name of the experiment + * @param finalState the final state of the object; must be serializable as a JSON text and of the same type as the + * previously-stored initial state + + * @return a future that completes when the final state has been stored; yields a finished sample if an initial sample + * was found or empty if no initial sample was found for the given account, device, and experiment + * + * @param the type of state object for this sample + * + * @throws JsonProcessingException if the given {@code finalState} could not be serialized as a JSON text + */ + public CompletableFuture> recordFinalState(final UUID accountIdentifier, + final byte deviceId, + final String experimentName, + final T finalState) throws JsonProcessingException { + + final AttributeValue aciAndDeviceIdAttributeValue = buildSortKey(accountIdentifier, deviceId); + + return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder() + .tableName(tableName) + .key(Map.of( + KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName), + ATTR_ACI_AND_DEVICE_ID, aciAndDeviceIdAttributeValue)) + // `UpdateItem` will, by default, create a new item if one does not already exist for the given primary key. We + // want update-only-if-exists behavior, though, and so check that there's already an existing item for this ACI + // and device ID. + .conditionExpression("#aciAndDeviceId = :aciAndDeviceId") + .updateExpression("SET #finalState = if_not_exists(#finalState, :finalState)") + .expressionAttributeNames(Map.of( + "#aciAndDeviceId", ATTR_ACI_AND_DEVICE_ID, + "#finalState", ATTR_FINAL_STATE)) + .expressionAttributeValues(Map.of( + ":aciAndDeviceId", aciAndDeviceIdAttributeValue, + ":finalState", AttributeValue.fromS(SystemMapper.jsonMapper().writeValueAsString(finalState)))) + .returnValues(ReturnValue.ALL_NEW) + .build()) + .thenApply(updateItemResponse -> { + try { + final boolean inExperimentGroup = updateItemResponse.attributes().get(ATTR_IN_EXPERIMENT_GROUP).bool(); + + @SuppressWarnings("unchecked") final T parsedInitialState = + (T) parseState(updateItemResponse.attributes().get(ATTR_INITIAL_STATE).s(), finalState.getClass()); + + @SuppressWarnings("unchecked") final T parsedFinalState = + (T) parseState(updateItemResponse.attributes().get(ATTR_FINAL_STATE).s(), finalState.getClass()); + + return new PushNotificationExperimentSample<>(inExperimentGroup, parsedInitialState, parsedFinalState); + } catch (final JsonProcessingException e) { + throw ExceptionUtils.wrap(e); + } + }); + } + + /** + * Returns a publisher across all samples pending a final state for a given experiment. + * + * @param experimentName the name of the experiment for which to retrieve samples pending a final state + * + * @return a publisher across all samples pending a final state for a given experiment + */ + public Flux> getDevicesPendingFinalState(final String experimentName) { + return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() + .tableName(tableName) + .keyConditionExpression("#experiment = :experiment") + .filterExpression("attribute_not_exists(#finalState)") + .expressionAttributeNames(Map.of( + "#experiment", KEY_EXPERIMENT_NAME, + "#finalState", ATTR_FINAL_STATE)) + .expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(experimentName))) + .projectionExpression(ATTR_ACI_AND_DEVICE_ID) + .build()) + .items()) + .map(item -> parseSortKey(item.get(ATTR_ACI_AND_DEVICE_ID))); + } + + /** + * Returns a publisher across all finished samples (i.e. samples with a recorded final state) for a given experiment. + * + * @param experimentName the name of the experiment for which to retrieve finished samples + * @param stateClass the type of state object for sample in the given experiment + * + * @return a publisher across all finished samples for the given experiment + * + * @param the type of the sample's state objects + */ + public Flux> getFinishedSamples(final String experimentName, + final Class stateClass) { + return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() + .tableName(tableName) + .keyConditionExpression("#experiment = :experiment") + .filterExpression("attribute_exists(#finalState)") + .expressionAttributeNames(Map.of( + "#experiment", KEY_EXPERIMENT_NAME, + "#finalState", ATTR_FINAL_STATE)) + .expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(experimentName))) + .build()) + .items()) + .handle((item, sink) -> { + try { + final boolean inExperimentGroup = item.get(ATTR_IN_EXPERIMENT_GROUP).bool(); + final T initialState = parseState(item.get(ATTR_INITIAL_STATE).s(), stateClass); + final T finalState = parseState(item.get(ATTR_FINAL_STATE).s(), stateClass); + + sink.next(new PushNotificationExperimentSample<>(inExperimentGroup, initialState, finalState)); + } catch (final JsonProcessingException e) { + sink.error(e); + } + }); + } + + public CompletableFuture discardSamples(final String experimentName, final int maxConcurrency) { + final AttributeValue experimentNameAttributeValue = AttributeValue.fromS(experimentName); + + return Flux.from(dynamoDbAsyncClient.scanPaginator(ScanRequest.builder() + .tableName(tableName) + .filterExpression("#experiment = :experiment") + .expressionAttributeNames(Map.of("#experiment", KEY_EXPERIMENT_NAME)) + .expressionAttributeValues(Map.of(":experiment", experimentNameAttributeValue)) + .projectionExpression(ATTR_ACI_AND_DEVICE_ID) + .build()) + .items()) + .map(item -> item.get(ATTR_ACI_AND_DEVICE_ID)) + .flatMap(aciAndDeviceId -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(DeleteItemRequest.builder() + .tableName(tableName) + .key(Map.of( + KEY_EXPERIMENT_NAME, experimentNameAttributeValue, + ATTR_ACI_AND_DEVICE_ID, aciAndDeviceId)) + .build())) + .retryWhen(Retry.backoff(5, Duration.ofSeconds(5))) + .onErrorResume(throwable -> { + log.warn("Failed to delete sample for experiment {}", experimentName, throwable); + return Mono.empty(); + }), maxConcurrency) + .then() + .toFuture(); + } + + @VisibleForTesting + static AttributeValue buildSortKey(final UUID accountIdentifier, final byte deviceId) { + return AttributeValue.fromB(SdkBytes.fromByteBuffer(ByteBuffer.allocate(17) + .putLong(accountIdentifier.getMostSignificantBits()) + .putLong(accountIdentifier.getLeastSignificantBits()) + .put(deviceId) + .flip())); + } + + private static Tuple2 parseSortKey(final AttributeValue sortKey) { + final ByteBuffer byteBuffer = sortKey.b().asByteBuffer(); + + return Tuples.of(new UUID(byteBuffer.getLong(), byteBuffer.getLong()), byteBuffer.get()); + } + + private static T parseState(final String state, final Class clazz) throws JsonProcessingException { + return SystemMapper.jsonMapper().readValue(state, clazz); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index f12445273..cd7e4589b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -11,6 +11,8 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import io.dropwizard.core.setup.Environment; import io.lettuce.core.resource.ClientResources; import java.io.IOException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateException; import java.time.Clock; import java.util.concurrent.ExecutorService; @@ -28,8 +30,13 @@ import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.push.APNSender; +import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.FcmSender; +import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; @@ -68,7 +75,10 @@ record CommandDependencies( MessagesManager messagesManager, ClientPresenceManager clientPresenceManager, KeysManager keysManager, + PushNotificationManager pushNotificationManager, + PushNotificationExperimentSamples pushNotificationExperimentSamples, FaultTolerantRedisCluster cacheCluster, + FaultTolerantRedisCluster pushSchedulerCluster, ClientResources.Builder redisClusterClientResourcesBuilder, BackupManager backupManager, DynamicConfigurationManager dynamicConfigurationManager) { @@ -76,7 +86,8 @@ record CommandDependencies( static CommandDependencies build( final String name, final Environment environment, - final WhisperServerConfiguration configuration) throws IOException, CertificateException { + final WhisperServerConfiguration configuration) + throws IOException, CertificateException, NoSuchAlgorithmException, InvalidKeyException { Clock clock = Clock.systemUTC(); environment.getObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); @@ -92,8 +103,10 @@ record CommandDependencies( final ClientResources.Builder redisClientResourcesBuilder = ClientResources.builder(); - FaultTolerantRedisCluster cacheCluster = configuration.getCacheClusterConfiguration().build("main_cache", - redisClientResourcesBuilder); + FaultTolerantRedisCluster cacheCluster = configuration.getCacheClusterConfiguration() + .build("main_cache", redisClientResourcesBuilder); + FaultTolerantRedisCluster pushSchedulerCluster = configuration.getPushSchedulerCluster() + .build("push_scheduler", redisClientResourcesBuilder); ScheduledExecutorService recurringJobExecutor = environment.lifecycle() .scheduledExecutorService(name(name, "recurringJob-%d")).threads(2).build(); @@ -115,6 +128,10 @@ record CommandDependencies( .executorService(name(name, "remoteStorage-%d")) .minThreads(0).maxThreads(Integer.MAX_VALUE).workQueue(new SynchronousQueue<>()) .keepAliveTime(io.dropwizard.util.Duration.seconds(60L)).build(); + ExecutorService apnSenderExecutor = environment.lifecycle().executorService(name(name, "apnSender-%d")) + .maxThreads(1).minThreads(1).build(); + ExecutorService fcmSenderExecutor = environment.lifecycle().executorService(name(name, "fcmSender-%d")) + .maxThreads(16).minThreads(16).build(); ScheduledExecutorService secureValueRecoveryServiceRetryExecutor = environment.lifecycle() .scheduledExecutorService(name(name, "secureValueRecoveryServiceRetry-%d")).threads(1).build(); @@ -225,6 +242,16 @@ record CommandDependencies( remoteStorageRetryExecutor, configuration.getCdn3StorageManagerConfiguration()), clock); + APNSender apnSender = new APNSender(apnSenderExecutor, configuration.getApnConfiguration()); + FcmSender fcmSender = new FcmSender(fcmSenderExecutor, configuration.getFcmConfiguration().credentials().value()); + ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, + apnSender, accountsManager, 0); + PushNotificationManager pushNotificationManager = + new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler); + PushNotificationExperimentSamples pushNotificationExperimentSamples = + new PushNotificationExperimentSamples(dynamoDbAsyncClient, + configuration.getDynamoDbTables().getPushNotificationExperimentSamples().getTableName(), + Clock.systemUTC()); environment.lifecycle().manage(messagesCache); environment.lifecycle().manage(clientPresenceManager); @@ -238,7 +265,10 @@ record CommandDependencies( messagesManager, clientPresenceManager, keys, + pushNotificationManager, + pushNotificationExperimentSamples, cacheCluster, + pushSchedulerCluster, redisClientResourcesBuilder, backupManager, dynamicConfigurationManager diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/DiscardPushNotificationExperimentSamplesCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/DiscardPushNotificationExperimentSamplesCommand.java new file mode 100644 index 000000000..497208605 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/DiscardPushNotificationExperimentSamplesCommand.java @@ -0,0 +1,62 @@ +package org.whispersystems.textsecuregcm.workers; + +import com.google.common.annotations.VisibleForTesting; +import io.dropwizard.core.Application; +import io.dropwizard.core.setup.Environment; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment; + +public class DiscardPushNotificationExperimentSamplesCommand extends AbstractCommandWithDependencies { + + private final PushNotificationExperimentFactory experimentFactory; + + private static final int DEFAULT_MAX_CONCURRENCY = 16; + + @VisibleForTesting + static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; + + private static final Logger log = LoggerFactory.getLogger(DiscardPushNotificationExperimentSamplesCommand.class); + + public DiscardPushNotificationExperimentSamplesCommand(final String name, + final String description, + final PushNotificationExperimentFactory experimentFactory) { + + super(new Application<>() { + @Override + public void run(final WhisperServerConfiguration configuration, final Environment environment) { + } + }, name, description); + + this.experimentFactory = experimentFactory; + } + + @Override + public void configure(final Subparser subparser) { + super.configure(subparser); + + subparser.addArgument("--max-concurrency") + .type(Integer.class) + .dest(MAX_CONCURRENCY_ARGUMENT) + .setDefault(DEFAULT_MAX_CONCURRENCY) + .help("Max concurrency for DynamoDB operations"); + } + + @Override + protected void run(final Environment environment, + final Namespace namespace, + final WhisperServerConfiguration configuration, + final CommandDependencies commandDependencies) throws Exception { + + final PushNotificationExperiment experiment = + experimentFactory.buildExperiment(commandDependencies, configuration); + + final int maxConcurrency = namespace.getInt(MAX_CONCURRENCY_ARGUMENT); + + commandDependencies.pushNotificationExperimentSamples() + .discardSamples(experiment.getExperimentName(), maxConcurrency).join(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java new file mode 100644 index 000000000..38f7bd345 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommand.java @@ -0,0 +1,117 @@ +package org.whispersystems.textsecuregcm.workers; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.common.annotations.VisibleForTesting; +import io.dropwizard.core.Application; +import io.dropwizard.core.setup.Environment; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import javax.annotation.Nullable; +import java.time.Duration; +import java.util.UUID; + +public class FinishPushNotificationExperimentCommand extends AbstractCommandWithDependencies { + + private final PushNotificationExperimentFactory experimentFactory; + + private static final int DEFAULT_MAX_CONCURRENCY = 16; + + @VisibleForTesting + static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; + + private static final Logger log = LoggerFactory.getLogger(FinishPushNotificationExperimentCommand.class); + + public FinishPushNotificationExperimentCommand(final String name, + final String description, + final PushNotificationExperimentFactory experimentFactory) { + + super(new Application<>() { + @Override + public void run(final WhisperServerConfiguration configuration, final Environment environment) { + } + }, name, description); + + this.experimentFactory = experimentFactory; + } + + @Override + public void configure(final Subparser subparser) { + super.configure(subparser); + + subparser.addArgument("--max-concurrency") + .type(Integer.class) + .dest(MAX_CONCURRENCY_ARGUMENT) + .setDefault(DEFAULT_MAX_CONCURRENCY) + .help("Max concurrency for DynamoDB operations"); + } + + @Override + protected void run(final Environment environment, + final Namespace namespace, + final WhisperServerConfiguration configuration, + final CommandDependencies commandDependencies) throws Exception { + + final PushNotificationExperiment experiment = + experimentFactory.buildExperiment(commandDependencies, configuration); + + final int maxConcurrency = namespace.getInt(MAX_CONCURRENCY_ARGUMENT); + + final AccountsManager accountsManager = commandDependencies.accountsManager(); + final PushNotificationExperimentSamples pushNotificationExperimentSamples = commandDependencies.pushNotificationExperimentSamples(); + + pushNotificationExperimentSamples.getDevicesPendingFinalState(experiment.getExperimentName()) + .flatMap(accountIdentifierAndDeviceId -> + Mono.fromFuture(() -> accountsManager.getByAccountIdentifierAsync(accountIdentifierAndDeviceId.getT1())) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))) + .map(maybeAccount -> Tuples.of(accountIdentifierAndDeviceId.getT1(), accountIdentifierAndDeviceId.getT2(), maybeAccount)), maxConcurrency) + .map(accountIdentifierAndDeviceIdAndMaybeAccount -> { + final UUID accountIdentifier = accountIdentifierAndDeviceIdAndMaybeAccount.getT1(); + final byte deviceId = accountIdentifierAndDeviceIdAndMaybeAccount.getT2(); + + @Nullable final Account account = accountIdentifierAndDeviceIdAndMaybeAccount.getT3() + .orElse(null); + + @Nullable final Device device = accountIdentifierAndDeviceIdAndMaybeAccount.getT3() + .flatMap(a -> a.getDevice(deviceId)) + .orElse(null); + + return Tuples.of(accountIdentifier, deviceId, experiment.getState(account, device)); + }) + .flatMap(accountIdentifierAndDeviceIdAndFinalState -> { + final UUID accountIdentifier = accountIdentifierAndDeviceIdAndFinalState.getT1(); + final byte deviceId = accountIdentifierAndDeviceIdAndFinalState.getT2(); + final T finalState = accountIdentifierAndDeviceIdAndFinalState.getT3(); + + return Mono.fromFuture(() -> { + try { + return pushNotificationExperimentSamples.recordFinalState(accountIdentifier, deviceId, + experiment.getExperimentName(), finalState); + } catch (final JsonProcessingException e) { + throw new RuntimeException(e); + } + }) + .onErrorResume(ConditionalCheckFailedException.class, throwable -> Mono.empty()) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))) + .onErrorResume(throwable -> { + log.warn("Failed to record final state for {}:{} in experiment {}", + accountIdentifier, deviceId, experiment.getExperimentName(), throwable); + + return Mono.empty(); + }); + }, maxConcurrency) + .then() + .block(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/PushNotificationExperimentFactory.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/PushNotificationExperimentFactory.java new file mode 100644 index 000000000..0aa2f6aa2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/PushNotificationExperimentFactory.java @@ -0,0 +1,10 @@ +package org.whispersystems.textsecuregcm.workers; + +import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment; + +public interface PushNotificationExperimentFactory { + + PushNotificationExperiment buildExperiment(CommandDependencies commandDependencies, + WhisperServerConfiguration configuration); +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScheduledApnPushNotificationSenderServiceCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScheduledApnPushNotificationSenderServiceCommand.java index b3d09513f..c3187214f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScheduledApnPushNotificationSenderServiceCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScheduledApnPushNotificationSenderServiceCommand.java @@ -19,7 +19,6 @@ import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.push.APNSender; import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler; -import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler; public class ScheduledApnPushNotificationSenderServiceCommand extends ServerCommand { @@ -64,15 +63,12 @@ public class ScheduledApnPushNotificationSenderServiceCommand extends ServerComm }); } - final FaultTolerantRedisCluster pushSchedulerCluster = configuration.getPushSchedulerCluster() - .build("push_scheduler", deps.redisClusterClientResourcesBuilder()); - final ExecutorService apnSenderExecutor = environment.lifecycle().executorService(name(getClass(), "apnSender-%d")) .maxThreads(1).minThreads(1).build(); final APNSender apnSender = new APNSender(apnSenderExecutor, configuration.getApnConfiguration()); final ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler( - pushSchedulerCluster, apnSender, deps.accountsManager(), namespace.getInt(WORKER_COUNT)); + deps.pushSchedulerCluster(), apnSender, deps.accountsManager(), namespace.getInt(WORKER_COUNT)); environment.lifecycle().manage(apnSender); environment.lifecycle().manage(apnPushNotificationScheduler); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommand.java new file mode 100644 index 000000000..2e68cd10f --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommand.java @@ -0,0 +1,133 @@ +package org.whispersystems.textsecuregcm.workers; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import net.sourceforge.argparse4j.inf.Subparser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; +import java.io.UncheckedIOException; +import java.time.Duration; +import java.util.UUID; + +public class StartPushNotificationExperimentCommand extends AbstractSinglePassCrawlAccountsCommand { + + private final PushNotificationExperimentFactory experimentFactory; + + private static final int DEFAULT_MAX_CONCURRENCY = 16; + + @VisibleForTesting + static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; + + private static final Counter INITIAL_SAMPLE_ALREADY_EXISTS_COUNTER = + Metrics.counter(MetricsUtil.name(StartPushNotificationExperimentCommand.class, "initialSampleAlreadyExists")); + + private static final Logger log = LoggerFactory.getLogger(StartPushNotificationExperimentCommand.class); + + public StartPushNotificationExperimentCommand(final String name, + final String description, + final PushNotificationExperimentFactory experimentFactory) { + + super(name, description); + this.experimentFactory = experimentFactory; + } + + @Override + public void configure(final Subparser subparser) { + super.configure(subparser); + + subparser.addArgument("--max-concurrency") + .type(Integer.class) + .dest(MAX_CONCURRENCY_ARGUMENT) + .setDefault(DEFAULT_MAX_CONCURRENCY) + .help("Max concurrency for DynamoDB operations"); + } + + @Override + protected void crawlAccounts(final Flux accounts) { + final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT); + + final PushNotificationExperiment experiment = + experimentFactory.buildExperiment(getCommandDependencies(), getConfiguration()); + + final PushNotificationExperimentSamples pushNotificationExperimentSamples = + getCommandDependencies().pushNotificationExperimentSamples(); + + accounts + .flatMap(account -> Flux.fromIterable(account.getDevices()) + .map(device -> Tuples.of(account, device))) + .filterWhen(accountAndDevice -> Mono.fromFuture(() -> + experiment.isDeviceEligible(accountAndDevice.getT1(), accountAndDevice.getT2())), + maxConcurrency) + .flatMap(accountAndDevice -> { + final UUID accountIdentifier = accountAndDevice.getT1().getIdentifier(IdentityType.ACI); + final byte deviceId = accountAndDevice.getT2().getId(); + + return Mono.fromFuture(() -> { + try { + return pushNotificationExperimentSamples.recordInitialState( + accountIdentifier, + deviceId, + experiment.getExperimentName(), + isInExperimentGroup(accountIdentifier, deviceId, experiment.getExperimentName()), + experiment.getState(accountAndDevice.getT1(), accountAndDevice.getT2())); + } catch (final JsonProcessingException e) { + throw new UncheckedIOException(e); + } + }) + .mapNotNull(stateStored -> { + if (stateStored) { + return accountAndDevice; + } else { + INITIAL_SAMPLE_ALREADY_EXISTS_COUNTER.increment(); + return null; + } + }) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)) + .onRetryExhaustedThrow(((backoffSpec, retrySignal) -> retrySignal.failure()))) + .onErrorResume(throwable -> { + log.warn("Failed to record initial sample for {}:{} in experiment {}", + accountIdentifier, deviceId, experiment.getExperimentName(), throwable); + + return Mono.empty(); + }); + }, maxConcurrency) + .flatMap(accountAndDevice -> { + final Account account = accountAndDevice.getT1(); + final Device device = accountAndDevice.getT2(); + final boolean inExperimentGroup = + isInExperimentGroup(account.getIdentifier(IdentityType.ACI), device.getId(), experiment.getExperimentName()); + + return Mono.fromFuture(() -> inExperimentGroup + ? experiment.applyExperimentTreatment(account, device) + : experiment.applyControlTreatment(account, device)) + .onErrorResume(throwable -> { + log.warn("Failed to apply {} treatment for {}:{} in experiment {}", + inExperimentGroup ? "experimental" : " control", + account.getIdentifier(IdentityType.ACI), + device.getId(), + experiment.getExperimentName(), + throwable); + + return Mono.empty(); + }); + }, maxConcurrency) + .then() + .block(); + } + + private boolean isInExperimentGroup(final UUID accountIdentifier, final byte deviceId, final String experimentName) { + return ((accountIdentifier.hashCode() ^ Byte.hashCode(deviceId) ^ experimentName.hashCode()) & 0x01) != 0; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java new file mode 100644 index 000000000..004167f06 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/PushNotificationExperimentSamplesTest.java @@ -0,0 +1,331 @@ +package org.whispersystems.textsecuregcm.experiment; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.fasterxml.jackson.core.JsonProcessingException; +import java.time.Clock; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; +import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema; +import org.whispersystems.textsecuregcm.util.SystemMapper; +import reactor.util.function.Tuples; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryResponse; +import software.amazon.awssdk.services.dynamodb.model.Select; +import javax.annotation.Nullable; + +class PushNotificationExperimentSamplesTest { + + private PushNotificationExperimentSamples pushNotificationExperimentSamples; + + @RegisterExtension + public static final DynamoDbExtension DYNAMO_DB_EXTENSION = + new DynamoDbExtension(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES); + + private record TestDeviceState(int bounciness) { + } + + @BeforeEach + void setUp() { + pushNotificationExperimentSamples = + new PushNotificationExperimentSamples(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName(), + Clock.systemUTC()); + } + + @Test + void recordInitialState() throws JsonProcessingException { + final String experimentName = "test-experiment"; + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean(); + final int bounciness = ThreadLocalRandom.current().nextInt(); + + assertTrue(pushNotificationExperimentSamples.recordInitialState(accountIdentifier, + deviceId, + experimentName, + inExperimentGroup, + new TestDeviceState(bounciness)) + .join(), + "Attempt to record an initial state should succeed for entirely new records"); + + assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null), + getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class)); + + assertTrue(pushNotificationExperimentSamples.recordInitialState(accountIdentifier, + deviceId, + experimentName, + inExperimentGroup, + new TestDeviceState(bounciness)) + .join(), + "Attempt to re-record an initial state should succeed for existing-but-unchanged records"); + + assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null), + getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class), + "Recorded initial state should be unchanged after repeated write"); + + assertFalse(pushNotificationExperimentSamples.recordInitialState(accountIdentifier, + deviceId, + experimentName, + inExperimentGroup, + new TestDeviceState(bounciness + 1)) + .join(), + "Attempt to record a conflicting initial state should fail"); + + assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null), + getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class), + "Recorded initial state should be unchanged after unsuccessful write"); + + assertFalse(pushNotificationExperimentSamples.recordInitialState(accountIdentifier, + deviceId, + experimentName, + !inExperimentGroup, + new TestDeviceState(bounciness)) + .join(), + "Attempt to record a new group assignment should fail"); + + assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null), + getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class), + "Recorded initial state should be unchanged after unsuccessful write"); + + final int finalBounciness = bounciness + 17; + + pushNotificationExperimentSamples.recordFinalState(accountIdentifier, + deviceId, + experimentName, + new TestDeviceState(finalBounciness)) + .join(); + + assertFalse(pushNotificationExperimentSamples.recordInitialState(accountIdentifier, + deviceId, + experimentName, + inExperimentGroup, + new TestDeviceState(bounciness)) + .join(), + "Attempt to record an initial state should fail for samples with final states"); + + assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), new TestDeviceState(finalBounciness)), + getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class), + "Recorded initial state should be unchanged after unsuccessful write"); + } + + @Test + void recordFinalState() throws JsonProcessingException { + final String experimentName = "test-experiment"; + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean(); + final int initialBounciness = ThreadLocalRandom.current().nextInt(); + final int finalBounciness = initialBounciness + 17; + + { + pushNotificationExperimentSamples.recordInitialState(accountIdentifier, + deviceId, + experimentName, + inExperimentGroup, + new TestDeviceState(initialBounciness)) + .join(); + + final PushNotificationExperimentSample returnedSample = + pushNotificationExperimentSamples.recordFinalState(accountIdentifier, + deviceId, + experimentName, + new TestDeviceState(finalBounciness)) + .join(); + + final PushNotificationExperimentSample expectedSample = + new PushNotificationExperimentSample<>(inExperimentGroup, + new TestDeviceState(initialBounciness), + new TestDeviceState(finalBounciness)); + + assertEquals(expectedSample, returnedSample, + "Attempt to update existing sample without final state should succeed"); + + assertEquals(expectedSample, getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class), + "Attempt to update existing sample without final state should be persisted"); + } + + assertThrows(CompletionException.class, () -> pushNotificationExperimentSamples.recordFinalState(accountIdentifier, + (byte) (deviceId + 1), + experimentName, + new TestDeviceState(finalBounciness)) + .join(), + "Attempts to record a final state without an initial state should fail"); + } + + @SuppressWarnings("SameParameterValue") + private PushNotificationExperimentSample getSample(final UUID accountIdentifier, + final byte deviceId, + final String experimentName, + final Class stateClass) throws JsonProcessingException { + + final GetItemResponse response = DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder() + .tableName(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName()) + .key(Map.of( + PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName), + PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID, PushNotificationExperimentSamples.buildSortKey(accountIdentifier, deviceId))) + .build()); + + final boolean inExperimentGroup = + response.item().get(PushNotificationExperimentSamples.ATTR_IN_EXPERIMENT_GROUP).bool(); + + final T initialState = + SystemMapper.jsonMapper().readValue(response.item().get(PushNotificationExperimentSamples.ATTR_INITIAL_STATE).s(), stateClass); + + @Nullable final T finalState = response.item().containsKey(PushNotificationExperimentSamples.ATTR_FINAL_STATE) + ? SystemMapper.jsonMapper().readValue(response.item().get(PushNotificationExperimentSamples.ATTR_FINAL_STATE).s(), stateClass) + : null; + + return new PushNotificationExperimentSample<>(inExperimentGroup, initialState, finalState); + } + + @Test + void getDevicesPendingFinalState() throws JsonProcessingException { + final String experimentName = "test-experiment"; + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean(); + final int initialBounciness = ThreadLocalRandom.current().nextInt(); + + //noinspection DataFlowIssue + assertTrue(pushNotificationExperimentSamples.getDevicesPendingFinalState(experimentName).collectList().block().isEmpty()); + + pushNotificationExperimentSamples.recordInitialState(accountIdentifier, + deviceId, + experimentName, + inExperimentGroup, + new TestDeviceState(initialBounciness)) + .join(); + + pushNotificationExperimentSamples.recordInitialState(accountIdentifier, + (byte) (deviceId + 1), + experimentName + "-different", + inExperimentGroup, + new TestDeviceState(initialBounciness)) + .join(); + + assertEquals(List.of(Tuples.of(accountIdentifier, deviceId)), + pushNotificationExperimentSamples.getDevicesPendingFinalState(experimentName).collectList().block()); + } + + @Test + void getFinishedSamples() throws JsonProcessingException { + final String experimentName = "test-experiment"; + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean(); + final int initialBounciness = ThreadLocalRandom.current().nextInt(); + final int finalBounciness = initialBounciness + 17; + + //noinspection DataFlowIssue + assertTrue(pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block().isEmpty()); + + pushNotificationExperimentSamples.recordInitialState(accountIdentifier, + deviceId, + experimentName, + inExperimentGroup, + new TestDeviceState(initialBounciness)) + .join(); + + //noinspection DataFlowIssue + assertTrue(pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block().isEmpty(), + "Publisher should not return unfinished samples"); + + pushNotificationExperimentSamples.recordFinalState(accountIdentifier, + deviceId, + experimentName, + new TestDeviceState(finalBounciness)) + .join(); + + final List> expectedSamples = + List.of(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(initialBounciness), new TestDeviceState(finalBounciness))); + + assertEquals( + expectedSamples, + pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block(), + "Publisher should return finished samples"); + + pushNotificationExperimentSamples.recordInitialState(accountIdentifier, + deviceId, + experimentName + "-different", + inExperimentGroup, + new TestDeviceState(initialBounciness)) + .join(); + + pushNotificationExperimentSamples.recordFinalState(accountIdentifier, + deviceId, + experimentName + "-different", + new TestDeviceState(finalBounciness)) + .join(); + + assertEquals( + expectedSamples, + pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block(), + "Publisher should return finished samples only from named experiment"); + } + + @Test + void discardSamples() throws JsonProcessingException { + final String discardSamplesExperimentName = "discard-experiment"; + final String retainSamplesExperimentName = "retain-experiment"; + final int sampleCount = 16; + + for (int i = 0; i < sampleCount; i++) { + pushNotificationExperimentSamples.recordInitialState(UUID.randomUUID(), + Device.PRIMARY_ID, + discardSamplesExperimentName, + ThreadLocalRandom.current().nextBoolean(), + new TestDeviceState(ThreadLocalRandom.current().nextInt())) + .join(); + + pushNotificationExperimentSamples.recordInitialState(UUID.randomUUID(), + Device.PRIMARY_ID, + retainSamplesExperimentName, + ThreadLocalRandom.current().nextBoolean(), + new TestDeviceState(ThreadLocalRandom.current().nextInt())) + .join(); + } + + pushNotificationExperimentSamples.discardSamples(discardSamplesExperimentName, 1).join(); + + { + final QueryResponse queryResponse = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().query(QueryRequest.builder() + .tableName(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName()) + .keyConditionExpression("#experiment = :experiment") + .expressionAttributeNames(Map.of("#experiment", PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME)) + .expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(discardSamplesExperimentName))) + .select(Select.COUNT) + .build()) + .join(); + + assertEquals(0, queryResponse.count()); + } + + { + final QueryResponse queryResponse = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().query(QueryRequest.builder() + .tableName(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName()) + .keyConditionExpression("#experiment = :experiment") + .expressionAttributeNames(Map.of("#experiment", PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME)) + .expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(retainSamplesExperimentName))) + .select(Select.COUNT) + .build()) + .join(); + + assertEquals(sampleCount, queryResponse.count()); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java index a19740a6e..9ce9100de 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java @@ -9,6 +9,7 @@ import java.util.Collections; import java.util.List; import org.whispersystems.textsecuregcm.backup.BackupsDb; import org.whispersystems.textsecuregcm.scheduler.JobScheduler; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex; import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; @@ -141,6 +142,20 @@ public final class DynamoDbExtensionSchema { .build()), List.of(), List.of()), + PUSH_NOTIFICATION_EXPERIMENT_SAMPLES("push_notification_experiment_samples_test", + PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME, + PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID, + List.of( + AttributeDefinition.builder() + .attributeName(PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME) + .attributeType(ScalarAttributeType.S) + .build(), + AttributeDefinition.builder() + .attributeName(PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID) + .attributeType(ScalarAttributeType.B) + .build()), + List.of(), List.of()), + REPEATED_USE_EC_SIGNED_PRE_KEYS("repeated_use_signed_ec_pre_keys_test", RepeatedUseSignedPreKeyStore.KEY_ACCOUNT_UUID, RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java new file mode 100644 index 000000000..87ad5ee2b --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/FinishPushNotificationExperimentCommandTest.java @@ -0,0 +1,252 @@ +package org.whispersystems.textsecuregcm.workers; + +import com.fasterxml.jackson.core.JsonProcessingException; +import net.sourceforge.argparse4j.inf.Namespace; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSample; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import reactor.core.publisher.Flux; +import reactor.util.function.Tuples; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; + +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class FinishPushNotificationExperimentCommandTest { + + private CommandDependencies commandDependencies; + private PushNotificationExperiment experiment; + + private FinishPushNotificationExperimentCommand finishPushNotificationExperimentCommand; + + private static final String EXPERIMENT_NAME = "test"; + + private static class TestFinishPushNotificationExperimentCommand extends FinishPushNotificationExperimentCommand { + + public TestFinishPushNotificationExperimentCommand(final PushNotificationExperiment experiment) { + super("test-finish-push-notification-experiment", + "Test start push notification experiment command", + (ignoredDependencies, ignoredConfiguration) -> experiment); + } + } + + @BeforeEach + void setUp() throws JsonProcessingException { + final AccountsManager accountsManager = mock(AccountsManager.class); + + final PushNotificationExperimentSamples pushNotificationExperimentSamples = + mock(PushNotificationExperimentSamples.class); + + when(pushNotificationExperimentSamples.recordFinalState(any(), anyByte(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(new PushNotificationExperimentSample<>(true, "test", "test"))); + + commandDependencies = new CommandDependencies(accountsManager, + null, + null, + null, + null, + null, + null, + null, + pushNotificationExperimentSamples, + null, + null, + null, + null, + null); + + //noinspection unchecked + experiment = mock(PushNotificationExperiment.class); + when(experiment.getExperimentName()).thenReturn(EXPERIMENT_NAME); + when(experiment.getState(any(), any())).thenReturn("test"); + + finishPushNotificationExperimentCommand = new TestFinishPushNotificationExperimentCommand(experiment); + } + + @Test + void run() throws Exception { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + final Device device = mock(Device.class); + when(device.getId()).thenReturn(deviceId); + + final Account account = mock(Account.class); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + + when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + + when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) + .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); + + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, + new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), + null, + commandDependencies)); + + verify(experiment).getState(account, device); + + verify(commandDependencies.pushNotificationExperimentSamples()) + .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); + } + + @Test + void runMissingAccount() throws Exception { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) + .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); + + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, + new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), + null, + commandDependencies)); + + verify(experiment).getState(null, null); + + verify(commandDependencies.pushNotificationExperimentSamples()) + .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); + } + + @Test + void runMissingDevice() throws Exception { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + final Account account = mock(Account.class); + when(account.getDevice(anyByte())).thenReturn(Optional.empty()); + + when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + + when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) + .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); + + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, + new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), + null, + commandDependencies)); + + verify(experiment).getState(account, null); + + verify(commandDependencies.pushNotificationExperimentSamples()) + .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); + } + + @Test + void runAccountFetchRetry() throws Exception { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + final Device device = mock(Device.class); + when(device.getId()).thenReturn(deviceId); + + final Account account = mock(Account.class); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + + when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + + when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) + .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); + + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, + new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), + null, + commandDependencies)); + + verify(commandDependencies.accountsManager(), times(3)).getByAccountIdentifierAsync(accountIdentifier); + + verify(experiment).getState(account, device); + + verify(commandDependencies.pushNotificationExperimentSamples()) + .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); + } + + @Test + void runStoreSampleRetry() throws Exception { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + final Device device = mock(Device.class); + when(device.getId()).thenReturn(deviceId); + + final Account account = mock(Account.class); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + + when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + + when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) + .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); + + when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any())) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())) + .thenReturn(CompletableFuture.completedFuture(new PushNotificationExperimentSample<>(true, "test", "test"))); + + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, + new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), + null, + commandDependencies)); + + verify(experiment).getState(account, device); + + verify(commandDependencies.pushNotificationExperimentSamples(), times(3)) + .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); + } + + @Test + void runMissingInitialSample() throws Exception { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + final Device device = mock(Device.class); + when(device.getId()).thenReturn(deviceId); + + final Account account = mock(Account.class); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + + when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + + when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME)) + .thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId))); + + when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any())) + .thenReturn(CompletableFuture.failedFuture(ConditionalCheckFailedException.builder().build())); + + assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null, + new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)), + null, + commandDependencies)); + + verify(experiment).getState(account, device); + + verify(commandDependencies.pushNotificationExperimentSamples()) + .recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommandTest.java new file mode 100644 index 000000000..308a817a7 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/StartPushNotificationExperimentCommandTest.java @@ -0,0 +1,166 @@ +package org.whispersystems.textsecuregcm.workers; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.core.JsonProcessingException; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import net.sourceforge.argparse4j.inf.Namespace; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment; +import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import reactor.core.publisher.Flux; + +class StartPushNotificationExperimentCommandTest { + + private PushNotificationExperimentSamples pushNotificationExperimentSamples; + private PushNotificationExperiment experiment; + + private StartPushNotificationExperimentCommand startPushNotificationExperimentCommand; + + // Taken together, these parameters will produce a device that's enrolled in the experimental group (as opposed to the + // control group) for an experiment. + private static final UUID ACCOUNT_IDENTIFIER = UUID.fromString("341fb18f-9dee-4181-bc40-e485958341d3"); + private static final byte DEVICE_ID = Device.PRIMARY_ID; + private static final String EXPERIMENT_NAME = "test"; + + private static class TestStartPushNotificationExperimentCommand extends StartPushNotificationExperimentCommand { + + private final CommandDependencies commandDependencies; + + public TestStartPushNotificationExperimentCommand( + final PushNotificationExperimentSamples pushNotificationExperimentSamples, + final PushNotificationExperiment experiment) { + + super("test-start-push-notification-experiment", + "Test start push notification experiment command", + (ignoredDependencies, ignoredConfiguration) -> experiment); + + this.commandDependencies = new CommandDependencies(null, + null, + null, + null, + null, + null, + null, + null, + pushNotificationExperimentSamples, + null, + null, + null, + null, + null); + } + + @Override + protected Namespace getNamespace() { + return new Namespace(Map.of(StartPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)); + } + + @Override + protected CommandDependencies getCommandDependencies() { + return commandDependencies; + } + } + + @BeforeEach + void setUp() { + //noinspection unchecked + experiment = mock(PushNotificationExperiment.class); + when(experiment.getExperimentName()).thenReturn(EXPERIMENT_NAME); + when(experiment.isDeviceEligible(any(), any())).thenReturn(CompletableFuture.completedFuture(true)); + when(experiment.getState(any(), any())).thenReturn("test"); + when(experiment.applyExperimentTreatment(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + + pushNotificationExperimentSamples = mock(PushNotificationExperimentSamples.class); + + try { + when(pushNotificationExperimentSamples.recordInitialState(any(), anyByte(), any(), anyBoolean(), any())) + .thenReturn(CompletableFuture.completedFuture(true)); + } catch (final JsonProcessingException e) { + throw new AssertionError(e); + } + + startPushNotificationExperimentCommand = + new TestStartPushNotificationExperimentCommand(pushNotificationExperimentSamples, experiment); + } + + @Test + void crawlAccounts() { + final Device device = mock(Device.class); + when(device.getId()).thenReturn(DEVICE_ID); + + final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER); + when(account.getDevices()).thenReturn(List.of(device)); + + assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account))); + verify(experiment).applyExperimentTreatment(account, device); + } + + @Test + void crawlAccountsExistingSample() throws JsonProcessingException { + final Device device = mock(Device.class); + when(device.getId()).thenReturn(DEVICE_ID); + + final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER); + when(account.getDevices()).thenReturn(List.of(device)); + + when(pushNotificationExperimentSamples.recordInitialState(any(), anyByte(), any(), anyBoolean(), any())) + .thenReturn(CompletableFuture.completedFuture(false)); + + assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account))); + verify(experiment, never()).applyExperimentTreatment(account, device); + } + + @Test + void crawlAccountsSampleRetry() throws JsonProcessingException { + final Device device = mock(Device.class); + when(device.getId()).thenReturn(DEVICE_ID); + + final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER); + when(account.getDevices()).thenReturn(List.of(device)); + + when(pushNotificationExperimentSamples.recordInitialState(any(), anyByte(), any(), anyBoolean(), any())) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())) + .thenReturn(CompletableFuture.completedFuture(true)); + + assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account))); + verify(experiment).applyExperimentTreatment(account, device); + verify(pushNotificationExperimentSamples, times(3)) + .recordInitialState(ACCOUNT_IDENTIFIER, DEVICE_ID, EXPERIMENT_NAME, true, "test"); + } + + @Test + void crawlAccountsExperimentException() { + final Device device = mock(Device.class); + when(device.getId()).thenReturn(DEVICE_ID); + + final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER); + when(account.getDevices()).thenReturn(List.of(device)); + + when(experiment.applyExperimentTreatment(account, device)) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())); + + assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account))); + verify(experiment).applyExperimentTreatment(account, device); + } +} diff --git a/service/src/test/resources/config/test.yml b/service/src/test/resources/config/test.yml index 4487311a0..6ea660997 100644 --- a/service/src/test/resources/config/test.yml +++ b/service/src/test/resources/config/test.yml @@ -110,6 +110,8 @@ dynamoDbTables: tableName: profiles_test pushChallenge: tableName: push_challenge_test + pushNotificationExperimentSamples: + tableName: Example_PushNotificationExperimentSamples redeemedReceipts: tableName: redeemed_receipts_test expiration: P30D # Duration of time until rows expire