Add a framework for running experiments to improve push notification reliability

This commit is contained in:
Jon Chambers 2024-07-25 11:36:05 -04:00 committed by GitHub
parent 1fe6dac760
commit 4ebad2c473
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1489 additions and 8 deletions

View File

@ -116,6 +116,8 @@ dynamoDbTables:
tableName: Example_Profiles
pushChallenge:
tableName: Example_PushChallenge
pushNotificationExperimentSamples:
tableName: Example_PushNotificationExperimentSamples
redeemedReceipts:
tableName: Example_RedeemedReceipts
expiration: P30D # Duration of time until rows expire

View File

@ -63,6 +63,7 @@ public class DynamoDbTables {
private final Table phoneNumberIdentifiers;
private final Table profiles;
private final Table pushChallenge;
private final Table pushNotificationExperimentSamples;
private final TableWithExpiration redeemedReceipts;
private final TableWithExpiration registrationRecovery;
private final Table remoteConfig;
@ -88,6 +89,7 @@ public class DynamoDbTables {
@JsonProperty("phoneNumberIdentifiers") final Table phoneNumberIdentifiers,
@JsonProperty("profiles") final Table profiles,
@JsonProperty("pushChallenge") final Table pushChallenge,
@JsonProperty("pushNotificationExperimentSamples") final Table pushNotificationExperimentSamples,
@JsonProperty("redeemedReceipts") final TableWithExpiration redeemedReceipts,
@JsonProperty("registrationRecovery") final TableWithExpiration registrationRecovery,
@JsonProperty("remoteConfig") final Table remoteConfig,
@ -112,6 +114,7 @@ public class DynamoDbTables {
this.phoneNumberIdentifiers = phoneNumberIdentifiers;
this.profiles = profiles;
this.pushChallenge = pushChallenge;
this.pushNotificationExperimentSamples = pushNotificationExperimentSamples;
this.redeemedReceipts = redeemedReceipts;
this.registrationRecovery = registrationRecovery;
this.remoteConfig = remoteConfig;
@ -217,6 +220,12 @@ public class DynamoDbTables {
return pushChallenge;
}
@NotNull
@Valid
public Table getPushNotificationExperimentSamples() {
return pushNotificationExperimentSamples;
}
@NotNull
@Valid
public TableWithExpiration getRedeemedReceipts() {

View File

@ -0,0 +1,68 @@
package org.whispersystems.textsecuregcm.experiment;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import javax.annotation.Nullable;
import java.util.concurrent.CompletableFuture;
/**
* A push notification selects for eligible devices, applies a control or experimental treatment, and provides a
* mechanism for comparing device states before and after receiving the treatment.
*
* @param <T> the type of state object stored for this experiment
*/
public interface PushNotificationExperiment<T> {
/**
* Returns the unique name of this experiment.
*
* @return the unique name of this experiment
*/
String getExperimentName();
/**
* Tests whether a device is eligible for this experiment. An eligible device may be assigned to either the control
* or experiment group within an experiment. Ineligible devices will not participate in the experiment in any way.
*
* @param account the account to which the device belongs
* @param device the device to test for eligibility in this experiment
*
* @return a future that yields a boolean value indicating whether the target device is eligible for this experiment
*/
CompletableFuture<Boolean> isDeviceEligible(Account account, Device device);
/**
* Generates an experiment specific state "snapshot" of the given device. Experiment results are generally evaluated
* by comparing a device's state before a treatment is applied and its state after the treatment is applied.
*
* @param account the account to which the device belongs
* @param device the device for which to generate a state "snapshot"
*
* @return an experiment-specific state "snapshot" of the given device
*/
T getState(@Nullable Account account, @Nullable Device device);
/**
* Applies a control treatment to the given device. In many cases (and by default) no action is taken for devices in
* the control group.
*
* @param account the account to which the device belongs
* @param device the device to which to apply the control treatment for this experiment
*
* @return a future that completes when the control treatment has been applied for the given device
*/
default CompletableFuture<Void> applyControlTreatment(Account account, Device device) {
return CompletableFuture.completedFuture(null);
};
/**
* Applies an experimental treatment to the given device. This generally involves sending or scheduling a specific
* type of push notification for the given device.
*
* @param account the account to which the device belongs
* @param device the device to which to apply the experimental treatment for this experiment
*
* @return a future that completes when the experimental treatment has been applied for the given device
*/
CompletableFuture<Void> applyExperimentTreatment(Account account, Device device);
}

View File

@ -0,0 +1,4 @@
package org.whispersystems.textsecuregcm.experiment;
public record PushNotificationExperimentSample<T>(boolean inExperimentGroup, T initialState, T finalState) {
}

View File

@ -0,0 +1,284 @@
package org.whispersystems.textsecuregcm.experiment;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.nio.ByteBuffer;
import java.time.Clock;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
import reactor.util.retry.Retry;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
public class PushNotificationExperimentSamples {
private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName;
private final Clock clock;
// Experiment name; DynamoDB string; partition key
public static final String KEY_EXPERIMENT_NAME = "N";
// Combined ACI and device ID; DynamoDB byte array; sort key
public static final String ATTR_ACI_AND_DEVICE_ID = "AD";
// Whether the device is enrolled in the experiment group (as opposed to control group); DynamoDB boolean
static final String ATTR_IN_EXPERIMENT_GROUP = "X";
// The experiment-specific state of the device at the start of the experiment, represented as a JSON blob; DynamoDB
// string
static final String ATTR_INITIAL_STATE = "I";
// The experiment-specific state of the device at the end of the experiment, represented as a JSON blob; DynamoDB
// string
static final String ATTR_FINAL_STATE = "F";
// The time, in seconds since the epoch, at which this sample should be deleted automatically
static final String ATTR_TTL = "E";
private static final Duration FINAL_SAMPLE_TTL = Duration.ofDays(7);
private static final Logger log = LoggerFactory.getLogger(PushNotificationExperimentSamples.class);
public PushNotificationExperimentSamples(final DynamoDbAsyncClient dynamoDbAsyncClient,
final String tableName,
final Clock clock) {
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName;
this.clock = clock;
}
/**
* Writes the initial state of a device participating in a push notification experiment.
*
* @param accountIdentifier the account identifier for the account to which the target device is linked
* @param deviceId the identifier for the device within the given account
* @param experimentName the name of the experiment
* @param inExperimentGroup whether the given device is in the experiment group (as opposed to control group)
* @param initialState the initial state of the object; must be serializable as a JSON text
*
* @return a future that completes when the record has been stored; the future yields {@code true} if a new record
* was stored or {@code false} if a conflicting record already exists
*
* @param <T> the type of state object for this sample
*
* @throws JsonProcessingException if the given {@code initialState} could not be serialized as a JSON text
*/
public <T> CompletableFuture<Boolean> recordInitialState(final UUID accountIdentifier,
final byte deviceId,
final String experimentName,
final boolean inExperimentGroup,
final T initialState) throws JsonProcessingException {
final AttributeValue initialStateAttributeValue =
AttributeValue.fromS(SystemMapper.jsonMapper().writeValueAsString(initialState));
final AttributeValue inExperimentGroupAttributeValue = AttributeValue.fromBool(inExperimentGroup);
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(Map.of(
KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName),
ATTR_ACI_AND_DEVICE_ID, buildSortKey(accountIdentifier, deviceId),
ATTR_IN_EXPERIMENT_GROUP, inExperimentGroupAttributeValue,
ATTR_INITIAL_STATE, initialStateAttributeValue,
ATTR_TTL, AttributeValue.fromN(String.valueOf(clock.instant().plus(FINAL_SAMPLE_TTL).getEpochSecond()))))
.conditionExpression("(attribute_not_exists(#inExperimentGroup) OR #inExperimentGroup = :inExperimentGroup) AND (attribute_not_exists(#initialState) OR #initialState = :initialState) AND attribute_not_exists(#finalState)")
.expressionAttributeNames(Map.of(
"#inExperimentGroup", ATTR_IN_EXPERIMENT_GROUP,
"#initialState", ATTR_INITIAL_STATE,
"#finalState", ATTR_FINAL_STATE))
.expressionAttributeValues(Map.of(
":inExperimentGroup", inExperimentGroupAttributeValue,
":initialState", initialStateAttributeValue))
.build())
.thenApply(ignored -> true)
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof ConditionalCheckFailedException) {
return false;
}
throw ExceptionUtils.wrap(throwable);
});
}
/**
* Writes the final state of a device participating in a push notification experiment.
*
* @param accountIdentifier the account identifier for the account to which the target device is linked
* @param deviceId the identifier for the device within the given account
* @param experimentName the name of the experiment
* @param finalState the final state of the object; must be serializable as a JSON text and of the same type as the
* previously-stored initial state
* @return a future that completes when the final state has been stored; yields a finished sample if an initial sample
* was found or empty if no initial sample was found for the given account, device, and experiment
*
* @param <T> the type of state object for this sample
*
* @throws JsonProcessingException if the given {@code finalState} could not be serialized as a JSON text
*/
public <T> CompletableFuture<PushNotificationExperimentSample<T>> recordFinalState(final UUID accountIdentifier,
final byte deviceId,
final String experimentName,
final T finalState) throws JsonProcessingException {
final AttributeValue aciAndDeviceIdAttributeValue = buildSortKey(accountIdentifier, deviceId);
return dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName),
ATTR_ACI_AND_DEVICE_ID, aciAndDeviceIdAttributeValue))
// `UpdateItem` will, by default, create a new item if one does not already exist for the given primary key. We
// want update-only-if-exists behavior, though, and so check that there's already an existing item for this ACI
// and device ID.
.conditionExpression("#aciAndDeviceId = :aciAndDeviceId")
.updateExpression("SET #finalState = if_not_exists(#finalState, :finalState)")
.expressionAttributeNames(Map.of(
"#aciAndDeviceId", ATTR_ACI_AND_DEVICE_ID,
"#finalState", ATTR_FINAL_STATE))
.expressionAttributeValues(Map.of(
":aciAndDeviceId", aciAndDeviceIdAttributeValue,
":finalState", AttributeValue.fromS(SystemMapper.jsonMapper().writeValueAsString(finalState))))
.returnValues(ReturnValue.ALL_NEW)
.build())
.thenApply(updateItemResponse -> {
try {
final boolean inExperimentGroup = updateItemResponse.attributes().get(ATTR_IN_EXPERIMENT_GROUP).bool();
@SuppressWarnings("unchecked") final T parsedInitialState =
(T) parseState(updateItemResponse.attributes().get(ATTR_INITIAL_STATE).s(), finalState.getClass());
@SuppressWarnings("unchecked") final T parsedFinalState =
(T) parseState(updateItemResponse.attributes().get(ATTR_FINAL_STATE).s(), finalState.getClass());
return new PushNotificationExperimentSample<>(inExperimentGroup, parsedInitialState, parsedFinalState);
} catch (final JsonProcessingException e) {
throw ExceptionUtils.wrap(e);
}
});
}
/**
* Returns a publisher across all samples pending a final state for a given experiment.
*
* @param experimentName the name of the experiment for which to retrieve samples pending a final state
*
* @return a publisher across all samples pending a final state for a given experiment
*/
public Flux<Tuple2<UUID, Byte>> getDevicesPendingFinalState(final String experimentName) {
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#experiment = :experiment")
.filterExpression("attribute_not_exists(#finalState)")
.expressionAttributeNames(Map.of(
"#experiment", KEY_EXPERIMENT_NAME,
"#finalState", ATTR_FINAL_STATE))
.expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(experimentName)))
.projectionExpression(ATTR_ACI_AND_DEVICE_ID)
.build())
.items())
.map(item -> parseSortKey(item.get(ATTR_ACI_AND_DEVICE_ID)));
}
/**
* Returns a publisher across all finished samples (i.e. samples with a recorded final state) for a given experiment.
*
* @param experimentName the name of the experiment for which to retrieve finished samples
* @param stateClass the type of state object for sample in the given experiment
*
* @return a publisher across all finished samples for the given experiment
*
* @param <T> the type of the sample's state objects
*/
public <T> Flux<PushNotificationExperimentSample<T>> getFinishedSamples(final String experimentName,
final Class<T> stateClass) {
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#experiment = :experiment")
.filterExpression("attribute_exists(#finalState)")
.expressionAttributeNames(Map.of(
"#experiment", KEY_EXPERIMENT_NAME,
"#finalState", ATTR_FINAL_STATE))
.expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(experimentName)))
.build())
.items())
.handle((item, sink) -> {
try {
final boolean inExperimentGroup = item.get(ATTR_IN_EXPERIMENT_GROUP).bool();
final T initialState = parseState(item.get(ATTR_INITIAL_STATE).s(), stateClass);
final T finalState = parseState(item.get(ATTR_FINAL_STATE).s(), stateClass);
sink.next(new PushNotificationExperimentSample<>(inExperimentGroup, initialState, finalState));
} catch (final JsonProcessingException e) {
sink.error(e);
}
});
}
public CompletableFuture<Void> discardSamples(final String experimentName, final int maxConcurrency) {
final AttributeValue experimentNameAttributeValue = AttributeValue.fromS(experimentName);
return Flux.from(dynamoDbAsyncClient.scanPaginator(ScanRequest.builder()
.tableName(tableName)
.filterExpression("#experiment = :experiment")
.expressionAttributeNames(Map.of("#experiment", KEY_EXPERIMENT_NAME))
.expressionAttributeValues(Map.of(":experiment", experimentNameAttributeValue))
.projectionExpression(ATTR_ACI_AND_DEVICE_ID)
.build())
.items())
.map(item -> item.get(ATTR_ACI_AND_DEVICE_ID))
.flatMap(aciAndDeviceId -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_EXPERIMENT_NAME, experimentNameAttributeValue,
ATTR_ACI_AND_DEVICE_ID, aciAndDeviceId))
.build()))
.retryWhen(Retry.backoff(5, Duration.ofSeconds(5)))
.onErrorResume(throwable -> {
log.warn("Failed to delete sample for experiment {}", experimentName, throwable);
return Mono.empty();
}), maxConcurrency)
.then()
.toFuture();
}
@VisibleForTesting
static AttributeValue buildSortKey(final UUID accountIdentifier, final byte deviceId) {
return AttributeValue.fromB(SdkBytes.fromByteBuffer(ByteBuffer.allocate(17)
.putLong(accountIdentifier.getMostSignificantBits())
.putLong(accountIdentifier.getLeastSignificantBits())
.put(deviceId)
.flip()));
}
private static Tuple2<UUID, Byte> parseSortKey(final AttributeValue sortKey) {
final ByteBuffer byteBuffer = sortKey.b().asByteBuffer();
return Tuples.of(new UUID(byteBuffer.getLong(), byteBuffer.getLong()), byteBuffer.get());
}
private static <T> T parseState(final String state, final Class<T> clazz) throws JsonProcessingException {
return SystemMapper.jsonMapper().readValue(state, clazz);
}
}

View File

@ -11,6 +11,8 @@ import com.fasterxml.jackson.databind.DeserializationFeature;
import io.dropwizard.core.setup.Environment;
import io.lettuce.core.resource.ClientResources;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.time.Clock;
import java.util.concurrent.ExecutorService;
@ -28,8 +30,13 @@ import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@ -68,7 +75,10 @@ record CommandDependencies(
MessagesManager messagesManager,
ClientPresenceManager clientPresenceManager,
KeysManager keysManager,
PushNotificationManager pushNotificationManager,
PushNotificationExperimentSamples pushNotificationExperimentSamples,
FaultTolerantRedisCluster cacheCluster,
FaultTolerantRedisCluster pushSchedulerCluster,
ClientResources.Builder redisClusterClientResourcesBuilder,
BackupManager backupManager,
DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
@ -76,7 +86,8 @@ record CommandDependencies(
static CommandDependencies build(
final String name,
final Environment environment,
final WhisperServerConfiguration configuration) throws IOException, CertificateException {
final WhisperServerConfiguration configuration)
throws IOException, CertificateException, NoSuchAlgorithmException, InvalidKeyException {
Clock clock = Clock.systemUTC();
environment.getObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
@ -92,8 +103,10 @@ record CommandDependencies(
final ClientResources.Builder redisClientResourcesBuilder = ClientResources.builder();
FaultTolerantRedisCluster cacheCluster = configuration.getCacheClusterConfiguration().build("main_cache",
redisClientResourcesBuilder);
FaultTolerantRedisCluster cacheCluster = configuration.getCacheClusterConfiguration()
.build("main_cache", redisClientResourcesBuilder);
FaultTolerantRedisCluster pushSchedulerCluster = configuration.getPushSchedulerCluster()
.build("push_scheduler", redisClientResourcesBuilder);
ScheduledExecutorService recurringJobExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "recurringJob-%d")).threads(2).build();
@ -115,6 +128,10 @@ record CommandDependencies(
.executorService(name(name, "remoteStorage-%d"))
.minThreads(0).maxThreads(Integer.MAX_VALUE).workQueue(new SynchronousQueue<>())
.keepAliveTime(io.dropwizard.util.Duration.seconds(60L)).build();
ExecutorService apnSenderExecutor = environment.lifecycle().executorService(name(name, "apnSender-%d"))
.maxThreads(1).minThreads(1).build();
ExecutorService fcmSenderExecutor = environment.lifecycle().executorService(name(name, "fcmSender-%d"))
.maxThreads(16).minThreads(16).build();
ScheduledExecutorService secureValueRecoveryServiceRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "secureValueRecoveryServiceRetry-%d")).threads(1).build();
@ -225,6 +242,16 @@ record CommandDependencies(
remoteStorageRetryExecutor,
configuration.getCdn3StorageManagerConfiguration()),
clock);
APNSender apnSender = new APNSender(apnSenderExecutor, configuration.getApnConfiguration());
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, configuration.getFcmConfiguration().credentials().value());
ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster,
apnSender, accountsManager, 0);
PushNotificationManager pushNotificationManager =
new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler);
PushNotificationExperimentSamples pushNotificationExperimentSamples =
new PushNotificationExperimentSamples(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getPushNotificationExperimentSamples().getTableName(),
Clock.systemUTC());
environment.lifecycle().manage(messagesCache);
environment.lifecycle().manage(clientPresenceManager);
@ -238,7 +265,10 @@ record CommandDependencies(
messagesManager,
clientPresenceManager,
keys,
pushNotificationManager,
pushNotificationExperimentSamples,
cacheCluster,
pushSchedulerCluster,
redisClientResourcesBuilder,
backupManager,
dynamicConfigurationManager

View File

@ -0,0 +1,62 @@
package org.whispersystems.textsecuregcm.workers;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.core.Application;
import io.dropwizard.core.setup.Environment;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment;
public class DiscardPushNotificationExperimentSamplesCommand extends AbstractCommandWithDependencies {
private final PushNotificationExperimentFactory<?> experimentFactory;
private static final int DEFAULT_MAX_CONCURRENCY = 16;
@VisibleForTesting
static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency";
private static final Logger log = LoggerFactory.getLogger(DiscardPushNotificationExperimentSamplesCommand.class);
public DiscardPushNotificationExperimentSamplesCommand(final String name,
final String description,
final PushNotificationExperimentFactory<?> experimentFactory) {
super(new Application<>() {
@Override
public void run(final WhisperServerConfiguration configuration, final Environment environment) {
}
}, name, description);
this.experimentFactory = experimentFactory;
}
@Override
public void configure(final Subparser subparser) {
super.configure(subparser);
subparser.addArgument("--max-concurrency")
.type(Integer.class)
.dest(MAX_CONCURRENCY_ARGUMENT)
.setDefault(DEFAULT_MAX_CONCURRENCY)
.help("Max concurrency for DynamoDB operations");
}
@Override
protected void run(final Environment environment,
final Namespace namespace,
final WhisperServerConfiguration configuration,
final CommandDependencies commandDependencies) throws Exception {
final PushNotificationExperiment<?> experiment =
experimentFactory.buildExperiment(commandDependencies, configuration);
final int maxConcurrency = namespace.getInt(MAX_CONCURRENCY_ARGUMENT);
commandDependencies.pushNotificationExperimentSamples()
.discardSamples(experiment.getExperimentName(), maxConcurrency).join();
}
}

View File

@ -0,0 +1,117 @@
package org.whispersystems.textsecuregcm.workers;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.core.Application;
import io.dropwizard.core.setup.Environment;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuples;
import reactor.util.retry.Retry;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import javax.annotation.Nullable;
import java.time.Duration;
import java.util.UUID;
public class FinishPushNotificationExperimentCommand<T> extends AbstractCommandWithDependencies {
private final PushNotificationExperimentFactory<T> experimentFactory;
private static final int DEFAULT_MAX_CONCURRENCY = 16;
@VisibleForTesting
static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency";
private static final Logger log = LoggerFactory.getLogger(FinishPushNotificationExperimentCommand.class);
public FinishPushNotificationExperimentCommand(final String name,
final String description,
final PushNotificationExperimentFactory<T> experimentFactory) {
super(new Application<>() {
@Override
public void run(final WhisperServerConfiguration configuration, final Environment environment) {
}
}, name, description);
this.experimentFactory = experimentFactory;
}
@Override
public void configure(final Subparser subparser) {
super.configure(subparser);
subparser.addArgument("--max-concurrency")
.type(Integer.class)
.dest(MAX_CONCURRENCY_ARGUMENT)
.setDefault(DEFAULT_MAX_CONCURRENCY)
.help("Max concurrency for DynamoDB operations");
}
@Override
protected void run(final Environment environment,
final Namespace namespace,
final WhisperServerConfiguration configuration,
final CommandDependencies commandDependencies) throws Exception {
final PushNotificationExperiment<T> experiment =
experimentFactory.buildExperiment(commandDependencies, configuration);
final int maxConcurrency = namespace.getInt(MAX_CONCURRENCY_ARGUMENT);
final AccountsManager accountsManager = commandDependencies.accountsManager();
final PushNotificationExperimentSamples pushNotificationExperimentSamples = commandDependencies.pushNotificationExperimentSamples();
pushNotificationExperimentSamples.getDevicesPendingFinalState(experiment.getExperimentName())
.flatMap(accountIdentifierAndDeviceId ->
Mono.fromFuture(() -> accountsManager.getByAccountIdentifierAsync(accountIdentifierAndDeviceId.getT1()))
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
.map(maybeAccount -> Tuples.of(accountIdentifierAndDeviceId.getT1(), accountIdentifierAndDeviceId.getT2(), maybeAccount)), maxConcurrency)
.map(accountIdentifierAndDeviceIdAndMaybeAccount -> {
final UUID accountIdentifier = accountIdentifierAndDeviceIdAndMaybeAccount.getT1();
final byte deviceId = accountIdentifierAndDeviceIdAndMaybeAccount.getT2();
@Nullable final Account account = accountIdentifierAndDeviceIdAndMaybeAccount.getT3()
.orElse(null);
@Nullable final Device device = accountIdentifierAndDeviceIdAndMaybeAccount.getT3()
.flatMap(a -> a.getDevice(deviceId))
.orElse(null);
return Tuples.of(accountIdentifier, deviceId, experiment.getState(account, device));
})
.flatMap(accountIdentifierAndDeviceIdAndFinalState -> {
final UUID accountIdentifier = accountIdentifierAndDeviceIdAndFinalState.getT1();
final byte deviceId = accountIdentifierAndDeviceIdAndFinalState.getT2();
final T finalState = accountIdentifierAndDeviceIdAndFinalState.getT3();
return Mono.fromFuture(() -> {
try {
return pushNotificationExperimentSamples.recordFinalState(accountIdentifier, deviceId,
experiment.getExperimentName(), finalState);
} catch (final JsonProcessingException e) {
throw new RuntimeException(e);
}
})
.onErrorResume(ConditionalCheckFailedException.class, throwable -> Mono.empty())
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
.onErrorResume(throwable -> {
log.warn("Failed to record final state for {}:{} in experiment {}",
accountIdentifier, deviceId, experiment.getExperimentName(), throwable);
return Mono.empty();
});
}, maxConcurrency)
.then()
.block();
}
}

View File

@ -0,0 +1,10 @@
package org.whispersystems.textsecuregcm.workers;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment;
public interface PushNotificationExperimentFactory<T> {
PushNotificationExperiment<T> buildExperiment(CommandDependencies commandDependencies,
WhisperServerConfiguration configuration);
}

View File

@ -19,7 +19,6 @@ import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler;
public class ScheduledApnPushNotificationSenderServiceCommand extends ServerCommand<WhisperServerConfiguration> {
@ -64,15 +63,12 @@ public class ScheduledApnPushNotificationSenderServiceCommand extends ServerComm
});
}
final FaultTolerantRedisCluster pushSchedulerCluster = configuration.getPushSchedulerCluster()
.build("push_scheduler", deps.redisClusterClientResourcesBuilder());
final ExecutorService apnSenderExecutor = environment.lifecycle().executorService(name(getClass(), "apnSender-%d"))
.maxThreads(1).minThreads(1).build();
final APNSender apnSender = new APNSender(apnSenderExecutor, configuration.getApnConfiguration());
final ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(
pushSchedulerCluster, apnSender, deps.accountsManager(), namespace.getInt(WORKER_COUNT));
deps.pushSchedulerCluster(), apnSender, deps.accountsManager(), namespace.getInt(WORKER_COUNT));
environment.lifecycle().manage(apnSender);
environment.lifecycle().manage(apnPushNotificationScheduler);

View File

@ -0,0 +1,133 @@
package org.whispersystems.textsecuregcm.workers;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuples;
import reactor.util.retry.Retry;
import java.io.UncheckedIOException;
import java.time.Duration;
import java.util.UUID;
public class StartPushNotificationExperimentCommand<T> extends AbstractSinglePassCrawlAccountsCommand {
private final PushNotificationExperimentFactory<T> experimentFactory;
private static final int DEFAULT_MAX_CONCURRENCY = 16;
@VisibleForTesting
static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency";
private static final Counter INITIAL_SAMPLE_ALREADY_EXISTS_COUNTER =
Metrics.counter(MetricsUtil.name(StartPushNotificationExperimentCommand.class, "initialSampleAlreadyExists"));
private static final Logger log = LoggerFactory.getLogger(StartPushNotificationExperimentCommand.class);
public StartPushNotificationExperimentCommand(final String name,
final String description,
final PushNotificationExperimentFactory<T> experimentFactory) {
super(name, description);
this.experimentFactory = experimentFactory;
}
@Override
public void configure(final Subparser subparser) {
super.configure(subparser);
subparser.addArgument("--max-concurrency")
.type(Integer.class)
.dest(MAX_CONCURRENCY_ARGUMENT)
.setDefault(DEFAULT_MAX_CONCURRENCY)
.help("Max concurrency for DynamoDB operations");
}
@Override
protected void crawlAccounts(final Flux<Account> accounts) {
final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT);
final PushNotificationExperiment<T> experiment =
experimentFactory.buildExperiment(getCommandDependencies(), getConfiguration());
final PushNotificationExperimentSamples pushNotificationExperimentSamples =
getCommandDependencies().pushNotificationExperimentSamples();
accounts
.flatMap(account -> Flux.fromIterable(account.getDevices())
.map(device -> Tuples.of(account, device)))
.filterWhen(accountAndDevice -> Mono.fromFuture(() ->
experiment.isDeviceEligible(accountAndDevice.getT1(), accountAndDevice.getT2())),
maxConcurrency)
.flatMap(accountAndDevice -> {
final UUID accountIdentifier = accountAndDevice.getT1().getIdentifier(IdentityType.ACI);
final byte deviceId = accountAndDevice.getT2().getId();
return Mono.fromFuture(() -> {
try {
return pushNotificationExperimentSamples.recordInitialState(
accountIdentifier,
deviceId,
experiment.getExperimentName(),
isInExperimentGroup(accountIdentifier, deviceId, experiment.getExperimentName()),
experiment.getState(accountAndDevice.getT1(), accountAndDevice.getT2()));
} catch (final JsonProcessingException e) {
throw new UncheckedIOException(e);
}
})
.mapNotNull(stateStored -> {
if (stateStored) {
return accountAndDevice;
} else {
INITIAL_SAMPLE_ALREADY_EXISTS_COUNTER.increment();
return null;
}
})
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1))
.onRetryExhaustedThrow(((backoffSpec, retrySignal) -> retrySignal.failure())))
.onErrorResume(throwable -> {
log.warn("Failed to record initial sample for {}:{} in experiment {}",
accountIdentifier, deviceId, experiment.getExperimentName(), throwable);
return Mono.empty();
});
}, maxConcurrency)
.flatMap(accountAndDevice -> {
final Account account = accountAndDevice.getT1();
final Device device = accountAndDevice.getT2();
final boolean inExperimentGroup =
isInExperimentGroup(account.getIdentifier(IdentityType.ACI), device.getId(), experiment.getExperimentName());
return Mono.fromFuture(() -> inExperimentGroup
? experiment.applyExperimentTreatment(account, device)
: experiment.applyControlTreatment(account, device))
.onErrorResume(throwable -> {
log.warn("Failed to apply {} treatment for {}:{} in experiment {}",
inExperimentGroup ? "experimental" : " control",
account.getIdentifier(IdentityType.ACI),
device.getId(),
experiment.getExperimentName(),
throwable);
return Mono.empty();
});
}, maxConcurrency)
.then()
.block();
}
private boolean isInExperimentGroup(final UUID accountIdentifier, final byte deviceId, final String experimentName) {
return ((accountIdentifier.hashCode() ^ Byte.hashCode(deviceId) ^ experimentName.hashCode()) & 0x01) != 0;
}
}

View File

@ -0,0 +1,331 @@
package org.whispersystems.textsecuregcm.experiment;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.time.Clock;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import reactor.util.function.Tuples;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.Select;
import javax.annotation.Nullable;
class PushNotificationExperimentSamplesTest {
private PushNotificationExperimentSamples pushNotificationExperimentSamples;
@RegisterExtension
public static final DynamoDbExtension DYNAMO_DB_EXTENSION =
new DynamoDbExtension(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES);
private record TestDeviceState(int bounciness) {
}
@BeforeEach
void setUp() {
pushNotificationExperimentSamples =
new PushNotificationExperimentSamples(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName(),
Clock.systemUTC());
}
@Test
void recordInitialState() throws JsonProcessingException {
final String experimentName = "test-experiment";
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean();
final int bounciness = ThreadLocalRandom.current().nextInt();
assertTrue(pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(bounciness))
.join(),
"Attempt to record an initial state should succeed for entirely new records");
assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null),
getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class));
assertTrue(pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(bounciness))
.join(),
"Attempt to re-record an initial state should succeed for existing-but-unchanged records");
assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null),
getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class),
"Recorded initial state should be unchanged after repeated write");
assertFalse(pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(bounciness + 1))
.join(),
"Attempt to record a conflicting initial state should fail");
assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null),
getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class),
"Recorded initial state should be unchanged after unsuccessful write");
assertFalse(pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
!inExperimentGroup,
new TestDeviceState(bounciness))
.join(),
"Attempt to record a new group assignment should fail");
assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), null),
getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class),
"Recorded initial state should be unchanged after unsuccessful write");
final int finalBounciness = bounciness + 17;
pushNotificationExperimentSamples.recordFinalState(accountIdentifier,
deviceId,
experimentName,
new TestDeviceState(finalBounciness))
.join();
assertFalse(pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(bounciness))
.join(),
"Attempt to record an initial state should fail for samples with final states");
assertEquals(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(bounciness), new TestDeviceState(finalBounciness)),
getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class),
"Recorded initial state should be unchanged after unsuccessful write");
}
@Test
void recordFinalState() throws JsonProcessingException {
final String experimentName = "test-experiment";
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean();
final int initialBounciness = ThreadLocalRandom.current().nextInt();
final int finalBounciness = initialBounciness + 17;
{
pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(initialBounciness))
.join();
final PushNotificationExperimentSample<TestDeviceState> returnedSample =
pushNotificationExperimentSamples.recordFinalState(accountIdentifier,
deviceId,
experimentName,
new TestDeviceState(finalBounciness))
.join();
final PushNotificationExperimentSample<TestDeviceState> expectedSample =
new PushNotificationExperimentSample<>(inExperimentGroup,
new TestDeviceState(initialBounciness),
new TestDeviceState(finalBounciness));
assertEquals(expectedSample, returnedSample,
"Attempt to update existing sample without final state should succeed");
assertEquals(expectedSample, getSample(accountIdentifier, deviceId, experimentName, TestDeviceState.class),
"Attempt to update existing sample without final state should be persisted");
}
assertThrows(CompletionException.class, () -> pushNotificationExperimentSamples.recordFinalState(accountIdentifier,
(byte) (deviceId + 1),
experimentName,
new TestDeviceState(finalBounciness))
.join(),
"Attempts to record a final state without an initial state should fail");
}
@SuppressWarnings("SameParameterValue")
private <T> PushNotificationExperimentSample<T> getSample(final UUID accountIdentifier,
final byte deviceId,
final String experimentName,
final Class<T> stateClass) throws JsonProcessingException {
final GetItemResponse response = DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder()
.tableName(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName())
.key(Map.of(
PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME, AttributeValue.fromS(experimentName),
PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID, PushNotificationExperimentSamples.buildSortKey(accountIdentifier, deviceId)))
.build());
final boolean inExperimentGroup =
response.item().get(PushNotificationExperimentSamples.ATTR_IN_EXPERIMENT_GROUP).bool();
final T initialState =
SystemMapper.jsonMapper().readValue(response.item().get(PushNotificationExperimentSamples.ATTR_INITIAL_STATE).s(), stateClass);
@Nullable final T finalState = response.item().containsKey(PushNotificationExperimentSamples.ATTR_FINAL_STATE)
? SystemMapper.jsonMapper().readValue(response.item().get(PushNotificationExperimentSamples.ATTR_FINAL_STATE).s(), stateClass)
: null;
return new PushNotificationExperimentSample<>(inExperimentGroup, initialState, finalState);
}
@Test
void getDevicesPendingFinalState() throws JsonProcessingException {
final String experimentName = "test-experiment";
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean();
final int initialBounciness = ThreadLocalRandom.current().nextInt();
//noinspection DataFlowIssue
assertTrue(pushNotificationExperimentSamples.getDevicesPendingFinalState(experimentName).collectList().block().isEmpty());
pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(initialBounciness))
.join();
pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
(byte) (deviceId + 1),
experimentName + "-different",
inExperimentGroup,
new TestDeviceState(initialBounciness))
.join();
assertEquals(List.of(Tuples.of(accountIdentifier, deviceId)),
pushNotificationExperimentSamples.getDevicesPendingFinalState(experimentName).collectList().block());
}
@Test
void getFinishedSamples() throws JsonProcessingException {
final String experimentName = "test-experiment";
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final boolean inExperimentGroup = ThreadLocalRandom.current().nextBoolean();
final int initialBounciness = ThreadLocalRandom.current().nextInt();
final int finalBounciness = initialBounciness + 17;
//noinspection DataFlowIssue
assertTrue(pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block().isEmpty());
pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName,
inExperimentGroup,
new TestDeviceState(initialBounciness))
.join();
//noinspection DataFlowIssue
assertTrue(pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block().isEmpty(),
"Publisher should not return unfinished samples");
pushNotificationExperimentSamples.recordFinalState(accountIdentifier,
deviceId,
experimentName,
new TestDeviceState(finalBounciness))
.join();
final List<PushNotificationExperimentSample<TestDeviceState>> expectedSamples =
List.of(new PushNotificationExperimentSample<>(inExperimentGroup, new TestDeviceState(initialBounciness), new TestDeviceState(finalBounciness)));
assertEquals(
expectedSamples,
pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block(),
"Publisher should return finished samples");
pushNotificationExperimentSamples.recordInitialState(accountIdentifier,
deviceId,
experimentName + "-different",
inExperimentGroup,
new TestDeviceState(initialBounciness))
.join();
pushNotificationExperimentSamples.recordFinalState(accountIdentifier,
deviceId,
experimentName + "-different",
new TestDeviceState(finalBounciness))
.join();
assertEquals(
expectedSamples,
pushNotificationExperimentSamples.getFinishedSamples(experimentName, TestDeviceState.class).collectList().block(),
"Publisher should return finished samples only from named experiment");
}
@Test
void discardSamples() throws JsonProcessingException {
final String discardSamplesExperimentName = "discard-experiment";
final String retainSamplesExperimentName = "retain-experiment";
final int sampleCount = 16;
for (int i = 0; i < sampleCount; i++) {
pushNotificationExperimentSamples.recordInitialState(UUID.randomUUID(),
Device.PRIMARY_ID,
discardSamplesExperimentName,
ThreadLocalRandom.current().nextBoolean(),
new TestDeviceState(ThreadLocalRandom.current().nextInt()))
.join();
pushNotificationExperimentSamples.recordInitialState(UUID.randomUUID(),
Device.PRIMARY_ID,
retainSamplesExperimentName,
ThreadLocalRandom.current().nextBoolean(),
new TestDeviceState(ThreadLocalRandom.current().nextInt()))
.join();
}
pushNotificationExperimentSamples.discardSamples(discardSamplesExperimentName, 1).join();
{
final QueryResponse queryResponse = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().query(QueryRequest.builder()
.tableName(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName())
.keyConditionExpression("#experiment = :experiment")
.expressionAttributeNames(Map.of("#experiment", PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME))
.expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(discardSamplesExperimentName)))
.select(Select.COUNT)
.build())
.join();
assertEquals(0, queryResponse.count());
}
{
final QueryResponse queryResponse = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().query(QueryRequest.builder()
.tableName(DynamoDbExtensionSchema.Tables.PUSH_NOTIFICATION_EXPERIMENT_SAMPLES.tableName())
.keyConditionExpression("#experiment = :experiment")
.expressionAttributeNames(Map.of("#experiment", PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME))
.expressionAttributeValues(Map.of(":experiment", AttributeValue.fromS(retainSamplesExperimentName)))
.select(Select.COUNT)
.build())
.join();
assertEquals(sampleCount, queryResponse.count());
}
}
}

View File

@ -9,6 +9,7 @@ import java.util.Collections;
import java.util.List;
import org.whispersystems.textsecuregcm.backup.BackupsDb;
import org.whispersystems.textsecuregcm.scheduler.JobScheduler;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
@ -141,6 +142,20 @@ public final class DynamoDbExtensionSchema {
.build()),
List.of(), List.of()),
PUSH_NOTIFICATION_EXPERIMENT_SAMPLES("push_notification_experiment_samples_test",
PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME,
PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID,
List.of(
AttributeDefinition.builder()
.attributeName(PushNotificationExperimentSamples.KEY_EXPERIMENT_NAME)
.attributeType(ScalarAttributeType.S)
.build(),
AttributeDefinition.builder()
.attributeName(PushNotificationExperimentSamples.ATTR_ACI_AND_DEVICE_ID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
REPEATED_USE_EC_SIGNED_PRE_KEYS("repeated_use_signed_ec_pre_keys_test",
RepeatedUseSignedPreKeyStore.KEY_ACCOUNT_UUID,
RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID,

View File

@ -0,0 +1,252 @@
package org.whispersystems.textsecuregcm.workers;
import com.fasterxml.jackson.core.JsonProcessingException;
import net.sourceforge.argparse4j.inf.Namespace;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSample;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Flux;
import reactor.util.function.Tuples;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
class FinishPushNotificationExperimentCommandTest {
private CommandDependencies commandDependencies;
private PushNotificationExperiment<String> experiment;
private FinishPushNotificationExperimentCommand<String> finishPushNotificationExperimentCommand;
private static final String EXPERIMENT_NAME = "test";
private static class TestFinishPushNotificationExperimentCommand extends FinishPushNotificationExperimentCommand<String> {
public TestFinishPushNotificationExperimentCommand(final PushNotificationExperiment<String> experiment) {
super("test-finish-push-notification-experiment",
"Test start push notification experiment command",
(ignoredDependencies, ignoredConfiguration) -> experiment);
}
}
@BeforeEach
void setUp() throws JsonProcessingException {
final AccountsManager accountsManager = mock(AccountsManager.class);
final PushNotificationExperimentSamples pushNotificationExperimentSamples =
mock(PushNotificationExperimentSamples.class);
when(pushNotificationExperimentSamples.recordFinalState(any(), anyByte(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(new PushNotificationExperimentSample<>(true, "test", "test")));
commandDependencies = new CommandDependencies(accountsManager,
null,
null,
null,
null,
null,
null,
null,
pushNotificationExperimentSamples,
null,
null,
null,
null,
null);
//noinspection unchecked
experiment = mock(PushNotificationExperiment.class);
when(experiment.getExperimentName()).thenReturn(EXPERIMENT_NAME);
when(experiment.getState(any(), any())).thenReturn("test");
finishPushNotificationExperimentCommand = new TestFinishPushNotificationExperimentCommand(experiment);
}
@Test
void run() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(experiment).getState(account, device);
verify(commandDependencies.pushNotificationExperimentSamples())
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
@Test
void runMissingAccount() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(experiment).getState(null, null);
verify(commandDependencies.pushNotificationExperimentSamples())
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
@Test
void runMissingDevice() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Account account = mock(Account.class);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(experiment).getState(account, null);
verify(commandDependencies.pushNotificationExperimentSamples())
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
@Test
void runAccountFetchRetry() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(commandDependencies.accountsManager(), times(3)).getByAccountIdentifierAsync(accountIdentifier);
verify(experiment).getState(account, device);
verify(commandDependencies.pushNotificationExperimentSamples())
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
@Test
void runStoreSampleRetry() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.completedFuture(new PushNotificationExperimentSample<>(true, "test", "test")));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(experiment).getState(account, device);
verify(commandDependencies.pushNotificationExperimentSamples(), times(3))
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
@Test
void runMissingInitialSample() throws Exception {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(commandDependencies.accountsManager().getByAccountIdentifierAsync(accountIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(commandDependencies.pushNotificationExperimentSamples().getDevicesPendingFinalState(EXPERIMENT_NAME))
.thenReturn(Flux.just(Tuples.of(accountIdentifier, deviceId)));
when(commandDependencies.pushNotificationExperimentSamples().recordFinalState(any(), anyByte(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(ConditionalCheckFailedException.builder().build()));
assertDoesNotThrow(() -> finishPushNotificationExperimentCommand.run(null,
new Namespace(Map.of(FinishPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1)),
null,
commandDependencies));
verify(experiment).getState(account, device);
verify(commandDependencies.pushNotificationExperimentSamples())
.recordFinalState(eq(accountIdentifier), eq(deviceId), eq(EXPERIMENT_NAME), any());
}
}

View File

@ -0,0 +1,166 @@
package org.whispersystems.textsecuregcm.workers;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import net.sourceforge.argparse4j.inf.Namespace;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperiment;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Flux;
class StartPushNotificationExperimentCommandTest {
private PushNotificationExperimentSamples pushNotificationExperimentSamples;
private PushNotificationExperiment<String> experiment;
private StartPushNotificationExperimentCommand<String> startPushNotificationExperimentCommand;
// Taken together, these parameters will produce a device that's enrolled in the experimental group (as opposed to the
// control group) for an experiment.
private static final UUID ACCOUNT_IDENTIFIER = UUID.fromString("341fb18f-9dee-4181-bc40-e485958341d3");
private static final byte DEVICE_ID = Device.PRIMARY_ID;
private static final String EXPERIMENT_NAME = "test";
private static class TestStartPushNotificationExperimentCommand extends StartPushNotificationExperimentCommand<String> {
private final CommandDependencies commandDependencies;
public TestStartPushNotificationExperimentCommand(
final PushNotificationExperimentSamples pushNotificationExperimentSamples,
final PushNotificationExperiment<String> experiment) {
super("test-start-push-notification-experiment",
"Test start push notification experiment command",
(ignoredDependencies, ignoredConfiguration) -> experiment);
this.commandDependencies = new CommandDependencies(null,
null,
null,
null,
null,
null,
null,
null,
pushNotificationExperimentSamples,
null,
null,
null,
null,
null);
}
@Override
protected Namespace getNamespace() {
return new Namespace(Map.of(StartPushNotificationExperimentCommand.MAX_CONCURRENCY_ARGUMENT, 1));
}
@Override
protected CommandDependencies getCommandDependencies() {
return commandDependencies;
}
}
@BeforeEach
void setUp() {
//noinspection unchecked
experiment = mock(PushNotificationExperiment.class);
when(experiment.getExperimentName()).thenReturn(EXPERIMENT_NAME);
when(experiment.isDeviceEligible(any(), any())).thenReturn(CompletableFuture.completedFuture(true));
when(experiment.getState(any(), any())).thenReturn("test");
when(experiment.applyExperimentTreatment(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
pushNotificationExperimentSamples = mock(PushNotificationExperimentSamples.class);
try {
when(pushNotificationExperimentSamples.recordInitialState(any(), anyByte(), any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
} catch (final JsonProcessingException e) {
throw new AssertionError(e);
}
startPushNotificationExperimentCommand =
new TestStartPushNotificationExperimentCommand(pushNotificationExperimentSamples, experiment);
}
@Test
void crawlAccounts() {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER);
when(account.getDevices()).thenReturn(List.of(device));
assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account)));
verify(experiment).applyExperimentTreatment(account, device);
}
@Test
void crawlAccountsExistingSample() throws JsonProcessingException {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER);
when(account.getDevices()).thenReturn(List.of(device));
when(pushNotificationExperimentSamples.recordInitialState(any(), anyByte(), any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(false));
assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account)));
verify(experiment, never()).applyExperimentTreatment(account, device);
}
@Test
void crawlAccountsSampleRetry() throws JsonProcessingException {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER);
when(account.getDevices()).thenReturn(List.of(device));
when(pushNotificationExperimentSamples.recordInitialState(any(), anyByte(), any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()))
.thenReturn(CompletableFuture.completedFuture(true));
assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account)));
verify(experiment).applyExperimentTreatment(account, device);
verify(pushNotificationExperimentSamples, times(3))
.recordInitialState(ACCOUNT_IDENTIFIER, DEVICE_ID, EXPERIMENT_NAME, true, "test");
}
@Test
void crawlAccountsExperimentException() {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER);
when(account.getDevices()).thenReturn(List.of(device));
when(experiment.applyExperimentTreatment(account, device))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
assertDoesNotThrow(() -> startPushNotificationExperimentCommand.crawlAccounts(Flux.just(account)));
verify(experiment).applyExperimentTreatment(account, device);
}
}

View File

@ -110,6 +110,8 @@ dynamoDbTables:
tableName: profiles_test
pushChallenge:
tableName: push_challenge_test
pushNotificationExperimentSamples:
tableName: Example_PushNotificationExperimentSamples
redeemedReceipts:
tableName: redeemed_receipts_test
expiration: P30D # Duration of time until rows expire