Generalize push notification scheduler and add support for delayed "new messages" notifications

This commit is contained in:
Jon Chambers 2024-08-16 16:16:55 -04:00 committed by GitHub
parent 5892dc71fa
commit 659ac2c107
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 979 additions and 757 deletions

View File

@ -188,7 +188,7 @@ import org.whispersystems.textsecuregcm.metrics.TrafficSource;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck; import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck;
import org.whispersystems.textsecuregcm.push.APNSender; import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender; import org.whispersystems.textsecuregcm.push.FcmSender;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
@ -649,10 +649,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs); RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration()); APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration());
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials().value()); FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials().value());
ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, PushNotificationScheduler pushNotificationScheduler = new PushNotificationScheduler(pushSchedulerCluster,
apnSender, accountsManager, 0); apnSender, fcmSender, accountsManager, 0, 0);
PushNotificationManager pushNotificationManager = PushNotificationManager pushNotificationManager =
new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler); new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler);
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(), RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(),
dynamicConfigurationManager, rateLimitersCluster); dynamicConfigurationManager, rateLimitersCluster);
ProvisioningManager provisioningManager = new ProvisioningManager( ProvisioningManager provisioningManager = new ProvisioningManager(
@ -743,7 +743,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
subscriptionProcessorRetryExecutor); subscriptionProcessorRetryExecutor);
environment.lifecycle().manage(apnSender); environment.lifecycle().manage(apnSender);
environment.lifecycle().manage(apnPushNotificationScheduler); environment.lifecycle().manage(pushNotificationScheduler);
environment.lifecycle().manage(provisioningManager); environment.lifecycle().manage(provisioningManager);
environment.lifecycle().manage(messagesCache); environment.lifecycle().manage(messagesCache);
environment.lifecycle().manage(clientPresenceManager); environment.lifecycle().manage(clientPresenceManager);
@ -1006,7 +1006,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager))); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager)));
webSocketEnvironment.setConnectListener( webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager, new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager,
clientPresenceManager, websocketScheduledExecutor, messageDeliveryScheduler, clientReleaseManager)); pushNotificationScheduler, clientPresenceManager, websocketScheduledExecutor, messageDeliveryScheduler,
clientReleaseManager));
webSocketEnvironment.jersey() webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); .register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET)); webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET));
@ -1099,7 +1100,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new KeysController(rateLimiters, keysManager, accountsManager, zkSecretParams, Clock.systemUTC()), new KeysController(rateLimiters, keysManager, accountsManager, zkSecretParams, Clock.systemUTC()),
new KeyTransparencyController(keyTransparencyServiceClient), new KeyTransparencyController(keyTransparencyServiceClient),
new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender, new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender,
accountsManager, messagesManager, pushNotificationManager, reportMessageManager, accountsManager, messagesManager, pushNotificationManager, pushNotificationScheduler, reportMessageManager,
multiRecipientMessageExecutor, messageDeliveryScheduler, reportSpamTokenProvider, clientReleaseManager, multiRecipientMessageExecutor, messageDeliveryScheduler, reportSpamTokenProvider, clientReleaseManager,
dynamicConfigurationManager, zkSecretParams, spamChecker, messageMetrics, Clock.systemUTC()), dynamicConfigurationManager, zkSecretParams, spamChecker, messageMetrics, Clock.systemUTC()),
new PaymentsController(currencyManager, paymentsCredentialsGenerator), new PaymentsController(currencyManager, paymentsCredentialsGenerator),

View File

@ -113,6 +113,7 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider; import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.spam.SpamChecker; import org.whispersystems.textsecuregcm.spam.SpamChecker;
@ -157,6 +158,7 @@ public class MessageController {
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
private final PushNotificationScheduler pushNotificationScheduler;
private final ReportMessageManager reportMessageManager; private final ReportMessageManager reportMessageManager;
private final ExecutorService multiRecipientMessageExecutor; private final ExecutorService multiRecipientMessageExecutor;
private final Scheduler messageDeliveryScheduler; private final Scheduler messageDeliveryScheduler;
@ -208,6 +210,8 @@ public class MessageController {
@VisibleForTesting @VisibleForTesting
static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes(); static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
private static final Duration NOTIFY_FOR_REMAINING_MESSAGES_DELAY = Duration.ofMinutes(1);
public MessageController( public MessageController(
RateLimiters rateLimiters, RateLimiters rateLimiters,
CardinalityEstimator messageByteLimitEstimator, CardinalityEstimator messageByteLimitEstimator,
@ -216,6 +220,7 @@ public class MessageController {
AccountsManager accountsManager, AccountsManager accountsManager,
MessagesManager messagesManager, MessagesManager messagesManager,
PushNotificationManager pushNotificationManager, PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
ReportMessageManager reportMessageManager, ReportMessageManager reportMessageManager,
@Nonnull ExecutorService multiRecipientMessageExecutor, @Nonnull ExecutorService multiRecipientMessageExecutor,
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
@ -233,6 +238,7 @@ public class MessageController {
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
this.pushNotificationScheduler = pushNotificationScheduler;
this.reportMessageManager = reportMessageManager; this.reportMessageManager = reportMessageManager;
this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor); this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor);
this.messageDeliveryScheduler = messageDeliveryScheduler; this.messageDeliveryScheduler = messageDeliveryScheduler;
@ -779,6 +785,10 @@ public class MessageController {
Metrics.summary(OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))) Metrics.summary(OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.record(estimateMessageListSizeBytes(messages)); .record(estimateMessageListSizeBytes(messages));
if (messagesAndHasMore.second()) {
pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(), auth.getAuthenticatedDevice(), NOTIFY_FOR_REMAINING_MESSAGES_DELAY);
}
return messages; return messages;
}) })
.timeout(Duration.ofSeconds(5)) .timeout(Duration.ofSeconds(5))

View File

@ -1,439 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.Limit;
import io.lettuce.core.Range;
import io.lettuce.core.RedisException;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.SetArgs;
import io.lettuce.core.cluster.SlotHash;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import org.whispersystems.textsecuregcm.util.Util;
public class ApnPushNotificationScheduler implements Managed {
private static final Logger logger = LoggerFactory.getLogger(ApnPushNotificationScheduler.class);
private static final String PENDING_RECURRING_VOIP_NOTIFICATIONS_KEY_PREFIX = "PENDING_APN";
private static final String PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_APN";
private static final String LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX = "LAST_BACKGROUND_NOTIFICATION";
@VisibleForTesting
static final String NEXT_SLOT_TO_PROCESS_KEY = "pending_notification_next_slot";
private static final Counter delivered = Metrics.counter(name(ApnPushNotificationScheduler.class, "voip_delivered"));
private static final Counter sent = Metrics.counter(name(ApnPushNotificationScheduler.class, "voip_sent"));
private static final Counter retry = Metrics.counter(name(ApnPushNotificationScheduler.class, "voip_retry"));
private static final Counter evicted = Metrics.counter(name(ApnPushNotificationScheduler.class, "voip_evicted"));
private static final Counter backgroundNotificationScheduledCounter = Metrics.counter(name(ApnPushNotificationScheduler.class, "backgroundNotification", "scheduled"));
private static final Counter backgroundNotificationSentCounter = Metrics.counter(name(ApnPushNotificationScheduler.class, "backgroundNotification", "sent"));
private final APNSender apnSender;
private final AccountsManager accountsManager;
private final FaultTolerantRedisCluster pushSchedulingCluster;
private final Clock clock;
private final ClusterLuaScript getPendingVoipDestinationsScript;
private final ClusterLuaScript insertPendingVoipDestinationScript;
private final ClusterLuaScript removePendingVoipDestinationScript;
private final ClusterLuaScript scheduleBackgroundNotificationScript;
private final Thread[] workerThreads;
@VisibleForTesting
static final Duration BACKGROUND_NOTIFICATION_PERIOD = Duration.ofMinutes(20);
private final AtomicBoolean running = new AtomicBoolean(false);
class NotificationWorker implements Runnable {
private static final int PAGE_SIZE = 128;
@Override
public void run() {
do {
try {
final long entriesProcessed = processNextSlot();
if (entriesProcessed == 0) {
Util.sleep(1000);
}
} catch (Exception e) {
logger.warn("Exception while operating", e);
}
} while (running.get());
}
private long processNextSlot() {
final int slot = (int) (pushSchedulingCluster.withCluster(connection ->
connection.sync().incr(NEXT_SLOT_TO_PROCESS_KEY)) % SlotHash.SLOT_COUNT);
return processRecurringVoipNotifications(slot) + processScheduledBackgroundNotifications(slot);
}
@VisibleForTesting
long processRecurringVoipNotifications(final int slot) {
List<String> pendingDestinations;
long entriesProcessed = 0;
do {
pendingDestinations = getPendingDestinationsForRecurringVoipNotifications(slot, PAGE_SIZE);
entriesProcessed += pendingDestinations.size();
for (final String destination : pendingDestinations) {
try {
getAccountAndDeviceFromPairString(destination).ifPresentOrElse(
accountAndDevice -> sendRecurringVoipNotification(accountAndDevice.first(), accountAndDevice.second()),
() -> removeRecurringVoipNotificationEntrySync(destination));
} catch (final IllegalArgumentException e) {
logger.warn("Failed to parse account/device pair: {}", destination, e);
}
}
} while (!pendingDestinations.isEmpty());
return entriesProcessed;
}
@VisibleForTesting
long processScheduledBackgroundNotifications(final int slot) {
final long currentTimeMillis = clock.millis();
final String queueKey = getPendingBackgroundNotificationQueueKey(slot);
final long processedBackgroundNotifications = pushSchedulingCluster.withCluster(connection -> {
List<String> destinations;
long offset = 0;
do {
destinations = connection.sync().zrangebyscore(queueKey, Range.create(0, currentTimeMillis), Limit.create(offset, PAGE_SIZE));
for (final String destination : destinations) {
try {
getAccountAndDeviceFromPairString(destination).ifPresent(accountAndDevice ->
sendBackgroundNotification(accountAndDevice.first(), accountAndDevice.second()));
} catch (final IllegalArgumentException e) {
logger.warn("Failed to parse account/device pair: {}", destination, e);
}
}
offset += destinations.size();
} while (destinations.size() == PAGE_SIZE);
return offset;
});
pushSchedulingCluster.useCluster(connection ->
connection.sync().zremrangebyscore(queueKey, Range.create(0, currentTimeMillis)));
return processedBackgroundNotifications;
}
}
public ApnPushNotificationScheduler(FaultTolerantRedisCluster pushSchedulingCluster,
APNSender apnSender, AccountsManager accountsManager, final int dedicatedProcessWorkerThreadCount)
throws IOException {
this(pushSchedulingCluster, apnSender, accountsManager, Clock.systemUTC(), dedicatedProcessWorkerThreadCount);
}
@VisibleForTesting
ApnPushNotificationScheduler(FaultTolerantRedisCluster pushSchedulingCluster,
APNSender apnSender,
AccountsManager accountsManager,
Clock clock,
int dedicatedProcessThreadCount) throws IOException {
this.apnSender = apnSender;
this.accountsManager = accountsManager;
this.pushSchedulingCluster = pushSchedulingCluster;
this.clock = clock;
this.getPendingVoipDestinationsScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/get.lua",
ScriptOutputType.MULTI);
this.insertPendingVoipDestinationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/insert.lua",
ScriptOutputType.VALUE);
this.removePendingVoipDestinationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/remove.lua",
ScriptOutputType.INTEGER);
this.scheduleBackgroundNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster,
"lua/apn/schedule_background_notification.lua", ScriptOutputType.VALUE);
this.workerThreads = new Thread[dedicatedProcessThreadCount];
for (int i = 0; i < this.workerThreads.length; i++) {
this.workerThreads[i] = new Thread(new NotificationWorker(), "ApnFallbackManagerWorker-" + i);
}
}
/**
* Schedule a recurring VOIP notification until {@link this#cancelScheduledNotifications} is called or the device is
* removed
*
* @return A CompletionStage that completes when the recurring notification has successfully been scheduled
*/
public CompletionStage<Void> scheduleRecurringVoipNotification(Account account, Device device) {
sent.increment();
return insertRecurringVoipNotificationEntry(account, device, clock.millis() + (15 * 1000), (15 * 1000));
}
/**
* Schedule a background notification to be sent some time in the future
*
* @return A CompletionStage that completes when the notification has successfully been scheduled
*/
public CompletionStage<Void> scheduleBackgroundNotification(final Account account, final Device device) {
backgroundNotificationScheduledCounter.increment();
return scheduleBackgroundNotificationScript.executeAsync(
List.of(
getLastBackgroundNotificationTimestampKey(account, device),
getPendingBackgroundNotificationQueueKey(account, device)),
List.of(
getPairString(account, device),
String.valueOf(clock.millis()),
String.valueOf(BACKGROUND_NOTIFICATION_PERIOD.toMillis())))
.thenAccept(dropValue());
}
/**
* Cancel a scheduled recurring VOIP notification
*
* @return A CompletionStage that completes when the scheduled task has been cancelled.
*/
public CompletionStage<Void> cancelScheduledNotifications(Account account, Device device) {
return removeRecurringVoipNotificationEntry(account, device)
.thenCompose(removed -> {
if (removed) {
delivered.increment();
}
return pushSchedulingCluster.withCluster(connection ->
connection.async().zrem(
getPendingBackgroundNotificationQueueKey(account, device),
getPairString(account, device)));
})
.thenAccept(dropValue());
}
@Override
public synchronized void start() {
running.set(true);
for (final Thread workerThread : workerThreads) {
workerThread.start();
}
}
@Override
public synchronized void stop() throws InterruptedException {
running.set(false);
for (final Thread workerThread : workerThreads) {
workerThread.join();
}
}
private void sendRecurringVoipNotification(final Account account, final Device device) {
String apnId = device.getVoipApnId();
if (apnId == null) {
removeRecurringVoipNotificationEntrySync(getEndpointKey(account, device));
return;
}
long deviceLastSeen = device.getLastSeen();
if (deviceLastSeen < clock.millis() - TimeUnit.DAYS.toMillis(7)) {
evicted.increment();
removeRecurringVoipNotificationEntrySync(getEndpointKey(account, device));
return;
}
apnSender.sendNotification(new PushNotification(apnId, PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device, true));
retry.increment();
}
@VisibleForTesting
void sendBackgroundNotification(final Account account, final Device device) {
if (StringUtils.isNotBlank(device.getApnId())) {
// It's okay for the "last notification" timestamp to expire after the "cooldown" period has elapsed; a missing
// timestamp and a timestamp older than the period are functionally equivalent.
pushSchedulingCluster.useCluster(connection -> connection.sync().set(
getLastBackgroundNotificationTimestampKey(account, device),
String.valueOf(clock.millis()), new SetArgs().ex(BACKGROUND_NOTIFICATION_PERIOD)));
apnSender.sendNotification(new PushNotification(device.getApnId(), PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, account, device, false));
backgroundNotificationSentCounter.increment();
}
}
@VisibleForTesting
static Optional<Pair<String, Byte>> getSeparated(String encoded) {
try {
if (encoded == null) return Optional.empty();
String[] parts = encoded.split(":");
if (parts.length != 2) {
logger.warn("Got strange encoded number: " + encoded);
return Optional.empty();
}
return Optional.of(new Pair<>(parts[0], Byte.parseByte(parts[1])));
} catch (NumberFormatException e) {
logger.warn("Badly formatted: " + encoded, e);
return Optional.empty();
}
}
@VisibleForTesting
static String getPairString(final Account account, final Device device) {
return account.getUuid() + ":" + device.getId();
}
@VisibleForTesting
Optional<Pair<Account, Device>> getAccountAndDeviceFromPairString(final String endpoint) {
try {
if (StringUtils.isBlank(endpoint)) {
throw new IllegalArgumentException("Endpoint must not be blank");
}
final String[] parts = endpoint.split(":");
if (parts.length != 2) {
throw new IllegalArgumentException("Could not parse endpoint string: " + endpoint);
}
final Optional<Account> maybeAccount = accountsManager.getByAccountIdentifier(UUID.fromString(parts[0]));
return maybeAccount.flatMap(account -> account.getDevice(Byte.parseByte(parts[1])))
.map(device -> new Pair<>(maybeAccount.get(), device));
} catch (final NumberFormatException e) {
throw new IllegalArgumentException(e);
}
}
private boolean removeRecurringVoipNotificationEntrySync(final String endpoint) {
try {
return removeRecurringVoipNotificationEntry(endpoint).toCompletableFuture().get();
} catch (ExecutionException e) {
if (e.getCause() instanceof RedisException re) {
throw re;
}
throw new RuntimeException(e);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
private CompletionStage<Boolean> removeRecurringVoipNotificationEntry(Account account, Device device) {
return removeRecurringVoipNotificationEntry(getEndpointKey(account, device));
}
private CompletionStage<Boolean> removeRecurringVoipNotificationEntry(final String endpoint) {
return removePendingVoipDestinationScript.executeAsync(
List.of(getPendingRecurringVoipNotificationQueueKey(endpoint), endpoint),
Collections.emptyList())
.thenApply(result -> ((long) result) > 0);
}
@SuppressWarnings("unchecked")
@VisibleForTesting
List<String> getPendingDestinationsForRecurringVoipNotifications(final int slot, final int limit) {
return (List<String>) getPendingVoipDestinationsScript.execute(
List.of(getPendingRecurringVoipNotificationQueueKey(slot)),
List.of(String.valueOf(clock.millis()), String.valueOf(limit)));
}
private CompletionStage<Void> insertRecurringVoipNotificationEntry(final Account account, final Device device, final long timestamp, final long interval) {
final String endpoint = getEndpointKey(account, device);
return insertPendingVoipDestinationScript.executeAsync(
List.of(getPendingRecurringVoipNotificationQueueKey(endpoint), endpoint),
List.of(String.valueOf(timestamp),
String.valueOf(interval),
account.getUuid().toString(),
String.valueOf(device.getId())))
.thenAccept(dropValue());
}
@VisibleForTesting
static String getEndpointKey(final Account account, final Device device) {
return "apn_device::{" + account.getUuid() + "::" + device.getId() + "}";
}
private static String getPendingRecurringVoipNotificationQueueKey(final String endpoint) {
return getPendingRecurringVoipNotificationQueueKey(SlotHash.getSlot(endpoint));
}
private static String getPendingRecurringVoipNotificationQueueKey(final int slot) {
return PENDING_RECURRING_VOIP_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
}
@VisibleForTesting
static String getPendingBackgroundNotificationQueueKey(final Account account, final Device device) {
return getPendingBackgroundNotificationQueueKey(SlotHash.getSlot(getPairString(account, device)));
}
private static String getPendingBackgroundNotificationQueueKey(final int slot) {
return PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
}
private static String getLastBackgroundNotificationTimestampKey(final Account account, final Device device) {
return LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX + "::{" + getPairString(account, device) + "}";
}
@VisibleForTesting
Optional<Instant> getLastBackgroundNotificationTimestamp(final Account account, final Device device) {
return Optional.ofNullable(
pushSchedulingCluster.withCluster(connection ->
connection.sync().get(getLastBackgroundNotificationTimestampKey(account, device))))
.map(timestampString -> Instant.ofEpochMilli(Long.parseLong(timestampString)));
}
@VisibleForTesting
Optional<Instant> getNextScheduledBackgroundNotificationTimestamp(final Account account, final Device device) {
return Optional.ofNullable(
pushSchedulingCluster.withCluster(connection ->
connection.sync().zscore(getPendingBackgroundNotificationQueueKey(account, device),
getPairString(account, device))))
.map(timestamp -> Instant.ofEpochMilli(timestamp.longValue()));
}
private static <T> Consumer<T> dropValue() {
return ignored -> {};
}
}

View File

@ -17,7 +17,6 @@ import java.util.function.BiConsumer;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -28,7 +27,7 @@ public class PushNotificationManager {
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final APNSender apnSender; private final APNSender apnSender;
private final FcmSender fcmSender; private final FcmSender fcmSender;
private final ApnPushNotificationScheduler apnPushNotificationScheduler; private final PushNotificationScheduler pushNotificationScheduler;
private static final String SENT_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "sentPushNotification"); private static final String SENT_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "sentPushNotification");
private static final String FAILED_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "failedPushNotification"); private static final String FAILED_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "failedPushNotification");
@ -39,12 +38,12 @@ public class PushNotificationManager {
public PushNotificationManager(final AccountsManager accountsManager, public PushNotificationManager(final AccountsManager accountsManager,
final APNSender apnSender, final APNSender apnSender,
final FcmSender fcmSender, final FcmSender fcmSender,
final ApnPushNotificationScheduler apnPushNotificationScheduler) { final PushNotificationScheduler pushNotificationScheduler) {
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.apnSender = apnSender; this.apnSender = apnSender;
this.fcmSender = fcmSender; this.fcmSender = fcmSender;
this.apnPushNotificationScheduler = apnPushNotificationScheduler; this.pushNotificationScheduler = pushNotificationScheduler;
} }
public CompletableFuture<Optional<SendPushNotificationResult>> sendNewMessageNotification(final Account destination, final byte destinationDeviceId, final boolean urgent) throws NotPushRegisteredException { public CompletableFuture<Optional<SendPushNotificationResult>> sendNewMessageNotification(final Account destination, final byte destinationDeviceId, final boolean urgent) throws NotPushRegisteredException {
@ -82,7 +81,7 @@ public class PushNotificationManager {
} }
public void handleMessagesRetrieved(final Account account, final Device device, final String userAgent) { public void handleMessagesRetrieved(final Account account, final Device device, final String userAgent) {
apnPushNotificationScheduler.cancelScheduledNotifications(account, device).whenComplete(logErrors()); pushNotificationScheduler.cancelScheduledNotifications(account, device).whenComplete(logErrors());
} }
@VisibleForTesting @VisibleForTesting
@ -107,8 +106,8 @@ public class PushNotificationManager {
if (pushNotification.tokenType() == PushNotification.TokenType.APN && !pushNotification.urgent()) { if (pushNotification.tokenType() == PushNotification.TokenType.APN && !pushNotification.urgent()) {
// APNs imposes a per-device limit on background push notifications; schedule a notification for some time in the // APNs imposes a per-device limit on background push notifications; schedule a notification for some time in the
// future (possibly even now!) rather than sending a notification directly // future (possibly even now!) rather than sending a notification directly
return apnPushNotificationScheduler return pushNotificationScheduler
.scheduleBackgroundNotification(pushNotification.destination(), pushNotification.destinationDevice()) .scheduleBackgroundApnsNotification(pushNotification.destination(), pushNotification.destinationDevice())
.whenComplete(logErrors()) .whenComplete(logErrors())
.thenApply(ignored -> Optional.<SendPushNotificationResult>empty()) .thenApply(ignored -> Optional.<SendPushNotificationResult>empty())
.toCompletableFuture(); .toCompletableFuture();
@ -149,7 +148,7 @@ public class PushNotificationManager {
pushNotification.destination() != null && pushNotification.destination() != null &&
pushNotification.destinationDevice() != null) { pushNotification.destinationDevice() != null) {
apnPushNotificationScheduler.scheduleRecurringVoipNotification( pushNotificationScheduler.scheduleRecurringApnsVoipNotification(
pushNotification.destination(), pushNotification.destination(),
pushNotification.destinationDevice()) pushNotification.destinationDevice())
.whenComplete(logErrors()); .whenComplete(logErrors());
@ -185,7 +184,7 @@ public class PushNotificationManager {
if (tokenExpired) { if (tokenExpired) {
if (tokenType == PushNotification.TokenType.APN || tokenType == PushNotification.TokenType.APN_VOIP) { if (tokenType == PushNotification.TokenType.APN || tokenType == PushNotification.TokenType.APN_VOIP) {
apnPushNotificationScheduler.cancelScheduledNotifications(account, device).whenComplete(logErrors()); pushNotificationScheduler.cancelScheduledNotifications(account, device).whenComplete(logErrors());
} }
clearPushToken(account, device, tokenType); clearPushToken(account, device, tokenType);

View File

@ -0,0 +1,545 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.Range;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.SetArgs;
import io.lettuce.core.cluster.SlotHash;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiFunction;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
public class PushNotificationScheduler implements Managed {
private static final Logger logger = LoggerFactory.getLogger(PushNotificationScheduler.class);
private static final String PENDING_RECURRING_VOIP_NOTIFICATIONS_KEY_PREFIX = "PENDING_APN";
private static final String PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_APN";
private static final String LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX = "LAST_BACKGROUND_NOTIFICATION";
private static final String PENDING_DELAYED_NOTIFICATIONS_KEY_PREFIX = "DELAYED";
@VisibleForTesting
static final String NEXT_SLOT_TO_PROCESS_KEY = "pending_notification_next_slot";
private static final Counter delivered = Metrics.counter("chat.ApnPushNotificationScheduler.voip_delivered");
private static final Counter sent = Metrics.counter("chat.ApnPushNotificationScheduler.voip_sent");
private static final Counter retry = Metrics.counter("chat.ApnPushNotificationScheduler.voip_retry");
private static final Counter evicted = Metrics.counter("chat.ApnPushNotificationScheduler.voip_evicted");
private static final Counter BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER = Metrics.counter(name(PushNotificationScheduler.class, "backgroundNotification", "scheduled"));
private static final String BACKGROUND_NOTIFICATION_SENT_COUNTER_NAME = name(PushNotificationScheduler.class, "backgroundNotification", "sent");
private static final String DELAYED_NOTIFICATION_SCHEDULED_COUNTER_NAME = name(PushNotificationScheduler.class, "delayedNotificationScheduled");
private static final String DELAYED_NOTIFICATION_SENT_COUNTER_NAME = name(PushNotificationScheduler.class, "delayedNotificationSent");
private static final String TOKEN_TYPE_TAG = "tokenType";
private static final String ACCEPTED_TAG = "accepted";
private final APNSender apnSender;
private final FcmSender fcmSender;
private final AccountsManager accountsManager;
private final FaultTolerantRedisCluster pushSchedulingCluster;
private final Clock clock;
private final ClusterLuaScript getPendingVoipDestinationsScript;
private final ClusterLuaScript insertPendingVoipDestinationScript;
private final ClusterLuaScript removePendingVoipDestinationScript;
private final ClusterLuaScript scheduleBackgroundApnsNotificationScript;
private final Thread[] workerThreads;
@VisibleForTesting
static final Duration BACKGROUND_NOTIFICATION_PERIOD = Duration.ofMinutes(20);
private final AtomicBoolean running = new AtomicBoolean(false);
class NotificationWorker implements Runnable {
private final int maxConcurrency;
private static final int PAGE_SIZE = 128;
NotificationWorker(final int maxConcurrency) {
this.maxConcurrency = maxConcurrency;
}
@Override
public void run() {
do {
try {
final long entriesProcessed = processNextSlot();
if (entriesProcessed == 0) {
Util.sleep(1000);
}
} catch (Exception e) {
logger.warn("Exception while operating", e);
}
} while (running.get());
}
private long processNextSlot() {
final int slot = (int) (pushSchedulingCluster.withCluster(connection ->
connection.sync().incr(NEXT_SLOT_TO_PROCESS_KEY)) % SlotHash.SLOT_COUNT);
return processRecurringApnsVoipNotifications(slot) +
processScheduledBackgroundApnsNotifications(slot) +
processScheduledDelayedNotifications(slot);
}
@VisibleForTesting
long processRecurringApnsVoipNotifications(final int slot) {
List<String> pendingDestinations;
long entriesProcessed = 0;
do {
pendingDestinations = getPendingDestinationsForRecurringApnsVoipNotifications(slot, PAGE_SIZE);
entriesProcessed += pendingDestinations.size();
Flux.fromIterable(pendingDestinations)
.flatMap(destination -> Mono.fromFuture(() -> getAccountAndDeviceFromPairString(destination))
.flatMap(maybeAccountAndDevice -> {
if (maybeAccountAndDevice.isPresent()) {
final Pair<Account, Device> accountAndDevice = maybeAccountAndDevice.get();
return Mono.fromFuture(() -> sendRecurringApnsVoipNotification(accountAndDevice.first(), accountAndDevice.second()));
} else {
final Pair<UUID, Byte> aciAndDeviceId = decodeAciAndDeviceId(destination);
return Mono.fromFuture(() -> removeRecurringApnsVoipNotificationEntry(aciAndDeviceId.first(), aciAndDeviceId.second()))
.then();
}
}), maxConcurrency)
.then()
.block();
} while (!pendingDestinations.isEmpty());
return entriesProcessed;
}
@VisibleForTesting
long processScheduledBackgroundApnsNotifications(final int slot) {
return processScheduledNotifications(getPendingBackgroundApnsNotificationQueueKey(slot),
PushNotificationScheduler.this::sendBackgroundApnsNotification);
}
@VisibleForTesting
long processScheduledDelayedNotifications(final int slot) {
return processScheduledNotifications(getDelayedNotificationQueueKey(slot),
PushNotificationScheduler.this::sendDelayedNotification);
}
private long processScheduledNotifications(final String queueKey,
final BiFunction<Account, Device, CompletableFuture<Void>> sendNotificationFunction) {
final long currentTimeMillis = clock.millis();
final AtomicLong processedNotifications = new AtomicLong(0);
pushSchedulingCluster.useCluster(
connection -> connection.reactive().zrangebyscore(queueKey, Range.create(0, currentTimeMillis))
.flatMap(encodedAciAndDeviceId -> Mono.fromFuture(
() -> getAccountAndDeviceFromPairString(encodedAciAndDeviceId)), maxConcurrency)
.flatMap(Mono::justOrEmpty)
.flatMap(accountAndDevice -> Mono.fromFuture(
() -> sendNotificationFunction.apply(accountAndDevice.first(), accountAndDevice.second()))
.then(Mono.defer(() -> connection.reactive().zrem(queueKey, encodeAciAndDeviceId(accountAndDevice.first(), accountAndDevice.second()))))
.doOnSuccess(ignored -> processedNotifications.incrementAndGet()),
maxConcurrency)
.then()
.block());
return processedNotifications.get();
}
}
public PushNotificationScheduler(final FaultTolerantRedisCluster pushSchedulingCluster,
final APNSender apnSender,
final FcmSender fcmSender,
final AccountsManager accountsManager,
final int dedicatedProcessWorkerThreadCount,
final int workerMaxConcurrency) throws IOException {
this(pushSchedulingCluster,
apnSender,
fcmSender,
accountsManager,
Clock.systemUTC(),
dedicatedProcessWorkerThreadCount,
workerMaxConcurrency);
}
@VisibleForTesting
PushNotificationScheduler(final FaultTolerantRedisCluster pushSchedulingCluster,
final APNSender apnSender,
final FcmSender fcmSender,
final AccountsManager accountsManager,
final Clock clock,
final int dedicatedProcessThreadCount,
final int workerMaxConcurrency) throws IOException {
this.apnSender = apnSender;
this.fcmSender = fcmSender;
this.accountsManager = accountsManager;
this.pushSchedulingCluster = pushSchedulingCluster;
this.clock = clock;
this.getPendingVoipDestinationsScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/get.lua",
ScriptOutputType.MULTI);
this.insertPendingVoipDestinationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/insert.lua",
ScriptOutputType.VALUE);
this.removePendingVoipDestinationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/remove.lua",
ScriptOutputType.INTEGER);
this.scheduleBackgroundApnsNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster,
"lua/apn/schedule_background_notification.lua", ScriptOutputType.VALUE);
this.workerThreads = new Thread[dedicatedProcessThreadCount];
for (int i = 0; i < this.workerThreads.length; i++) {
this.workerThreads[i] = new Thread(new NotificationWorker(workerMaxConcurrency), "PushNotificationScheduler-" + i);
}
}
/**
* Schedule a recurring VOIP notification until {@link this#cancelScheduledNotifications} is called or the device is
* removed
*
* @return A CompletionStage that completes when the recurring notification has successfully been scheduled
*/
public CompletionStage<Void> scheduleRecurringApnsVoipNotification(Account account, Device device) {
sent.increment();
return insertRecurringApnsVoipNotificationEntry(account.getIdentifier(IdentityType.ACI), device.getId(), clock.millis() + (15 * 1000), (15 * 1000));
}
/**
* Schedule a background APNs notification to be sent some time in the future.
*
* @return A CompletionStage that completes when the notification has successfully been scheduled
*
* @throws IllegalArgumentException if the given device does not have an APNs token
*/
public CompletionStage<Void> scheduleBackgroundApnsNotification(final Account account, final Device device) {
if (StringUtils.isBlank(device.getApnId())) {
throw new IllegalArgumentException("Device must have an APNs token");
}
BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER.increment();
return scheduleBackgroundApnsNotificationScript.executeAsync(
List.of(
getLastBackgroundApnsNotificationTimestampKey(account, device),
getPendingBackgroundApnsNotificationQueueKey(account, device)),
List.of(
encodeAciAndDeviceId(account, device),
String.valueOf(clock.millis()),
String.valueOf(BACKGROUND_NOTIFICATION_PERIOD.toMillis())))
.thenRun(Util.NOOP);
}
/**
* Schedules a "new message" push notification to be delivered to the given device after at least the given duration.
* If another notification had previously been scheduled, calling this method will replace the previously-scheduled
* delivery time with the given time.
*
* @param account the account to which the target device belongs
* @param device the device to which to deliver a "new message" push notification
* @param minDelay the minimum delay after which to deliver the notification
*
* @return a future that completes once the notification has been scheduled
*/
public CompletableFuture<Void> scheduleDelayedNotification(final Account account, final Device device, final Duration minDelay) {
return pushSchedulingCluster.withCluster(connection ->
connection.async().zadd(getDelayedNotificationQueueKey(account, device),
clock.instant().plus(minDelay).toEpochMilli(),
encodeAciAndDeviceId(account, device)))
.thenRun(() -> Metrics.counter(DELAYED_NOTIFICATION_SCHEDULED_COUNTER_NAME,
TOKEN_TYPE_TAG, getTokenType(device))
.increment())
.toCompletableFuture();
}
/**
* Cancel a scheduled recurring VOIP notification
*
* @return A CompletionStage that completes when the scheduled task has been cancelled.
*/
public CompletionStage<Void> cancelScheduledNotifications(Account account, Device device) {
return CompletableFuture.allOf(
cancelRecurringApnsVoipNotifications(account, device),
cancelBackgroundApnsNotifications(account, device),
cancelDelayedNotifications(account, device));
}
private CompletableFuture<Void> cancelRecurringApnsVoipNotifications(final Account account, final Device device) {
return removeRecurringApnsVoipNotificationEntry(account.getIdentifier(IdentityType.ACI), device.getId())
.thenCompose(removed -> {
if (removed) {
delivered.increment();
}
return pushSchedulingCluster.withCluster(connection ->
connection.async().zrem(
getPendingBackgroundApnsNotificationQueueKey(account, device),
encodeAciAndDeviceId(account, device)));
})
.thenRun(Util.NOOP)
.toCompletableFuture();
}
@VisibleForTesting
CompletableFuture<Void> cancelBackgroundApnsNotifications(final Account account, final Device device) {
return pushSchedulingCluster.withCluster(connection -> connection.async()
.zrem(getPendingBackgroundApnsNotificationQueueKey(account, device), encodeAciAndDeviceId(account, device)))
.thenRun(Util.NOOP)
.toCompletableFuture();
}
@VisibleForTesting
CompletableFuture<Void> cancelDelayedNotifications(final Account account, final Device device) {
return pushSchedulingCluster.withCluster(connection ->
connection.async().zrem(getDelayedNotificationQueueKey(account, device),
encodeAciAndDeviceId(account, device)))
.thenRun(Util.NOOP)
.toCompletableFuture();
}
@Override
public synchronized void start() {
running.set(true);
for (final Thread workerThread : workerThreads) {
workerThread.start();
}
}
@Override
public synchronized void stop() throws InterruptedException {
running.set(false);
for (final Thread workerThread : workerThreads) {
workerThread.join();
}
}
private CompletableFuture<Void> sendRecurringApnsVoipNotification(final Account account, final Device device) {
if (StringUtils.isBlank(device.getVoipApnId())) {
return removeRecurringApnsVoipNotificationEntry(account.getIdentifier(IdentityType.ACI), device.getId())
.thenRun(Util.NOOP);
}
if (device.getLastSeen() < clock.millis() - TimeUnit.DAYS.toMillis(7)) {
return removeRecurringApnsVoipNotificationEntry(account.getIdentifier(IdentityType.ACI), device.getId())
.thenRun(evicted::increment);
}
return apnSender.sendNotification(new PushNotification(device.getVoipApnId(), PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device, true))
.thenRun(retry::increment);
}
@VisibleForTesting
CompletableFuture<Void> sendBackgroundApnsNotification(final Account account, final Device device) {
if (StringUtils.isBlank(device.getApnId())) {
return CompletableFuture.completedFuture(null);
}
// It's okay for the "last notification" timestamp to expire after the "cooldown" period has elapsed; a missing
// timestamp and a timestamp older than the period are functionally equivalent.
return pushSchedulingCluster.withCluster(connection -> connection.async().set(
getLastBackgroundApnsNotificationTimestampKey(account, device),
String.valueOf(clock.millis()), new SetArgs().ex(BACKGROUND_NOTIFICATION_PERIOD)))
.thenCompose(ignored -> apnSender.sendNotification(new PushNotification(device.getApnId(), PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, account, device, false)))
.thenAccept(response -> Metrics.counter(BACKGROUND_NOTIFICATION_SENT_COUNTER_NAME,
ACCEPTED_TAG, String.valueOf(response.accepted()))
.increment())
.toCompletableFuture();
}
@VisibleForTesting
CompletableFuture<Void> sendDelayedNotification(final Account account, final Device device) {
if (StringUtils.isAllBlank(device.getApnId(), device.getGcmId())) {
return CompletableFuture.completedFuture(null);
}
final boolean isApnsDevice = StringUtils.isNotBlank(device.getApnId());
final PushNotification pushNotification = new PushNotification(
isApnsDevice ? device.getApnId() : device.getGcmId(),
isApnsDevice ? PushNotification.TokenType.APN : PushNotification.TokenType.FCM,
PushNotification.NotificationType.NOTIFICATION,
null,
account,
device,
true);
final PushNotificationSender pushNotificationSender = isApnsDevice ? apnSender : fcmSender;
return pushNotificationSender.sendNotification(pushNotification)
.thenAccept(response -> Metrics.counter(DELAYED_NOTIFICATION_SENT_COUNTER_NAME,
TOKEN_TYPE_TAG, getTokenType(device),
ACCEPTED_TAG, String.valueOf(response.accepted()))
.increment());
}
@VisibleForTesting
static String encodeAciAndDeviceId(final Account account, final Device device) {
return account.getUuid() + ":" + device.getId();
}
static Pair<UUID, Byte> decodeAciAndDeviceId(final String encoded) {
if (StringUtils.isBlank(encoded)) {
throw new IllegalArgumentException("Encoded ACI/device ID pair must not be blank");
}
final int separatorIndex = encoded.indexOf(':');
if (separatorIndex == -1) {
throw new IllegalArgumentException("String did not contain a ':' separator");
}
final UUID aci = UUID.fromString(encoded.substring(0, separatorIndex));
final byte deviceId = Byte.parseByte(encoded.substring(separatorIndex + 1));
return new Pair<>(aci, deviceId);
}
@VisibleForTesting
CompletableFuture<Optional<Pair<Account, Device>>> getAccountAndDeviceFromPairString(final String endpoint) {
final Pair<UUID, Byte> aciAndDeviceId = decodeAciAndDeviceId(endpoint);
return accountsManager.getByAccountIdentifierAsync(aciAndDeviceId.first())
.thenApply(maybeAccount -> maybeAccount
.flatMap(account -> account.getDevice(aciAndDeviceId.second()).map(device -> new Pair<>(account, device))));
}
private CompletableFuture<Boolean> removeRecurringApnsVoipNotificationEntry(final UUID aci, final byte deviceId) {
final String endpoint = getVoipEndpointKey(aci, deviceId);
return removePendingVoipDestinationScript.executeAsync(
List.of(getPendingRecurringApnsVoipNotificationQueueKey(endpoint), endpoint), Collections.emptyList())
.thenApply(result -> ((long) result) > 0);
}
@SuppressWarnings("unchecked")
@VisibleForTesting
List<String> getPendingDestinationsForRecurringApnsVoipNotifications(final int slot, final int limit) {
return (List<String>) getPendingVoipDestinationsScript.execute(
List.of(getPendingRecurringApnsVoipNotificationQueueKey(slot)),
List.of(String.valueOf(clock.millis()), String.valueOf(limit)));
}
@SuppressWarnings("SameParameterValue")
private CompletionStage<Void> insertRecurringApnsVoipNotificationEntry(final UUID aci, final byte deviceId, final long timestamp, final long interval) {
final String endpoint = getVoipEndpointKey(aci, deviceId);
return insertPendingVoipDestinationScript.executeAsync(
List.of(getPendingRecurringApnsVoipNotificationQueueKey(endpoint), endpoint),
List.of(String.valueOf(timestamp),
String.valueOf(interval),
aci.toString(),
String.valueOf(deviceId)))
.thenRun(Util.NOOP);
}
@VisibleForTesting
static String getVoipEndpointKey(final UUID aci, final byte deviceId) {
return "apn_device::{" + aci + "::" + deviceId + "}";
}
private static String getPendingRecurringApnsVoipNotificationQueueKey(final String endpoint) {
return getPendingRecurringApnsVoipNotificationQueueKey(SlotHash.getSlot(endpoint));
}
private static String getPendingRecurringApnsVoipNotificationQueueKey(final int slot) {
return PENDING_RECURRING_VOIP_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
}
@VisibleForTesting
static String getPendingBackgroundApnsNotificationQueueKey(final Account account, final Device device) {
return getPendingBackgroundApnsNotificationQueueKey(SlotHash.getSlot(encodeAciAndDeviceId(account, device)));
}
private static String getPendingBackgroundApnsNotificationQueueKey(final int slot) {
return PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
}
private static String getLastBackgroundApnsNotificationTimestampKey(final Account account, final Device device) {
return LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX + "::{" + encodeAciAndDeviceId(account, device) + "}";
}
@VisibleForTesting
static String getDelayedNotificationQueueKey(final Account account, final Device device) {
return getDelayedNotificationQueueKey(SlotHash.getSlot(encodeAciAndDeviceId(account, device)));
}
private static String getDelayedNotificationQueueKey(final int slot) {
return PENDING_DELAYED_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
}
@VisibleForTesting
Optional<Instant> getLastBackgroundApnsNotificationTimestamp(final Account account, final Device device) {
return Optional.ofNullable(
pushSchedulingCluster.withCluster(connection ->
connection.sync().get(getLastBackgroundApnsNotificationTimestampKey(account, device))))
.map(timestampString -> Instant.ofEpochMilli(Long.parseLong(timestampString)));
}
@VisibleForTesting
Optional<Instant> getNextScheduledBackgroundApnsNotificationTimestamp(final Account account, final Device device) {
return Optional.ofNullable(
pushSchedulingCluster.withCluster(connection ->
connection.sync().zscore(getPendingBackgroundApnsNotificationQueueKey(account, device),
encodeAciAndDeviceId(account, device))))
.map(timestamp -> Instant.ofEpochMilli(timestamp.longValue()));
}
@VisibleForTesting
Optional<Instant> getNextScheduledDelayedNotificationTimestamp(final Account account, final Device device) {
return Optional.ofNullable(
pushSchedulingCluster.withCluster(connection ->
connection.sync().zscore(getDelayedNotificationQueueKey(account, device),
encodeAciAndDeviceId(account, device))))
.map(timestamp -> Instant.ofEpochMilli(timestamp.longValue()));
}
private static String getTokenType(final Device device) {
if (StringUtils.isNotBlank(device.getApnId())) {
return "apns";
} else if (StringUtils.isNotBlank(device.getGcmId())) {
return "fcm";
} else {
return "unknown";
}
}
}

View File

@ -24,6 +24,7 @@ import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
@ -52,6 +53,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final MessageMetrics messageMetrics; private final MessageMetrics messageMetrics;
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
private final PushNotificationScheduler pushNotificationScheduler;
private final ClientPresenceManager clientPresenceManager; private final ClientPresenceManager clientPresenceManager;
private final ScheduledExecutorService scheduledExecutorService; private final ScheduledExecutorService scheduledExecutorService;
private final Scheduler messageDeliveryScheduler; private final Scheduler messageDeliveryScheduler;
@ -71,6 +73,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
MessagesManager messagesManager, MessagesManager messagesManager,
MessageMetrics messageMetrics, MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager, PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
ClientPresenceManager clientPresenceManager, ClientPresenceManager clientPresenceManager,
ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
@ -79,6 +82,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.messageMetrics = messageMetrics; this.messageMetrics = messageMetrics;
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
this.pushNotificationScheduler = pushNotificationScheduler;
this.clientPresenceManager = clientPresenceManager; this.clientPresenceManager = clientPresenceManager;
this.scheduledExecutorService = scheduledExecutorService; this.scheduledExecutorService = scheduledExecutorService;
this.messageDeliveryScheduler = messageDeliveryScheduler; this.messageDeliveryScheduler = messageDeliveryScheduler;
@ -142,6 +146,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
messagesManager, messagesManager,
messageMetrics, messageMetrics,
pushNotificationManager, pushNotificationManager,
pushNotificationScheduler,
auth, auth,
context.getClient(), context.getClient(),
scheduledExecutorService, scheduledExecutorService,

View File

@ -13,6 +13,7 @@ import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -43,8 +44,8 @@ import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -88,8 +89,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
"sendMessages"); "sendMessages");
private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class, private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class,
"sendMessageError"); "sendMessageError");
private static final String PUSH_NOTIFICATION_ON_CLOSE_COUNTER_NAME =
MetricsUtil.name(WebSocketConnection.class, "pushNotificationOnClose");
private static final String STATUS_CODE_TAG = "status"; private static final String STATUS_CODE_TAG = "status";
private static final String STATUS_MESSAGE_TAG = "message"; private static final String STATUS_MESSAGE_TAG = "message";
private static final String ERROR_TYPE_TAG = "errorType"; private static final String ERROR_TYPE_TAG = "errorType";
@ -109,12 +108,15 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private static final int DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS = 5 * 60 * 1000; private static final int DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS = 5 * 60 * 1000;
private static final Duration CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY = Duration.ofMinutes(1);
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private final ReceiptSender receiptSender; private final ReceiptSender receiptSender;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final MessageMetrics messageMetrics; private final MessageMetrics messageMetrics;
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
private final PushNotificationScheduler pushNotificationScheduler;
private final AuthenticatedDevice auth; private final AuthenticatedDevice auth;
private final WebSocketClient client; private final WebSocketClient client;
@ -148,6 +150,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
MessagesManager messagesManager, MessagesManager messagesManager,
MessageMetrics messageMetrics, MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager, PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
AuthenticatedDevice auth, AuthenticatedDevice auth,
WebSocketClient client, WebSocketClient client,
ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService,
@ -158,6 +161,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
messagesManager, messagesManager,
messageMetrics, messageMetrics,
pushNotificationManager, pushNotificationManager,
pushNotificationScheduler,
auth, auth,
client, client,
DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS, DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS,
@ -171,6 +175,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
MessagesManager messagesManager, MessagesManager messagesManager,
MessageMetrics messageMetrics, MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager, PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
AuthenticatedDevice auth, AuthenticatedDevice auth,
WebSocketClient client, WebSocketClient client,
int sendFuturesTimeoutMillis, int sendFuturesTimeoutMillis,
@ -182,6 +187,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.messageMetrics = messageMetrics; this.messageMetrics = messageMetrics;
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
this.pushNotificationScheduler = pushNotificationScheduler;
this.auth = auth; this.auth = auth;
this.client = client; this.client = client;
this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis; this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis;
@ -211,14 +217,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
client.close(1000, "OK"); client.close(1000, "OK");
if (storedMessageState.get() != StoredMessageState.EMPTY) { if (storedMessageState.get() != StoredMessageState.EMPTY) {
try { pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(),
pushNotificationManager.sendNewMessageNotification(auth.getAccount(), auth.getAuthenticatedDevice().getId(), true); auth.getAuthenticatedDevice(),
CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY);
Metrics.counter(PUSH_NOTIFICATION_ON_CLOSE_COUNTER_NAME,
Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())))
.increment();
} catch (NotPushRegisteredException ignored) {
}
} }
} }

View File

@ -33,7 +33,7 @@ import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controll
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.APNSender; import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender; import org.whispersystems.textsecuregcm.push.FcmSender;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@ -245,10 +245,10 @@ record CommandDependencies(
clock); clock);
APNSender apnSender = new APNSender(apnSenderExecutor, configuration.getApnConfiguration()); APNSender apnSender = new APNSender(apnSenderExecutor, configuration.getApnConfiguration());
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, configuration.getFcmConfiguration().credentials().value()); FcmSender fcmSender = new FcmSender(fcmSenderExecutor, configuration.getFcmConfiguration().credentials().value());
ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, PushNotificationScheduler pushNotificationScheduler = new PushNotificationScheduler(pushSchedulerCluster,
apnSender, accountsManager, 0); apnSender, fcmSender, accountsManager, 0, 0);
PushNotificationManager pushNotificationManager = PushNotificationManager pushNotificationManager =
new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler); new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler);
PushNotificationExperimentSamples pushNotificationExperimentSamples = PushNotificationExperimentSamples pushNotificationExperimentSamples =
new PushNotificationExperimentSamples(dynamoDbAsyncClient, new PushNotificationExperimentSamples(dynamoDbAsyncClient,
configuration.getDynamoDbTables().getPushNotificationExperimentSamples().getTableName(), configuration.getDynamoDbTables().getPushNotificationExperimentSamples().getTableName(),

View File

@ -18,12 +18,14 @@ import net.sourceforge.argparse4j.inf.Subparser;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.push.APNSender; import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler; import org.whispersystems.textsecuregcm.push.FcmSender;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler; import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler;
public class ScheduledApnPushNotificationSenderServiceCommand extends ServerCommand<WhisperServerConfiguration> { public class ScheduledApnPushNotificationSenderServiceCommand extends ServerCommand<WhisperServerConfiguration> {
private static final String WORKER_COUNT = "workers"; private static final String WORKER_COUNT = "workers";
private static final String MAX_CONCURRENCY = "max_concurrency";
public ScheduledApnPushNotificationSenderServiceCommand() { public ScheduledApnPushNotificationSenderServiceCommand() {
super(new Application<>() { super(new Application<>() {
@ -38,11 +40,19 @@ public class ScheduledApnPushNotificationSenderServiceCommand extends ServerComm
@Override @Override
public void configure(final Subparser subparser) { public void configure(final Subparser subparser) {
super.configure(subparser); super.configure(subparser);
subparser.addArgument("--workers") subparser.addArgument("--workers")
.type(Integer.class) .type(Integer.class)
.dest(WORKER_COUNT) .dest(WORKER_COUNT)
.required(true) .required(true)
.help("The number of worker threads"); .help("The number of worker threads");
subparser.addArgument("--max-concurrency")
.type(Integer.class)
.dest(MAX_CONCURRENCY)
.required(false)
.setDefault(16)
.help("The number of concurrent operations per worker thread");
} }
@Override @Override
@ -63,15 +73,16 @@ public class ScheduledApnPushNotificationSenderServiceCommand extends ServerComm
}); });
} }
final ExecutorService apnSenderExecutor = environment.lifecycle().executorService(name(getClass(), "apnSender-%d")) final ExecutorService pushNotificationSenderExecutor = environment.lifecycle().executorService(name(getClass(), "apnSender-%d"))
.maxThreads(1).minThreads(1).build(); .maxThreads(1).minThreads(1).build();
final APNSender apnSender = new APNSender(apnSenderExecutor, configuration.getApnConfiguration()); final APNSender apnSender = new APNSender(pushNotificationSenderExecutor, configuration.getApnConfiguration());
final ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler( final FcmSender fcmSender = new FcmSender(pushNotificationSenderExecutor, configuration.getFcmConfiguration().credentials().value());
deps.pushSchedulerCluster(), apnSender, deps.accountsManager(), namespace.getInt(WORKER_COUNT)); final PushNotificationScheduler pushNotificationScheduler = new PushNotificationScheduler(
deps.pushSchedulerCluster(), apnSender, fcmSender, deps.accountsManager(), namespace.getInt(WORKER_COUNT), namespace.getInt(MAX_CONCURRENCY));
environment.lifecycle().manage(apnSender); environment.lifecycle().manage(apnSender);
environment.lifecycle().manage(apnPushNotificationScheduler); environment.lifecycle().manage(pushNotificationScheduler);
MetricsUtil.registerSystemResourceMetrics(environment); MetricsUtil.registerSystemResourceMetrics(environment);

View File

@ -76,6 +76,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets; import org.junitpioneer.jupiter.cartesian.ArgumentSets;
@ -109,6 +110,7 @@ import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider; import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.spam.SpamChecker; import org.whispersystems.textsecuregcm.spam.SpamChecker;
@ -179,6 +181,7 @@ class MessageControllerTest {
private static final CardinalityEstimator cardinalityEstimator = mock(CardinalityEstimator.class); private static final CardinalityEstimator cardinalityEstimator = mock(CardinalityEstimator.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class); private static final RateLimiter rateLimiter = mock(RateLimiter.class);
private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);
private static final PushNotificationScheduler pushNotificationScheduler = mock(PushNotificationScheduler.class);
private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService(); private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService();
private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
@ -200,13 +203,15 @@ class MessageControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource( .addResource(
new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager, new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager,
messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor, messagesManager, pushNotificationManager, pushNotificationScheduler, reportMessageManager, multiRecipientMessageExecutor,
messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager, messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager,
serverSecretParams, SpamChecker.noop(), new MessageMetrics(), clock)) serverSecretParams, SpamChecker.noop(), new MessageMetrics(), clock))
.build(); .build();
@BeforeEach @BeforeEach
void setup() { void setup() {
reset(pushNotificationScheduler);
final List<Device> singleDeviceList = List.of( final List<Device> singleDeviceList = List.of(
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, true) generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, true)
); );
@ -630,8 +635,13 @@ class MessageControllerTest {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource @CsvSource({
void testGetMessages(boolean receiveStories) { "false, false",
"false, true",
"true, false",
"true, true"
})
void testGetMessages(final boolean receiveStories, final boolean hasMore) {
final long timestampOne = 313377; final long timestampOne = 313377;
final long timestampTwo = 313388; final long timestampTwo = 313388;
@ -651,7 +661,7 @@ class MessageControllerTest {
); );
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), anyBoolean()))
.thenReturn(Mono.just(new Pair<>(envelopes, false))); .thenReturn(Mono.just(new Pair<>(envelopes, hasMore)));
final String userAgent = "Test-UA"; final String userAgent = "Test-UA";
@ -685,13 +695,12 @@ class MessageControllerTest {
} }
verify(pushNotificationManager).handleMessagesRetrieved(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE, userAgent); verify(pushNotificationManager).handleMessagesRetrieved(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE, userAgent);
}
private static Stream<Arguments> testGetMessages() { if (hasMore) {
return Stream.of( verify(pushNotificationScheduler).scheduleDelayedNotification(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any());
Arguments.of(true), } else {
Arguments.of(false) verify(pushNotificationScheduler, never()).scheduleDelayedNotification(any(), any(), any());
); }
} }
@Test @Test

View File

@ -1,252 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.SlotHash;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
class ApnPushNotificationSchedulerTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private Account account;
private Device device;
private APNSender apnSender;
private TestClock clock;
private ApnPushNotificationScheduler apnPushNotificationScheduler;
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final String ACCOUNT_NUMBER = "+18005551234";
private static final byte DEVICE_ID = 1;
private static final String APN_ID = RandomStringUtils.randomAlphanumeric(32);
private static final String VOIP_APN_ID = RandomStringUtils.randomAlphanumeric(32);
@BeforeEach
void setUp() throws Exception {
device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
when(device.getApnId()).thenReturn(APN_ID);
when(device.getVoipApnId()).thenReturn(VOIP_APN_ID);
when(device.getLastSeen()).thenReturn(System.currentTimeMillis());
account = mock(Account.class);
when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(account.getNumber()).thenReturn(ACCOUNT_NUMBER);
when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device));
final AccountsManager accountsManager = mock(AccountsManager.class);
when(accountsManager.getByE164(ACCOUNT_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(ACCOUNT_UUID)).thenReturn(Optional.of(account));
apnSender = mock(APNSender.class);
clock = TestClock.now();
apnPushNotificationScheduler = new ApnPushNotificationScheduler(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
apnSender, accountsManager, clock, 1);
}
@Test
void testClusterInsert() throws ExecutionException, InterruptedException {
final String endpoint = ApnPushNotificationScheduler.getEndpointKey(account, device);
final long currentTimeMillis = System.currentTimeMillis();
assertTrue(
apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty());
clock.pin(Instant.ofEpochMilli(currentTimeMillis - 30_000));
apnPushNotificationScheduler.scheduleRecurringVoipNotification(account, device).toCompletableFuture().get();
clock.pin(Instant.ofEpochMilli(currentTimeMillis));
final List<String> pendingDestinations = apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 2);
assertEquals(1, pendingDestinations.size());
final Optional<Pair<String, Byte>> maybeUuidAndDeviceId = ApnPushNotificationScheduler.getSeparated(
pendingDestinations.get(0));
assertTrue(maybeUuidAndDeviceId.isPresent());
assertEquals(ACCOUNT_UUID.toString(), maybeUuidAndDeviceId.get().first());
assertEquals(DEVICE_ID, maybeUuidAndDeviceId.get().second());
assertTrue(
apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty());
}
@Test
void testProcessRecurringVoipNotifications() throws ExecutionException, InterruptedException {
final ApnPushNotificationScheduler.NotificationWorker worker = apnPushNotificationScheduler.new NotificationWorker();
final long currentTimeMillis = System.currentTimeMillis();
clock.pin(Instant.ofEpochMilli(currentTimeMillis - 30_000));
apnPushNotificationScheduler.scheduleRecurringVoipNotification(account, device).toCompletableFuture().get();
clock.pin(Instant.ofEpochMilli(currentTimeMillis));
final int slot = SlotHash.getSlot(ApnPushNotificationScheduler.getEndpointKey(account, device));
assertEquals(1, worker.processRecurringVoipNotifications(slot));
final ArgumentCaptor<PushNotification> notificationCaptor = ArgumentCaptor.forClass(PushNotification.class);
verify(apnSender).sendNotification(notificationCaptor.capture());
final PushNotification pushNotification = notificationCaptor.getValue();
assertEquals(VOIP_APN_ID, pushNotification.deviceToken());
assertEquals(account, pushNotification.destination());
assertEquals(device, pushNotification.destinationDevice());
assertEquals(0, worker.processRecurringVoipNotifications(slot));
}
@Test
void testScheduleBackgroundNotificationWithNoRecentNotification() throws ExecutionException, InterruptedException {
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
clock.pin(now);
assertEquals(Optional.empty(),
apnPushNotificationScheduler.getLastBackgroundNotificationTimestamp(account, device));
assertEquals(Optional.empty(),
apnPushNotificationScheduler.getNextScheduledBackgroundNotificationTimestamp(account, device));
apnPushNotificationScheduler.scheduleBackgroundNotification(account, device).toCompletableFuture().get();
assertEquals(Optional.of(now),
apnPushNotificationScheduler.getNextScheduledBackgroundNotificationTimestamp(account, device));
}
@Test
void testScheduleBackgroundNotificationWithRecentNotification() throws ExecutionException, InterruptedException {
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
final Instant recentNotificationTimestamp =
now.minus(ApnPushNotificationScheduler.BACKGROUND_NOTIFICATION_PERIOD.dividedBy(2));
// Insert a timestamp for a recently-sent background push notification
clock.pin(Instant.ofEpochMilli(recentNotificationTimestamp.toEpochMilli()));
apnPushNotificationScheduler.sendBackgroundNotification(account, device);
clock.pin(now);
apnPushNotificationScheduler.scheduleBackgroundNotification(account, device).toCompletableFuture().get();
final Instant expectedScheduledTimestamp =
recentNotificationTimestamp.plus(ApnPushNotificationScheduler.BACKGROUND_NOTIFICATION_PERIOD);
assertEquals(Optional.of(expectedScheduledTimestamp),
apnPushNotificationScheduler.getNextScheduledBackgroundNotificationTimestamp(account, device));
}
@Test
void testProcessScheduledBackgroundNotifications() throws ExecutionException, InterruptedException {
final ApnPushNotificationScheduler.NotificationWorker worker = apnPushNotificationScheduler.new NotificationWorker();
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
clock.pin(Instant.ofEpochMilli(now.toEpochMilli()));
apnPushNotificationScheduler.scheduleBackgroundNotification(account, device).toCompletableFuture().get();
final int slot =
SlotHash.getSlot(ApnPushNotificationScheduler.getPendingBackgroundNotificationQueueKey(account, device));
clock.pin(Instant.ofEpochMilli(now.minusMillis(1).toEpochMilli()));
assertEquals(0, worker.processScheduledBackgroundNotifications(slot));
clock.pin(now);
assertEquals(1, worker.processScheduledBackgroundNotifications(slot));
final ArgumentCaptor<PushNotification> notificationCaptor = ArgumentCaptor.forClass(PushNotification.class);
verify(apnSender).sendNotification(notificationCaptor.capture());
final PushNotification pushNotification = notificationCaptor.getValue();
assertEquals(PushNotification.TokenType.APN, pushNotification.tokenType());
assertEquals(APN_ID, pushNotification.deviceToken());
assertEquals(account, pushNotification.destination());
assertEquals(device, pushNotification.destinationDevice());
assertEquals(PushNotification.NotificationType.NOTIFICATION, pushNotification.notificationType());
assertFalse(pushNotification.urgent());
assertEquals(0, worker.processRecurringVoipNotifications(slot));
}
@Test
void testProcessScheduledBackgroundNotificationsCancelled() throws ExecutionException, InterruptedException {
final ApnPushNotificationScheduler.NotificationWorker worker = apnPushNotificationScheduler.new NotificationWorker();
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
clock.pin(now);
apnPushNotificationScheduler.scheduleBackgroundNotification(account, device).toCompletableFuture().get();
apnPushNotificationScheduler.cancelScheduledNotifications(account, device).toCompletableFuture().get();
final int slot =
SlotHash.getSlot(ApnPushNotificationScheduler.getPendingBackgroundNotificationQueueKey(account, device));
assertEquals(0, worker.processScheduledBackgroundNotifications(slot));
verify(apnSender, never()).sendNotification(any());
}
@ParameterizedTest
@CsvSource({
"1, true",
"0, false",
})
void testDedicatedProcessDynamicConfiguration(final int dedicatedThreadCount, final boolean expectActivity)
throws Exception {
final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
when(redisCluster.withCluster(any())).thenReturn(0L);
final AccountsManager accountsManager = mock(AccountsManager.class);
apnPushNotificationScheduler = new ApnPushNotificationScheduler(redisCluster, apnSender,
accountsManager, dedicatedThreadCount);
apnPushNotificationScheduler.start();
apnPushNotificationScheduler.stop();
if (expectActivity) {
verify(redisCluster, atLeastOnce()).withCluster(any());
} else {
verifyNoInteractions(redisCluster);
verifyNoInteractions(accountsManager);
verifyNoInteractions(apnSender);
}
}
}

View File

@ -33,7 +33,7 @@ class PushNotificationManagerTest {
private AccountsManager accountsManager; private AccountsManager accountsManager;
private APNSender apnSender; private APNSender apnSender;
private FcmSender fcmSender; private FcmSender fcmSender;
private ApnPushNotificationScheduler apnPushNotificationScheduler; private PushNotificationScheduler pushNotificationScheduler;
private PushNotificationManager pushNotificationManager; private PushNotificationManager pushNotificationManager;
@ -42,12 +42,12 @@ class PushNotificationManagerTest {
accountsManager = mock(AccountsManager.class); accountsManager = mock(AccountsManager.class);
apnSender = mock(APNSender.class); apnSender = mock(APNSender.class);
fcmSender = mock(FcmSender.class); fcmSender = mock(FcmSender.class);
apnPushNotificationScheduler = mock(ApnPushNotificationScheduler.class); pushNotificationScheduler = mock(PushNotificationScheduler.class);
AccountsHelper.setupMockUpdate(accountsManager); AccountsHelper.setupMockUpdate(accountsManager);
pushNotificationManager = pushNotificationManager =
new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler); new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler);
} }
@ParameterizedTest @ParameterizedTest
@ -152,7 +152,7 @@ class PushNotificationManagerTest {
verifyNoInteractions(apnSender); verifyNoInteractions(apnSender);
verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any());
verify(device, never()).setGcmId(any()); verify(device, never()).setGcmId(any());
verifyNoInteractions(apnPushNotificationScheduler); verifyNoInteractions(pushNotificationScheduler);
} }
@ParameterizedTest @ParameterizedTest
@ -171,7 +171,7 @@ class PushNotificationManagerTest {
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, Optional.empty(), false, Optional.empty()))); .thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, Optional.empty(), false, Optional.empty())));
if (!urgent) { if (!urgent) {
when(apnPushNotificationScheduler.scheduleBackgroundNotification(account, device)) when(pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device))
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
} }
@ -181,10 +181,10 @@ class PushNotificationManagerTest {
if (urgent) { if (urgent) {
verify(apnSender).sendNotification(pushNotification); verify(apnSender).sendNotification(pushNotification);
verifyNoInteractions(apnPushNotificationScheduler); verifyNoInteractions(pushNotificationScheduler);
} else { } else {
verifyNoInteractions(apnSender); verifyNoInteractions(apnSender);
verify(apnPushNotificationScheduler).scheduleBackgroundNotification(account, device); verify(pushNotificationScheduler).scheduleBackgroundApnsNotification(account, device);
} }
} }
@ -210,8 +210,8 @@ class PushNotificationManagerTest {
verifyNoInteractions(fcmSender); verifyNoInteractions(fcmSender);
verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any());
verify(device, never()).setGcmId(any()); verify(device, never()).setGcmId(any());
verify(apnPushNotificationScheduler).scheduleRecurringVoipNotification(account, device); verify(pushNotificationScheduler).scheduleRecurringApnsVoipNotification(account, device);
verify(apnPushNotificationScheduler, never()).scheduleBackgroundNotification(any(), any()); verify(pushNotificationScheduler, never()).scheduleBackgroundApnsNotification(any(), any());
} }
@Test @Test
@ -236,7 +236,7 @@ class PushNotificationManagerTest {
verify(accountsManager).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); verify(accountsManager).updateDevice(eq(account), eq(Device.PRIMARY_ID), any());
verify(device).setGcmId(null); verify(device).setGcmId(null);
verifyNoInteractions(apnSender); verifyNoInteractions(apnSender);
verifyNoInteractions(apnPushNotificationScheduler); verifyNoInteractions(pushNotificationScheduler);
} }
@Test @Test
@ -257,7 +257,7 @@ class PushNotificationManagerTest {
when(apnSender.sendNotification(pushNotification)) when(apnSender.sendNotification(pushNotification))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(false, Optional.empty(), true, Optional.empty()))); .thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(false, Optional.empty(), true, Optional.empty())));
when(apnPushNotificationScheduler.cancelScheduledNotifications(account, device)) when(pushNotificationScheduler.cancelScheduledNotifications(account, device))
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
pushNotificationManager.sendNotification(pushNotification); pushNotificationManager.sendNotification(pushNotification);
@ -266,7 +266,7 @@ class PushNotificationManagerTest {
verify(accountsManager).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); verify(accountsManager).updateDevice(eq(account), eq(Device.PRIMARY_ID), any());
verify(device).setVoipApnId(null); verify(device).setVoipApnId(null);
verify(device, never()).setApnId(any()); verify(device, never()).setApnId(any());
verify(apnPushNotificationScheduler).cancelScheduledNotifications(account, device); verify(pushNotificationScheduler).cancelScheduledNotifications(account, device);
} }
@Test @Test
@ -290,7 +290,7 @@ class PushNotificationManagerTest {
when(apnSender.sendNotification(pushNotification)) when(apnSender.sendNotification(pushNotification))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(false, Optional.empty(), true, Optional.of(tokenTimestamp.minusSeconds(60))))); .thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(false, Optional.empty(), true, Optional.of(tokenTimestamp.minusSeconds(60)))));
when(apnPushNotificationScheduler.cancelScheduledNotifications(account, device)) when(pushNotificationScheduler.cancelScheduledNotifications(account, device))
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
pushNotificationManager.sendNotification(pushNotification); pushNotificationManager.sendNotification(pushNotification);
@ -299,7 +299,7 @@ class PushNotificationManagerTest {
verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any());
verify(device, never()).setVoipApnId(any()); verify(device, never()).setVoipApnId(any());
verify(device, never()).setApnId(any()); verify(device, never()).setApnId(any());
verify(apnPushNotificationScheduler, never()).cancelScheduledNotifications(account, device); verify(pushNotificationScheduler, never()).cancelScheduledNotifications(account, device);
} }
@Test @Test
@ -312,11 +312,11 @@ class PushNotificationManagerTest {
when(account.getUuid()).thenReturn(accountIdentifier); when(account.getUuid()).thenReturn(accountIdentifier);
when(device.getId()).thenReturn(Device.PRIMARY_ID); when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(apnPushNotificationScheduler.cancelScheduledNotifications(account, device)) when(pushNotificationScheduler.cancelScheduledNotifications(account, device))
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
pushNotificationManager.handleMessagesRetrieved(account, device, userAgent); pushNotificationManager.handleMessagesRetrieved(account, device, userAgent);
verify(apnPushNotificationScheduler).cancelScheduledNotifications(account, device); verify(pushNotificationScheduler).cancelScheduledNotifications(account, device);
} }
} }

View File

@ -0,0 +1,327 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.SlotHash;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
class PushNotificationSchedulerTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private Account account;
private Device device;
private APNSender apnSender;
private FcmSender fcmSender;
private TestClock clock;
private PushNotificationScheduler pushNotificationScheduler;
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final String ACCOUNT_NUMBER = "+18005551234";
private static final byte DEVICE_ID = 1;
private static final String APN_ID = RandomStringUtils.randomAlphanumeric(32);
private static final String VOIP_APN_ID = RandomStringUtils.randomAlphanumeric(32);
@BeforeEach
void setUp() throws Exception {
device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
when(device.getApnId()).thenReturn(APN_ID);
when(device.getVoipApnId()).thenReturn(VOIP_APN_ID);
when(device.getLastSeen()).thenReturn(System.currentTimeMillis());
account = mock(Account.class);
when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_UUID);
when(account.getNumber()).thenReturn(ACCOUNT_NUMBER);
when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device));
final AccountsManager accountsManager = mock(AccountsManager.class);
when(accountsManager.getByE164(ACCOUNT_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifierAsync(ACCOUNT_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
apnSender = mock(APNSender.class);
fcmSender = mock(FcmSender.class);
clock = TestClock.now();
when(apnSender.sendNotification(any()))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, Optional.empty(), false, Optional.empty())));
when(fcmSender.sendNotification(any()))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, Optional.empty(), false, Optional.empty())));
pushNotificationScheduler = new PushNotificationScheduler(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
apnSender, fcmSender, accountsManager, clock, 1, 1);
}
@Test
void testClusterInsert() throws ExecutionException, InterruptedException {
final String endpoint = PushNotificationScheduler.getVoipEndpointKey(ACCOUNT_UUID, DEVICE_ID);
final long currentTimeMillis = System.currentTimeMillis();
assertTrue(
pushNotificationScheduler.getPendingDestinationsForRecurringApnsVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty());
clock.pin(Instant.ofEpochMilli(currentTimeMillis - 30_000));
pushNotificationScheduler.scheduleRecurringApnsVoipNotification(account, device).toCompletableFuture().get();
clock.pin(Instant.ofEpochMilli(currentTimeMillis));
final List<String> pendingDestinations = pushNotificationScheduler.getPendingDestinationsForRecurringApnsVoipNotifications(SlotHash.getSlot(endpoint), 2);
assertEquals(1, pendingDestinations.size());
final Pair<UUID, Byte> aciAndDeviceId =
PushNotificationScheduler.decodeAciAndDeviceId(pendingDestinations.getFirst());
assertEquals(ACCOUNT_UUID, aciAndDeviceId.first());
assertEquals(DEVICE_ID, aciAndDeviceId.second());
assertTrue(
pushNotificationScheduler.getPendingDestinationsForRecurringApnsVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty());
}
@Test
void testProcessRecurringVoipNotifications() throws ExecutionException, InterruptedException {
final PushNotificationScheduler.NotificationWorker worker = pushNotificationScheduler.new NotificationWorker(1);
final long currentTimeMillis = System.currentTimeMillis();
clock.pin(Instant.ofEpochMilli(currentTimeMillis - 30_000));
pushNotificationScheduler.scheduleRecurringApnsVoipNotification(account, device).toCompletableFuture().get();
clock.pin(Instant.ofEpochMilli(currentTimeMillis));
final int slot = SlotHash.getSlot(PushNotificationScheduler.getVoipEndpointKey(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, worker.processRecurringApnsVoipNotifications(slot));
final ArgumentCaptor<PushNotification> notificationCaptor = ArgumentCaptor.forClass(PushNotification.class);
verify(apnSender).sendNotification(notificationCaptor.capture());
final PushNotification pushNotification = notificationCaptor.getValue();
assertEquals(VOIP_APN_ID, pushNotification.deviceToken());
assertEquals(account, pushNotification.destination());
assertEquals(device, pushNotification.destinationDevice());
assertEquals(0, worker.processRecurringApnsVoipNotifications(slot));
}
@Test
void testScheduleBackgroundNotificationWithNoRecentApnsNotification() throws ExecutionException, InterruptedException {
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
clock.pin(now);
assertEquals(Optional.empty(),
pushNotificationScheduler.getLastBackgroundApnsNotificationTimestamp(account, device));
assertEquals(Optional.empty(),
pushNotificationScheduler.getNextScheduledBackgroundApnsNotificationTimestamp(account, device));
pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device).toCompletableFuture().get();
assertEquals(Optional.of(now),
pushNotificationScheduler.getNextScheduledBackgroundApnsNotificationTimestamp(account, device));
}
@Test
void testScheduleBackgroundNotificationWithRecentApnsNotification() throws ExecutionException, InterruptedException {
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
final Instant recentNotificationTimestamp =
now.minus(PushNotificationScheduler.BACKGROUND_NOTIFICATION_PERIOD.dividedBy(2));
// Insert a timestamp for a recently-sent background push notification
clock.pin(Instant.ofEpochMilli(recentNotificationTimestamp.toEpochMilli()));
pushNotificationScheduler.sendBackgroundApnsNotification(account, device);
clock.pin(now);
pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device).toCompletableFuture().get();
final Instant expectedScheduledTimestamp =
recentNotificationTimestamp.plus(PushNotificationScheduler.BACKGROUND_NOTIFICATION_PERIOD);
assertEquals(Optional.of(expectedScheduledTimestamp),
pushNotificationScheduler.getNextScheduledBackgroundApnsNotificationTimestamp(account, device));
}
@Test
void testCancelBackgroundApnsNotifications() {
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
clock.pin(now);
pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device).toCompletableFuture().join();
pushNotificationScheduler.cancelBackgroundApnsNotifications(account, device).join();
assertEquals(Optional.empty(),
pushNotificationScheduler.getLastBackgroundApnsNotificationTimestamp(account, device));
assertEquals(Optional.empty(),
pushNotificationScheduler.getNextScheduledBackgroundApnsNotificationTimestamp(account, device));
}
@Test
void testProcessScheduledBackgroundNotifications() {
final PushNotificationScheduler.NotificationWorker worker = pushNotificationScheduler.new NotificationWorker(1);
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
clock.pin(Instant.ofEpochMilli(now.toEpochMilli()));
pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device).toCompletableFuture().join();
final int slot =
SlotHash.getSlot(PushNotificationScheduler.getPendingBackgroundApnsNotificationQueueKey(account, device));
clock.pin(Instant.ofEpochMilli(now.minusMillis(1).toEpochMilli()));
assertEquals(0, worker.processScheduledBackgroundApnsNotifications(slot));
clock.pin(now);
assertEquals(1, worker.processScheduledBackgroundApnsNotifications(slot));
final ArgumentCaptor<PushNotification> notificationCaptor = ArgumentCaptor.forClass(PushNotification.class);
verify(apnSender).sendNotification(notificationCaptor.capture());
final PushNotification pushNotification = notificationCaptor.getValue();
assertEquals(PushNotification.TokenType.APN, pushNotification.tokenType());
assertEquals(APN_ID, pushNotification.deviceToken());
assertEquals(account, pushNotification.destination());
assertEquals(device, pushNotification.destinationDevice());
assertEquals(PushNotification.NotificationType.NOTIFICATION, pushNotification.notificationType());
assertFalse(pushNotification.urgent());
assertEquals(0, worker.processRecurringApnsVoipNotifications(slot));
assertEquals(Optional.empty(),
pushNotificationScheduler.getNextScheduledBackgroundApnsNotificationTimestamp(account, device));
}
@Test
void testProcessScheduledBackgroundNotificationsCancelled() throws ExecutionException, InterruptedException {
final PushNotificationScheduler.NotificationWorker worker = pushNotificationScheduler.new NotificationWorker(1);
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
clock.pin(now);
pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device).toCompletableFuture().get();
pushNotificationScheduler.cancelScheduledNotifications(account, device).toCompletableFuture().get();
final int slot =
SlotHash.getSlot(PushNotificationScheduler.getPendingBackgroundApnsNotificationQueueKey(account, device));
assertEquals(0, worker.processScheduledBackgroundApnsNotifications(slot));
verify(apnSender, never()).sendNotification(any());
}
@Test
void testScheduleDelayedNotification() {
clock.pin(Instant.now());
assertEquals(Optional.empty(),
pushNotificationScheduler.getNextScheduledDelayedNotificationTimestamp(account, device));
pushNotificationScheduler.scheduleDelayedNotification(account, device, Duration.ofMinutes(1)).join();
assertEquals(Optional.of(clock.instant().truncatedTo(ChronoUnit.MILLIS).plus(Duration.ofMinutes(1))),
pushNotificationScheduler.getNextScheduledDelayedNotificationTimestamp(account, device));
pushNotificationScheduler.scheduleDelayedNotification(account, device, Duration.ofMinutes(2)).join();
assertEquals(Optional.of(clock.instant().truncatedTo(ChronoUnit.MILLIS).plus(Duration.ofMinutes(2))),
pushNotificationScheduler.getNextScheduledDelayedNotificationTimestamp(account, device));
}
@Test
void testCancelDelayedNotification() {
pushNotificationScheduler.scheduleDelayedNotification(account, device, Duration.ofMinutes(1)).join();
pushNotificationScheduler.cancelDelayedNotifications(account, device).join();
assertEquals(Optional.empty(),
pushNotificationScheduler.getNextScheduledDelayedNotificationTimestamp(account, device));
}
@Test
void testProcessScheduledDelayedNotifications() {
final PushNotificationScheduler.NotificationWorker worker = pushNotificationScheduler.new NotificationWorker(1);
final int slot = SlotHash.getSlot(PushNotificationScheduler.getDelayedNotificationQueueKey(account, device));
clock.pin(Instant.now());
pushNotificationScheduler.scheduleDelayedNotification(account, device, Duration.ofMinutes(1)).join();
assertEquals(0, worker.processScheduledDelayedNotifications(slot));
clock.pin(clock.instant().plus(Duration.ofMinutes(1)));
assertEquals(1, worker.processScheduledDelayedNotifications(slot));
assertEquals(Optional.empty(),
pushNotificationScheduler.getNextScheduledDelayedNotificationTimestamp(account, device));
}
@ParameterizedTest
@CsvSource({
"1, true",
"0, false",
})
void testDedicatedProcessDynamicConfiguration(final int dedicatedThreadCount, final boolean expectActivity)
throws Exception {
final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
when(redisCluster.withCluster(any())).thenReturn(0L);
final AccountsManager accountsManager = mock(AccountsManager.class);
pushNotificationScheduler = new PushNotificationScheduler(redisCluster, apnSender, fcmSender,
accountsManager, dedicatedThreadCount, 1);
pushNotificationScheduler.start();
pushNotificationScheduler.stop();
if (expectActivity) {
verify(redisCluster, atLeastOnce()).withCluster(any());
} else {
verifyNoInteractions(redisCluster);
verifyNoInteractions(accountsManager);
verifyNoInteractions(apnSender);
}
}
}

View File

@ -48,6 +48,7 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -127,6 +128,7 @@ class WebSocketConnectionIntegrationTest {
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new MessageMetrics(), new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device), new AuthenticatedDevice(account, device),
webSocketClient, webSocketClient,
scheduledExecutorService, scheduledExecutorService,
@ -213,6 +215,7 @@ class WebSocketConnectionIntegrationTest {
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new MessageMetrics(), new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device), new AuthenticatedDevice(account, device),
webSocketClient, webSocketClient,
scheduledExecutorService, scheduledExecutorService,
@ -280,6 +283,7 @@ class WebSocketConnectionIntegrationTest {
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new MessageMetrics(), new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device), new AuthenticatedDevice(account, device),
webSocketClient, webSocketClient,
100, // use a very short timeout, so that this test completes quickly 100, // use a very short timeout, so that this test completes quickly

View File

@ -59,6 +59,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -122,8 +123,8 @@ class WebSocketConnectionTest {
WebSocketAccountAuthenticator webSocketAuthenticator = WebSocketAccountAuthenticator webSocketAuthenticator =
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class)); new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager, AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
new MessageMetrics(), mock(PushNotificationManager.class), mock(ClientPresenceManager.class), new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class),
retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager); mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
@ -626,7 +627,7 @@ class WebSocketConnectionTest {
private @NotNull WebSocketConnection webSocketConnection(final WebSocketClient client) { private @NotNull WebSocketConnection webSocketConnection(final WebSocketClient client) {
return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(), return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(),
mock(PushNotificationManager.class), auth, client, mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client,
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager); retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager);
} }