diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 24f554c3a..679f35a73 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -188,7 +188,7 @@ import org.whispersystems.textsecuregcm.metrics.TrafficSource; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck; import org.whispersystems.textsecuregcm.push.APNSender; -import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler; +import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.FcmSender; import org.whispersystems.textsecuregcm.push.MessageSender; @@ -649,10 +649,10 @@ public class WhisperServerService extends Application - connection.sync().incr(NEXT_SLOT_TO_PROCESS_KEY)) % SlotHash.SLOT_COUNT); - - return processRecurringVoipNotifications(slot) + processScheduledBackgroundNotifications(slot); - } - - @VisibleForTesting - long processRecurringVoipNotifications(final int slot) { - List pendingDestinations; - long entriesProcessed = 0; - - do { - pendingDestinations = getPendingDestinationsForRecurringVoipNotifications(slot, PAGE_SIZE); - entriesProcessed += pendingDestinations.size(); - - for (final String destination : pendingDestinations) { - try { - getAccountAndDeviceFromPairString(destination).ifPresentOrElse( - accountAndDevice -> sendRecurringVoipNotification(accountAndDevice.first(), accountAndDevice.second()), - () -> removeRecurringVoipNotificationEntrySync(destination)); - } catch (final IllegalArgumentException e) { - logger.warn("Failed to parse account/device pair: {}", destination, e); - } - } - } while (!pendingDestinations.isEmpty()); - - return entriesProcessed; - } - - @VisibleForTesting - long processScheduledBackgroundNotifications(final int slot) { - final long currentTimeMillis = clock.millis(); - final String queueKey = getPendingBackgroundNotificationQueueKey(slot); - - final long processedBackgroundNotifications = pushSchedulingCluster.withCluster(connection -> { - List destinations; - long offset = 0; - - do { - destinations = connection.sync().zrangebyscore(queueKey, Range.create(0, currentTimeMillis), Limit.create(offset, PAGE_SIZE)); - - for (final String destination : destinations) { - try { - getAccountAndDeviceFromPairString(destination).ifPresent(accountAndDevice -> - sendBackgroundNotification(accountAndDevice.first(), accountAndDevice.second())); - } catch (final IllegalArgumentException e) { - logger.warn("Failed to parse account/device pair: {}", destination, e); - } - } - - offset += destinations.size(); - } while (destinations.size() == PAGE_SIZE); - - return offset; - }); - - pushSchedulingCluster.useCluster(connection -> - connection.sync().zremrangebyscore(queueKey, Range.create(0, currentTimeMillis))); - - return processedBackgroundNotifications; - } - } - - public ApnPushNotificationScheduler(FaultTolerantRedisCluster pushSchedulingCluster, - APNSender apnSender, AccountsManager accountsManager, final int dedicatedProcessWorkerThreadCount) - throws IOException { - - this(pushSchedulingCluster, apnSender, accountsManager, Clock.systemUTC(), dedicatedProcessWorkerThreadCount); - } - - @VisibleForTesting - ApnPushNotificationScheduler(FaultTolerantRedisCluster pushSchedulingCluster, - APNSender apnSender, - AccountsManager accountsManager, - Clock clock, - int dedicatedProcessThreadCount) throws IOException { - - this.apnSender = apnSender; - this.accountsManager = accountsManager; - this.pushSchedulingCluster = pushSchedulingCluster; - this.clock = clock; - - this.getPendingVoipDestinationsScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/get.lua", - ScriptOutputType.MULTI); - this.insertPendingVoipDestinationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/insert.lua", - ScriptOutputType.VALUE); - this.removePendingVoipDestinationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/remove.lua", - ScriptOutputType.INTEGER); - - this.scheduleBackgroundNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, - "lua/apn/schedule_background_notification.lua", ScriptOutputType.VALUE); - - this.workerThreads = new Thread[dedicatedProcessThreadCount]; - - for (int i = 0; i < this.workerThreads.length; i++) { - this.workerThreads[i] = new Thread(new NotificationWorker(), "ApnFallbackManagerWorker-" + i); - } - } - - /** - * Schedule a recurring VOIP notification until {@link this#cancelScheduledNotifications} is called or the device is - * removed - * - * @return A CompletionStage that completes when the recurring notification has successfully been scheduled - */ - public CompletionStage scheduleRecurringVoipNotification(Account account, Device device) { - sent.increment(); - return insertRecurringVoipNotificationEntry(account, device, clock.millis() + (15 * 1000), (15 * 1000)); - } - - /** - * Schedule a background notification to be sent some time in the future - * - * @return A CompletionStage that completes when the notification has successfully been scheduled - */ - public CompletionStage scheduleBackgroundNotification(final Account account, final Device device) { - backgroundNotificationScheduledCounter.increment(); - - return scheduleBackgroundNotificationScript.executeAsync( - List.of( - getLastBackgroundNotificationTimestampKey(account, device), - getPendingBackgroundNotificationQueueKey(account, device)), - List.of( - getPairString(account, device), - String.valueOf(clock.millis()), - String.valueOf(BACKGROUND_NOTIFICATION_PERIOD.toMillis()))) - .thenAccept(dropValue()); - } - - /** - * Cancel a scheduled recurring VOIP notification - * - * @return A CompletionStage that completes when the scheduled task has been cancelled. - */ - public CompletionStage cancelScheduledNotifications(Account account, Device device) { - return removeRecurringVoipNotificationEntry(account, device) - .thenCompose(removed -> { - if (removed) { - delivered.increment(); - } - return pushSchedulingCluster.withCluster(connection -> - connection.async().zrem( - getPendingBackgroundNotificationQueueKey(account, device), - getPairString(account, device))); - }) - .thenAccept(dropValue()); - } - - @Override - public synchronized void start() { - running.set(true); - - for (final Thread workerThread : workerThreads) { - workerThread.start(); - } - } - - @Override - public synchronized void stop() throws InterruptedException { - running.set(false); - - for (final Thread workerThread : workerThreads) { - workerThread.join(); - } - } - - private void sendRecurringVoipNotification(final Account account, final Device device) { - String apnId = device.getVoipApnId(); - - if (apnId == null) { - removeRecurringVoipNotificationEntrySync(getEndpointKey(account, device)); - return; - } - - long deviceLastSeen = device.getLastSeen(); - if (deviceLastSeen < clock.millis() - TimeUnit.DAYS.toMillis(7)) { - evicted.increment(); - removeRecurringVoipNotificationEntrySync(getEndpointKey(account, device)); - return; - } - - apnSender.sendNotification(new PushNotification(apnId, PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device, true)); - retry.increment(); - } - - @VisibleForTesting - void sendBackgroundNotification(final Account account, final Device device) { - if (StringUtils.isNotBlank(device.getApnId())) { - // It's okay for the "last notification" timestamp to expire after the "cooldown" period has elapsed; a missing - // timestamp and a timestamp older than the period are functionally equivalent. - pushSchedulingCluster.useCluster(connection -> connection.sync().set( - getLastBackgroundNotificationTimestampKey(account, device), - String.valueOf(clock.millis()), new SetArgs().ex(BACKGROUND_NOTIFICATION_PERIOD))); - - apnSender.sendNotification(new PushNotification(device.getApnId(), PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, account, device, false)); - - backgroundNotificationSentCounter.increment(); - } - } - - @VisibleForTesting - static Optional> getSeparated(String encoded) { - try { - if (encoded == null) return Optional.empty(); - - String[] parts = encoded.split(":"); - - if (parts.length != 2) { - logger.warn("Got strange encoded number: " + encoded); - return Optional.empty(); - } - - return Optional.of(new Pair<>(parts[0], Byte.parseByte(parts[1]))); - } catch (NumberFormatException e) { - logger.warn("Badly formatted: " + encoded, e); - return Optional.empty(); - } - } - - @VisibleForTesting - static String getPairString(final Account account, final Device device) { - return account.getUuid() + ":" + device.getId(); - } - - @VisibleForTesting - Optional> getAccountAndDeviceFromPairString(final String endpoint) { - try { - if (StringUtils.isBlank(endpoint)) { - throw new IllegalArgumentException("Endpoint must not be blank"); - } - - final String[] parts = endpoint.split(":"); - - if (parts.length != 2) { - throw new IllegalArgumentException("Could not parse endpoint string: " + endpoint); - } - - final Optional maybeAccount = accountsManager.getByAccountIdentifier(UUID.fromString(parts[0])); - - return maybeAccount.flatMap(account -> account.getDevice(Byte.parseByte(parts[1]))) - .map(device -> new Pair<>(maybeAccount.get(), device)); - - } catch (final NumberFormatException e) { - throw new IllegalArgumentException(e); - } - } - - private boolean removeRecurringVoipNotificationEntrySync(final String endpoint) { - try { - return removeRecurringVoipNotificationEntry(endpoint).toCompletableFuture().get(); - } catch (ExecutionException e) { - if (e.getCause() instanceof RedisException re) { - throw re; - } - throw new RuntimeException(e); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - - private CompletionStage removeRecurringVoipNotificationEntry(Account account, Device device) { - return removeRecurringVoipNotificationEntry(getEndpointKey(account, device)); - } - - private CompletionStage removeRecurringVoipNotificationEntry(final String endpoint) { - return removePendingVoipDestinationScript.executeAsync( - List.of(getPendingRecurringVoipNotificationQueueKey(endpoint), endpoint), - Collections.emptyList()) - .thenApply(result -> ((long) result) > 0); - } - - @SuppressWarnings("unchecked") - @VisibleForTesting - List getPendingDestinationsForRecurringVoipNotifications(final int slot, final int limit) { - return (List) getPendingVoipDestinationsScript.execute( - List.of(getPendingRecurringVoipNotificationQueueKey(slot)), - List.of(String.valueOf(clock.millis()), String.valueOf(limit))); - } - - private CompletionStage insertRecurringVoipNotificationEntry(final Account account, final Device device, final long timestamp, final long interval) { - final String endpoint = getEndpointKey(account, device); - - return insertPendingVoipDestinationScript.executeAsync( - List.of(getPendingRecurringVoipNotificationQueueKey(endpoint), endpoint), - List.of(String.valueOf(timestamp), - String.valueOf(interval), - account.getUuid().toString(), - String.valueOf(device.getId()))) - .thenAccept(dropValue()); - } - - @VisibleForTesting - static String getEndpointKey(final Account account, final Device device) { - return "apn_device::{" + account.getUuid() + "::" + device.getId() + "}"; - } - - private static String getPendingRecurringVoipNotificationQueueKey(final String endpoint) { - return getPendingRecurringVoipNotificationQueueKey(SlotHash.getSlot(endpoint)); - } - - private static String getPendingRecurringVoipNotificationQueueKey(final int slot) { - return PENDING_RECURRING_VOIP_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}"; - } - - @VisibleForTesting - static String getPendingBackgroundNotificationQueueKey(final Account account, final Device device) { - return getPendingBackgroundNotificationQueueKey(SlotHash.getSlot(getPairString(account, device))); - } - - private static String getPendingBackgroundNotificationQueueKey(final int slot) { - return PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}"; - } - - private static String getLastBackgroundNotificationTimestampKey(final Account account, final Device device) { - return LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX + "::{" + getPairString(account, device) + "}"; - } - - @VisibleForTesting - Optional getLastBackgroundNotificationTimestamp(final Account account, final Device device) { - return Optional.ofNullable( - pushSchedulingCluster.withCluster(connection -> - connection.sync().get(getLastBackgroundNotificationTimestampKey(account, device)))) - .map(timestampString -> Instant.ofEpochMilli(Long.parseLong(timestampString))); - } - - @VisibleForTesting - Optional getNextScheduledBackgroundNotificationTimestamp(final Account account, final Device device) { - return Optional.ofNullable( - pushSchedulingCluster.withCluster(connection -> - connection.sync().zscore(getPendingBackgroundNotificationQueueKey(account, device), - getPairString(account, device)))) - .map(timestamp -> Instant.ofEpochMilli(timestamp.longValue())); - } - - private static Consumer dropValue() { - return ignored -> {}; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java index 2fe17ccc9..16ff36887 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java @@ -17,7 +17,6 @@ import java.util.function.BiConsumer; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -28,7 +27,7 @@ public class PushNotificationManager { private final AccountsManager accountsManager; private final APNSender apnSender; private final FcmSender fcmSender; - private final ApnPushNotificationScheduler apnPushNotificationScheduler; + private final PushNotificationScheduler pushNotificationScheduler; private static final String SENT_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "sentPushNotification"); private static final String FAILED_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "failedPushNotification"); @@ -39,12 +38,12 @@ public class PushNotificationManager { public PushNotificationManager(final AccountsManager accountsManager, final APNSender apnSender, final FcmSender fcmSender, - final ApnPushNotificationScheduler apnPushNotificationScheduler) { + final PushNotificationScheduler pushNotificationScheduler) { this.accountsManager = accountsManager; this.apnSender = apnSender; this.fcmSender = fcmSender; - this.apnPushNotificationScheduler = apnPushNotificationScheduler; + this.pushNotificationScheduler = pushNotificationScheduler; } public CompletableFuture> sendNewMessageNotification(final Account destination, final byte destinationDeviceId, final boolean urgent) throws NotPushRegisteredException { @@ -82,7 +81,7 @@ public class PushNotificationManager { } public void handleMessagesRetrieved(final Account account, final Device device, final String userAgent) { - apnPushNotificationScheduler.cancelScheduledNotifications(account, device).whenComplete(logErrors()); + pushNotificationScheduler.cancelScheduledNotifications(account, device).whenComplete(logErrors()); } @VisibleForTesting @@ -107,8 +106,8 @@ public class PushNotificationManager { if (pushNotification.tokenType() == PushNotification.TokenType.APN && !pushNotification.urgent()) { // APNs imposes a per-device limit on background push notifications; schedule a notification for some time in the // future (possibly even now!) rather than sending a notification directly - return apnPushNotificationScheduler - .scheduleBackgroundNotification(pushNotification.destination(), pushNotification.destinationDevice()) + return pushNotificationScheduler + .scheduleBackgroundApnsNotification(pushNotification.destination(), pushNotification.destinationDevice()) .whenComplete(logErrors()) .thenApply(ignored -> Optional.empty()) .toCompletableFuture(); @@ -149,7 +148,7 @@ public class PushNotificationManager { pushNotification.destination() != null && pushNotification.destinationDevice() != null) { - apnPushNotificationScheduler.scheduleRecurringVoipNotification( + pushNotificationScheduler.scheduleRecurringApnsVoipNotification( pushNotification.destination(), pushNotification.destinationDevice()) .whenComplete(logErrors()); @@ -185,7 +184,7 @@ public class PushNotificationManager { if (tokenExpired) { if (tokenType == PushNotification.TokenType.APN || tokenType == PushNotification.TokenType.APN_VOIP) { - apnPushNotificationScheduler.cancelScheduledNotifications(account, device).whenComplete(logErrors()); + pushNotificationScheduler.cancelScheduledNotifications(account, device).whenComplete(logErrors()); } clearPushToken(account, device, tokenType); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationScheduler.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationScheduler.java new file mode 100644 index 000000000..b24f8008d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationScheduler.java @@ -0,0 +1,545 @@ +/* + * Copyright 2013-2020 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.push; + +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + +import com.google.common.annotations.VisibleForTesting; +import io.dropwizard.lifecycle.Managed; +import io.lettuce.core.Range; +import io.lettuce.core.ScriptOutputType; +import io.lettuce.core.SetArgs; +import io.lettuce.core.cluster.SlotHash; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiFunction; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.RedisClusterUtil; +import org.whispersystems.textsecuregcm.util.Util; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class PushNotificationScheduler implements Managed { + + private static final Logger logger = LoggerFactory.getLogger(PushNotificationScheduler.class); + + private static final String PENDING_RECURRING_VOIP_NOTIFICATIONS_KEY_PREFIX = "PENDING_APN"; + private static final String PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_APN"; + private static final String LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX = "LAST_BACKGROUND_NOTIFICATION"; + private static final String PENDING_DELAYED_NOTIFICATIONS_KEY_PREFIX = "DELAYED"; + + @VisibleForTesting + static final String NEXT_SLOT_TO_PROCESS_KEY = "pending_notification_next_slot"; + + private static final Counter delivered = Metrics.counter("chat.ApnPushNotificationScheduler.voip_delivered"); + private static final Counter sent = Metrics.counter("chat.ApnPushNotificationScheduler.voip_sent"); + private static final Counter retry = Metrics.counter("chat.ApnPushNotificationScheduler.voip_retry"); + private static final Counter evicted = Metrics.counter("chat.ApnPushNotificationScheduler.voip_evicted"); + + private static final Counter BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER = Metrics.counter(name(PushNotificationScheduler.class, "backgroundNotification", "scheduled")); + private static final String BACKGROUND_NOTIFICATION_SENT_COUNTER_NAME = name(PushNotificationScheduler.class, "backgroundNotification", "sent"); + + private static final String DELAYED_NOTIFICATION_SCHEDULED_COUNTER_NAME = name(PushNotificationScheduler.class, "delayedNotificationScheduled"); + private static final String DELAYED_NOTIFICATION_SENT_COUNTER_NAME = name(PushNotificationScheduler.class, "delayedNotificationSent"); + private static final String TOKEN_TYPE_TAG = "tokenType"; + private static final String ACCEPTED_TAG = "accepted"; + + private final APNSender apnSender; + private final FcmSender fcmSender; + private final AccountsManager accountsManager; + private final FaultTolerantRedisCluster pushSchedulingCluster; + private final Clock clock; + + private final ClusterLuaScript getPendingVoipDestinationsScript; + private final ClusterLuaScript insertPendingVoipDestinationScript; + private final ClusterLuaScript removePendingVoipDestinationScript; + + private final ClusterLuaScript scheduleBackgroundApnsNotificationScript; + + private final Thread[] workerThreads; + + @VisibleForTesting + static final Duration BACKGROUND_NOTIFICATION_PERIOD = Duration.ofMinutes(20); + + private final AtomicBoolean running = new AtomicBoolean(false); + + class NotificationWorker implements Runnable { + + private final int maxConcurrency; + + private static final int PAGE_SIZE = 128; + + NotificationWorker(final int maxConcurrency) { + this.maxConcurrency = maxConcurrency; + } + + @Override + public void run() { + do { + try { + final long entriesProcessed = processNextSlot(); + + if (entriesProcessed == 0) { + Util.sleep(1000); + } + } catch (Exception e) { + logger.warn("Exception while operating", e); + } + } while (running.get()); + } + + private long processNextSlot() { + final int slot = (int) (pushSchedulingCluster.withCluster(connection -> + connection.sync().incr(NEXT_SLOT_TO_PROCESS_KEY)) % SlotHash.SLOT_COUNT); + + return processRecurringApnsVoipNotifications(slot) + + processScheduledBackgroundApnsNotifications(slot) + + processScheduledDelayedNotifications(slot); + } + + @VisibleForTesting + long processRecurringApnsVoipNotifications(final int slot) { + List pendingDestinations; + long entriesProcessed = 0; + + do { + pendingDestinations = getPendingDestinationsForRecurringApnsVoipNotifications(slot, PAGE_SIZE); + entriesProcessed += pendingDestinations.size(); + + Flux.fromIterable(pendingDestinations) + .flatMap(destination -> Mono.fromFuture(() -> getAccountAndDeviceFromPairString(destination)) + .flatMap(maybeAccountAndDevice -> { + if (maybeAccountAndDevice.isPresent()) { + final Pair accountAndDevice = maybeAccountAndDevice.get(); + return Mono.fromFuture(() -> sendRecurringApnsVoipNotification(accountAndDevice.first(), accountAndDevice.second())); + } else { + final Pair aciAndDeviceId = decodeAciAndDeviceId(destination); + return Mono.fromFuture(() -> removeRecurringApnsVoipNotificationEntry(aciAndDeviceId.first(), aciAndDeviceId.second())) + .then(); + } + }), maxConcurrency) + .then() + .block(); + } while (!pendingDestinations.isEmpty()); + + return entriesProcessed; + } + + @VisibleForTesting + long processScheduledBackgroundApnsNotifications(final int slot) { + return processScheduledNotifications(getPendingBackgroundApnsNotificationQueueKey(slot), + PushNotificationScheduler.this::sendBackgroundApnsNotification); + } + + @VisibleForTesting + long processScheduledDelayedNotifications(final int slot) { + return processScheduledNotifications(getDelayedNotificationQueueKey(slot), + PushNotificationScheduler.this::sendDelayedNotification); + } + + private long processScheduledNotifications(final String queueKey, + final BiFunction> sendNotificationFunction) { + + final long currentTimeMillis = clock.millis(); + final AtomicLong processedNotifications = new AtomicLong(0); + + pushSchedulingCluster.useCluster( + connection -> connection.reactive().zrangebyscore(queueKey, Range.create(0, currentTimeMillis)) + .flatMap(encodedAciAndDeviceId -> Mono.fromFuture( + () -> getAccountAndDeviceFromPairString(encodedAciAndDeviceId)), maxConcurrency) + .flatMap(Mono::justOrEmpty) + .flatMap(accountAndDevice -> Mono.fromFuture( + () -> sendNotificationFunction.apply(accountAndDevice.first(), accountAndDevice.second())) + .then(Mono.defer(() -> connection.reactive().zrem(queueKey, encodeAciAndDeviceId(accountAndDevice.first(), accountAndDevice.second())))) + .doOnSuccess(ignored -> processedNotifications.incrementAndGet()), + maxConcurrency) + .then() + .block()); + + return processedNotifications.get(); + } + } + + public PushNotificationScheduler(final FaultTolerantRedisCluster pushSchedulingCluster, + final APNSender apnSender, + final FcmSender fcmSender, + final AccountsManager accountsManager, + final int dedicatedProcessWorkerThreadCount, + final int workerMaxConcurrency) throws IOException { + + this(pushSchedulingCluster, + apnSender, + fcmSender, + accountsManager, + Clock.systemUTC(), + dedicatedProcessWorkerThreadCount, + workerMaxConcurrency); + } + + @VisibleForTesting + PushNotificationScheduler(final FaultTolerantRedisCluster pushSchedulingCluster, + final APNSender apnSender, + final FcmSender fcmSender, + final AccountsManager accountsManager, + final Clock clock, + final int dedicatedProcessThreadCount, + final int workerMaxConcurrency) throws IOException { + + this.apnSender = apnSender; + this.fcmSender = fcmSender; + this.accountsManager = accountsManager; + this.pushSchedulingCluster = pushSchedulingCluster; + this.clock = clock; + + this.getPendingVoipDestinationsScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/get.lua", + ScriptOutputType.MULTI); + this.insertPendingVoipDestinationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/insert.lua", + ScriptOutputType.VALUE); + this.removePendingVoipDestinationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/remove.lua", + ScriptOutputType.INTEGER); + + this.scheduleBackgroundApnsNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, + "lua/apn/schedule_background_notification.lua", ScriptOutputType.VALUE); + + this.workerThreads = new Thread[dedicatedProcessThreadCount]; + + for (int i = 0; i < this.workerThreads.length; i++) { + this.workerThreads[i] = new Thread(new NotificationWorker(workerMaxConcurrency), "PushNotificationScheduler-" + i); + } + } + + /** + * Schedule a recurring VOIP notification until {@link this#cancelScheduledNotifications} is called or the device is + * removed + * + * @return A CompletionStage that completes when the recurring notification has successfully been scheduled + */ + public CompletionStage scheduleRecurringApnsVoipNotification(Account account, Device device) { + sent.increment(); + return insertRecurringApnsVoipNotificationEntry(account.getIdentifier(IdentityType.ACI), device.getId(), clock.millis() + (15 * 1000), (15 * 1000)); + } + + /** + * Schedule a background APNs notification to be sent some time in the future. + * + * @return A CompletionStage that completes when the notification has successfully been scheduled + * + * @throws IllegalArgumentException if the given device does not have an APNs token + */ + public CompletionStage scheduleBackgroundApnsNotification(final Account account, final Device device) { + if (StringUtils.isBlank(device.getApnId())) { + throw new IllegalArgumentException("Device must have an APNs token"); + } + + BACKGROUND_NOTIFICATION_SCHEDULED_COUNTER.increment(); + + return scheduleBackgroundApnsNotificationScript.executeAsync( + List.of( + getLastBackgroundApnsNotificationTimestampKey(account, device), + getPendingBackgroundApnsNotificationQueueKey(account, device)), + List.of( + encodeAciAndDeviceId(account, device), + String.valueOf(clock.millis()), + String.valueOf(BACKGROUND_NOTIFICATION_PERIOD.toMillis()))) + .thenRun(Util.NOOP); + } + + /** + * Schedules a "new message" push notification to be delivered to the given device after at least the given duration. + * If another notification had previously been scheduled, calling this method will replace the previously-scheduled + * delivery time with the given time. + * + * @param account the account to which the target device belongs + * @param device the device to which to deliver a "new message" push notification + * @param minDelay the minimum delay after which to deliver the notification + * + * @return a future that completes once the notification has been scheduled + */ + public CompletableFuture scheduleDelayedNotification(final Account account, final Device device, final Duration minDelay) { + return pushSchedulingCluster.withCluster(connection -> + connection.async().zadd(getDelayedNotificationQueueKey(account, device), + clock.instant().plus(minDelay).toEpochMilli(), + encodeAciAndDeviceId(account, device))) + .thenRun(() -> Metrics.counter(DELAYED_NOTIFICATION_SCHEDULED_COUNTER_NAME, + TOKEN_TYPE_TAG, getTokenType(device)) + .increment()) + .toCompletableFuture(); + } + + /** + * Cancel a scheduled recurring VOIP notification + * + * @return A CompletionStage that completes when the scheduled task has been cancelled. + */ + public CompletionStage cancelScheduledNotifications(Account account, Device device) { + return CompletableFuture.allOf( + cancelRecurringApnsVoipNotifications(account, device), + cancelBackgroundApnsNotifications(account, device), + cancelDelayedNotifications(account, device)); + } + + private CompletableFuture cancelRecurringApnsVoipNotifications(final Account account, final Device device) { + return removeRecurringApnsVoipNotificationEntry(account.getIdentifier(IdentityType.ACI), device.getId()) + .thenCompose(removed -> { + if (removed) { + delivered.increment(); + } + return pushSchedulingCluster.withCluster(connection -> + connection.async().zrem( + getPendingBackgroundApnsNotificationQueueKey(account, device), + encodeAciAndDeviceId(account, device))); + }) + .thenRun(Util.NOOP) + .toCompletableFuture(); + } + + @VisibleForTesting + CompletableFuture cancelBackgroundApnsNotifications(final Account account, final Device device) { + return pushSchedulingCluster.withCluster(connection -> connection.async() + .zrem(getPendingBackgroundApnsNotificationQueueKey(account, device), encodeAciAndDeviceId(account, device))) + .thenRun(Util.NOOP) + .toCompletableFuture(); + } + + @VisibleForTesting + CompletableFuture cancelDelayedNotifications(final Account account, final Device device) { + return pushSchedulingCluster.withCluster(connection -> + connection.async().zrem(getDelayedNotificationQueueKey(account, device), + encodeAciAndDeviceId(account, device))) + .thenRun(Util.NOOP) + .toCompletableFuture(); + } + + @Override + public synchronized void start() { + running.set(true); + + for (final Thread workerThread : workerThreads) { + workerThread.start(); + } + } + + @Override + public synchronized void stop() throws InterruptedException { + running.set(false); + + for (final Thread workerThread : workerThreads) { + workerThread.join(); + } + } + + private CompletableFuture sendRecurringApnsVoipNotification(final Account account, final Device device) { + if (StringUtils.isBlank(device.getVoipApnId())) { + return removeRecurringApnsVoipNotificationEntry(account.getIdentifier(IdentityType.ACI), device.getId()) + .thenRun(Util.NOOP); + } + + if (device.getLastSeen() < clock.millis() - TimeUnit.DAYS.toMillis(7)) { + return removeRecurringApnsVoipNotificationEntry(account.getIdentifier(IdentityType.ACI), device.getId()) + .thenRun(evicted::increment); + } + + return apnSender.sendNotification(new PushNotification(device.getVoipApnId(), PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device, true)) + .thenRun(retry::increment); + } + + @VisibleForTesting + CompletableFuture sendBackgroundApnsNotification(final Account account, final Device device) { + if (StringUtils.isBlank(device.getApnId())) { + return CompletableFuture.completedFuture(null); + } + + // It's okay for the "last notification" timestamp to expire after the "cooldown" period has elapsed; a missing + // timestamp and a timestamp older than the period are functionally equivalent. + return pushSchedulingCluster.withCluster(connection -> connection.async().set( + getLastBackgroundApnsNotificationTimestampKey(account, device), + String.valueOf(clock.millis()), new SetArgs().ex(BACKGROUND_NOTIFICATION_PERIOD))) + .thenCompose(ignored -> apnSender.sendNotification(new PushNotification(device.getApnId(), PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, account, device, false))) + .thenAccept(response -> Metrics.counter(BACKGROUND_NOTIFICATION_SENT_COUNTER_NAME, + ACCEPTED_TAG, String.valueOf(response.accepted())) + .increment()) + .toCompletableFuture(); + } + + @VisibleForTesting + CompletableFuture sendDelayedNotification(final Account account, final Device device) { + if (StringUtils.isAllBlank(device.getApnId(), device.getGcmId())) { + return CompletableFuture.completedFuture(null); + } + + final boolean isApnsDevice = StringUtils.isNotBlank(device.getApnId()); + + final PushNotification pushNotification = new PushNotification( + isApnsDevice ? device.getApnId() : device.getGcmId(), + isApnsDevice ? PushNotification.TokenType.APN : PushNotification.TokenType.FCM, + PushNotification.NotificationType.NOTIFICATION, + null, + account, + device, + true); + + final PushNotificationSender pushNotificationSender = isApnsDevice ? apnSender : fcmSender; + + return pushNotificationSender.sendNotification(pushNotification) + .thenAccept(response -> Metrics.counter(DELAYED_NOTIFICATION_SENT_COUNTER_NAME, + TOKEN_TYPE_TAG, getTokenType(device), + ACCEPTED_TAG, String.valueOf(response.accepted())) + .increment()); + } + + @VisibleForTesting + static String encodeAciAndDeviceId(final Account account, final Device device) { + return account.getUuid() + ":" + device.getId(); + } + + static Pair decodeAciAndDeviceId(final String encoded) { + if (StringUtils.isBlank(encoded)) { + throw new IllegalArgumentException("Encoded ACI/device ID pair must not be blank"); + } + + final int separatorIndex = encoded.indexOf(':'); + + if (separatorIndex == -1) { + throw new IllegalArgumentException("String did not contain a ':' separator"); + } + + final UUID aci = UUID.fromString(encoded.substring(0, separatorIndex)); + final byte deviceId = Byte.parseByte(encoded.substring(separatorIndex + 1)); + + return new Pair<>(aci, deviceId); + } + + @VisibleForTesting + CompletableFuture>> getAccountAndDeviceFromPairString(final String endpoint) { + final Pair aciAndDeviceId = decodeAciAndDeviceId(endpoint); + + return accountsManager.getByAccountIdentifierAsync(aciAndDeviceId.first()) + .thenApply(maybeAccount -> maybeAccount + .flatMap(account -> account.getDevice(aciAndDeviceId.second()).map(device -> new Pair<>(account, device)))); + } + + private CompletableFuture removeRecurringApnsVoipNotificationEntry(final UUID aci, final byte deviceId) { + final String endpoint = getVoipEndpointKey(aci, deviceId); + + return removePendingVoipDestinationScript.executeAsync( + List.of(getPendingRecurringApnsVoipNotificationQueueKey(endpoint), endpoint), Collections.emptyList()) + .thenApply(result -> ((long) result) > 0); + } + + @SuppressWarnings("unchecked") + @VisibleForTesting + List getPendingDestinationsForRecurringApnsVoipNotifications(final int slot, final int limit) { + return (List) getPendingVoipDestinationsScript.execute( + List.of(getPendingRecurringApnsVoipNotificationQueueKey(slot)), + List.of(String.valueOf(clock.millis()), String.valueOf(limit))); + } + + @SuppressWarnings("SameParameterValue") + private CompletionStage insertRecurringApnsVoipNotificationEntry(final UUID aci, final byte deviceId, final long timestamp, final long interval) { + final String endpoint = getVoipEndpointKey(aci, deviceId); + + return insertPendingVoipDestinationScript.executeAsync( + List.of(getPendingRecurringApnsVoipNotificationQueueKey(endpoint), endpoint), + List.of(String.valueOf(timestamp), + String.valueOf(interval), + aci.toString(), + String.valueOf(deviceId))) + .thenRun(Util.NOOP); + } + + @VisibleForTesting + static String getVoipEndpointKey(final UUID aci, final byte deviceId) { + return "apn_device::{" + aci + "::" + deviceId + "}"; + } + + private static String getPendingRecurringApnsVoipNotificationQueueKey(final String endpoint) { + return getPendingRecurringApnsVoipNotificationQueueKey(SlotHash.getSlot(endpoint)); + } + + private static String getPendingRecurringApnsVoipNotificationQueueKey(final int slot) { + return PENDING_RECURRING_VOIP_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}"; + } + + @VisibleForTesting + static String getPendingBackgroundApnsNotificationQueueKey(final Account account, final Device device) { + return getPendingBackgroundApnsNotificationQueueKey(SlotHash.getSlot(encodeAciAndDeviceId(account, device))); + } + + private static String getPendingBackgroundApnsNotificationQueueKey(final int slot) { + return PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}"; + } + + private static String getLastBackgroundApnsNotificationTimestampKey(final Account account, final Device device) { + return LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX + "::{" + encodeAciAndDeviceId(account, device) + "}"; + } + + @VisibleForTesting + static String getDelayedNotificationQueueKey(final Account account, final Device device) { + return getDelayedNotificationQueueKey(SlotHash.getSlot(encodeAciAndDeviceId(account, device))); + } + + private static String getDelayedNotificationQueueKey(final int slot) { + return PENDING_DELAYED_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}"; + } + + @VisibleForTesting + Optional getLastBackgroundApnsNotificationTimestamp(final Account account, final Device device) { + return Optional.ofNullable( + pushSchedulingCluster.withCluster(connection -> + connection.sync().get(getLastBackgroundApnsNotificationTimestampKey(account, device)))) + .map(timestampString -> Instant.ofEpochMilli(Long.parseLong(timestampString))); + } + + @VisibleForTesting + Optional getNextScheduledBackgroundApnsNotificationTimestamp(final Account account, final Device device) { + return Optional.ofNullable( + pushSchedulingCluster.withCluster(connection -> + connection.sync().zscore(getPendingBackgroundApnsNotificationQueueKey(account, device), + encodeAciAndDeviceId(account, device)))) + .map(timestamp -> Instant.ofEpochMilli(timestamp.longValue())); + } + + @VisibleForTesting + Optional getNextScheduledDelayedNotificationTimestamp(final Account account, final Device device) { + return Optional.ofNullable( + pushSchedulingCluster.withCluster(connection -> + connection.sync().zscore(getDelayedNotificationQueueKey(account, device), + encodeAciAndDeviceId(account, device)))) + .map(timestamp -> Instant.ofEpochMilli(timestamp.longValue())); + } + + private static String getTokenType(final Device device) { + if (StringUtils.isNotBlank(device.getApnId())) { + return "apns"; + } else if (StringUtils.isNotBlank(device.getGcmId())) { + return "fcm"; + } else { + return "unknown"; + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index cc6d446a2..877142bff 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -24,6 +24,7 @@ import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; +import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; @@ -52,6 +53,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final MessagesManager messagesManager; private final MessageMetrics messageMetrics; private final PushNotificationManager pushNotificationManager; + private final PushNotificationScheduler pushNotificationScheduler; private final ClientPresenceManager clientPresenceManager; private final ScheduledExecutorService scheduledExecutorService; private final Scheduler messageDeliveryScheduler; @@ -71,6 +73,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { MessagesManager messagesManager, MessageMetrics messageMetrics, PushNotificationManager pushNotificationManager, + PushNotificationScheduler pushNotificationScheduler, ClientPresenceManager clientPresenceManager, ScheduledExecutorService scheduledExecutorService, Scheduler messageDeliveryScheduler, @@ -79,6 +82,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { this.messagesManager = messagesManager; this.messageMetrics = messageMetrics; this.pushNotificationManager = pushNotificationManager; + this.pushNotificationScheduler = pushNotificationScheduler; this.clientPresenceManager = clientPresenceManager; this.scheduledExecutorService = scheduledExecutorService; this.messageDeliveryScheduler = messageDeliveryScheduler; @@ -142,6 +146,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { messagesManager, messageMetrics, pushNotificationManager, + pushNotificationScheduler, auth, context.getClient(), scheduledExecutorService, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 7d1720384..83ed82e6d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -13,6 +13,7 @@ import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; +import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -43,8 +44,8 @@ import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; -import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.PushNotificationManager; +import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -88,8 +89,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac "sendMessages"); private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class, "sendMessageError"); - private static final String PUSH_NOTIFICATION_ON_CLOSE_COUNTER_NAME = - MetricsUtil.name(WebSocketConnection.class, "pushNotificationOnClose"); private static final String STATUS_CODE_TAG = "status"; private static final String STATUS_MESSAGE_TAG = "message"; private static final String ERROR_TYPE_TAG = "errorType"; @@ -109,12 +108,15 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private static final int DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS = 5 * 60 * 1000; + private static final Duration CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY = Duration.ofMinutes(1); + private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); private final ReceiptSender receiptSender; private final MessagesManager messagesManager; private final MessageMetrics messageMetrics; private final PushNotificationManager pushNotificationManager; + private final PushNotificationScheduler pushNotificationScheduler; private final AuthenticatedDevice auth; private final WebSocketClient client; @@ -148,6 +150,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac MessagesManager messagesManager, MessageMetrics messageMetrics, PushNotificationManager pushNotificationManager, + PushNotificationScheduler pushNotificationScheduler, AuthenticatedDevice auth, WebSocketClient client, ScheduledExecutorService scheduledExecutorService, @@ -158,6 +161,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac messagesManager, messageMetrics, pushNotificationManager, + pushNotificationScheduler, auth, client, DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS, @@ -171,6 +175,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac MessagesManager messagesManager, MessageMetrics messageMetrics, PushNotificationManager pushNotificationManager, + PushNotificationScheduler pushNotificationScheduler, AuthenticatedDevice auth, WebSocketClient client, int sendFuturesTimeoutMillis, @@ -182,6 +187,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac this.messagesManager = messagesManager; this.messageMetrics = messageMetrics; this.pushNotificationManager = pushNotificationManager; + this.pushNotificationScheduler = pushNotificationScheduler; this.auth = auth; this.client = client; this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis; @@ -211,14 +217,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac client.close(1000, "OK"); if (storedMessageState.get() != StoredMessageState.EMPTY) { - try { - pushNotificationManager.sendNewMessageNotification(auth.getAccount(), auth.getAuthenticatedDevice().getId(), true); - - Metrics.counter(PUSH_NOTIFICATION_ON_CLOSE_COUNTER_NAME, - Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()))) - .increment(); - } catch (NotPushRegisteredException ignored) { - } + pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(), + auth.getAuthenticatedDevice(), + CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index aeb165594..e50d6e10a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -33,7 +33,7 @@ import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controll import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.push.APNSender; -import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler; +import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.FcmSender; import org.whispersystems.textsecuregcm.push.PushNotificationManager; @@ -245,10 +245,10 @@ record CommandDependencies( clock); APNSender apnSender = new APNSender(apnSenderExecutor, configuration.getApnConfiguration()); FcmSender fcmSender = new FcmSender(fcmSenderExecutor, configuration.getFcmConfiguration().credentials().value()); - ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, - apnSender, accountsManager, 0); + PushNotificationScheduler pushNotificationScheduler = new PushNotificationScheduler(pushSchedulerCluster, + apnSender, fcmSender, accountsManager, 0, 0); PushNotificationManager pushNotificationManager = - new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler); + new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler); PushNotificationExperimentSamples pushNotificationExperimentSamples = new PushNotificationExperimentSamples(dynamoDbAsyncClient, configuration.getDynamoDbTables().getPushNotificationExperimentSamples().getTableName(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScheduledApnPushNotificationSenderServiceCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScheduledApnPushNotificationSenderServiceCommand.java index c3187214f..a90cb6ab5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScheduledApnPushNotificationSenderServiceCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScheduledApnPushNotificationSenderServiceCommand.java @@ -18,12 +18,14 @@ import net.sourceforge.argparse4j.inf.Subparser; import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.push.APNSender; -import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler; +import org.whispersystems.textsecuregcm.push.FcmSender; +import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler; public class ScheduledApnPushNotificationSenderServiceCommand extends ServerCommand { private static final String WORKER_COUNT = "workers"; + private static final String MAX_CONCURRENCY = "max_concurrency"; public ScheduledApnPushNotificationSenderServiceCommand() { super(new Application<>() { @@ -38,11 +40,19 @@ public class ScheduledApnPushNotificationSenderServiceCommand extends ServerComm @Override public void configure(final Subparser subparser) { super.configure(subparser); + subparser.addArgument("--workers") .type(Integer.class) .dest(WORKER_COUNT) .required(true) .help("The number of worker threads"); + + subparser.addArgument("--max-concurrency") + .type(Integer.class) + .dest(MAX_CONCURRENCY) + .required(false) + .setDefault(16) + .help("The number of concurrent operations per worker thread"); } @Override @@ -63,15 +73,16 @@ public class ScheduledApnPushNotificationSenderServiceCommand extends ServerComm }); } - final ExecutorService apnSenderExecutor = environment.lifecycle().executorService(name(getClass(), "apnSender-%d")) + final ExecutorService pushNotificationSenderExecutor = environment.lifecycle().executorService(name(getClass(), "apnSender-%d")) .maxThreads(1).minThreads(1).build(); - final APNSender apnSender = new APNSender(apnSenderExecutor, configuration.getApnConfiguration()); - final ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler( - deps.pushSchedulerCluster(), apnSender, deps.accountsManager(), namespace.getInt(WORKER_COUNT)); + final APNSender apnSender = new APNSender(pushNotificationSenderExecutor, configuration.getApnConfiguration()); + final FcmSender fcmSender = new FcmSender(pushNotificationSenderExecutor, configuration.getFcmConfiguration().credentials().value()); + final PushNotificationScheduler pushNotificationScheduler = new PushNotificationScheduler( + deps.pushSchedulerCluster(), apnSender, fcmSender, deps.accountsManager(), namespace.getInt(WORKER_COUNT), namespace.getInt(MAX_CONCURRENCY)); environment.lifecycle().manage(apnSender); - environment.lifecycle().manage(apnPushNotificationScheduler); + environment.lifecycle().manage(pushNotificationScheduler); MetricsUtil.registerSystemResourceMetrics(environment); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index cd2d14e5d..3339a7fd3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -76,6 +76,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.junitpioneer.jupiter.cartesian.ArgumentSets; @@ -109,6 +110,7 @@ import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.PushNotificationManager; +import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider; import org.whispersystems.textsecuregcm.spam.SpamChecker; @@ -179,6 +181,7 @@ class MessageControllerTest { private static final CardinalityEstimator cardinalityEstimator = mock(CardinalityEstimator.class); private static final RateLimiter rateLimiter = mock(RateLimiter.class); private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); + private static final PushNotificationScheduler pushNotificationScheduler = mock(PushNotificationScheduler.class); private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService(); private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); @@ -200,13 +203,15 @@ class MessageControllerTest { .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource( new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager, - messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor, + messagesManager, pushNotificationManager, pushNotificationScheduler, reportMessageManager, multiRecipientMessageExecutor, messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager, serverSecretParams, SpamChecker.noop(), new MessageMetrics(), clock)) .build(); @BeforeEach void setup() { + reset(pushNotificationScheduler); + final List singleDeviceList = List.of( generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, true) ); @@ -630,8 +635,13 @@ class MessageControllerTest { } @ParameterizedTest - @MethodSource - void testGetMessages(boolean receiveStories) { + @CsvSource({ + "false, false", + "false, true", + "true, false", + "true, true" + }) + void testGetMessages(final boolean receiveStories, final boolean hasMore) { final long timestampOne = 313377; final long timestampTwo = 313388; @@ -651,7 +661,7 @@ class MessageControllerTest { ); when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), anyBoolean())) - .thenReturn(Mono.just(new Pair<>(envelopes, false))); + .thenReturn(Mono.just(new Pair<>(envelopes, hasMore))); final String userAgent = "Test-UA"; @@ -685,13 +695,12 @@ class MessageControllerTest { } verify(pushNotificationManager).handleMessagesRetrieved(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE, userAgent); - } - private static Stream testGetMessages() { - return Stream.of( - Arguments.of(true), - Arguments.of(false) - ); + if (hasMore) { + verify(pushNotificationScheduler).scheduleDelayedNotification(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()); + } else { + verify(pushNotificationScheduler, never()).scheduleDelayedNotification(any(), any(), any()); + } } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/ApnPushNotificationSchedulerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/ApnPushNotificationSchedulerTest.java deleted file mode 100644 index a634a5c04..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/ApnPushNotificationSchedulerTest.java +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Copyright 2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.push; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; - -import io.lettuce.core.cluster.SlotHash; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.List; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.ExecutionException; -import org.apache.commons.lang3.RandomStringUtils; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.CsvSource; -import org.mockito.ArgumentCaptor; -import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; -import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.util.Pair; -import org.whispersystems.textsecuregcm.util.TestClock; - -class ApnPushNotificationSchedulerTest { - - @RegisterExtension - static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); - - private Account account; - private Device device; - - private APNSender apnSender; - private TestClock clock; - - private ApnPushNotificationScheduler apnPushNotificationScheduler; - - private static final UUID ACCOUNT_UUID = UUID.randomUUID(); - private static final String ACCOUNT_NUMBER = "+18005551234"; - private static final byte DEVICE_ID = 1; - private static final String APN_ID = RandomStringUtils.randomAlphanumeric(32); - private static final String VOIP_APN_ID = RandomStringUtils.randomAlphanumeric(32); - - @BeforeEach - void setUp() throws Exception { - - device = mock(Device.class); - when(device.getId()).thenReturn(DEVICE_ID); - when(device.getApnId()).thenReturn(APN_ID); - when(device.getVoipApnId()).thenReturn(VOIP_APN_ID); - when(device.getLastSeen()).thenReturn(System.currentTimeMillis()); - - account = mock(Account.class); - when(account.getUuid()).thenReturn(ACCOUNT_UUID); - when(account.getNumber()).thenReturn(ACCOUNT_NUMBER); - when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device)); - - final AccountsManager accountsManager = mock(AccountsManager.class); - when(accountsManager.getByE164(ACCOUNT_NUMBER)).thenReturn(Optional.of(account)); - when(accountsManager.getByAccountIdentifier(ACCOUNT_UUID)).thenReturn(Optional.of(account)); - - apnSender = mock(APNSender.class); - clock = TestClock.now(); - - apnPushNotificationScheduler = new ApnPushNotificationScheduler(REDIS_CLUSTER_EXTENSION.getRedisCluster(), - apnSender, accountsManager, clock, 1); - } - - @Test - void testClusterInsert() throws ExecutionException, InterruptedException { - final String endpoint = ApnPushNotificationScheduler.getEndpointKey(account, device); - final long currentTimeMillis = System.currentTimeMillis(); - - assertTrue( - apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty()); - - clock.pin(Instant.ofEpochMilli(currentTimeMillis - 30_000)); - apnPushNotificationScheduler.scheduleRecurringVoipNotification(account, device).toCompletableFuture().get(); - - clock.pin(Instant.ofEpochMilli(currentTimeMillis)); - final List pendingDestinations = apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 2); - assertEquals(1, pendingDestinations.size()); - - final Optional> maybeUuidAndDeviceId = ApnPushNotificationScheduler.getSeparated( - pendingDestinations.get(0)); - - assertTrue(maybeUuidAndDeviceId.isPresent()); - assertEquals(ACCOUNT_UUID.toString(), maybeUuidAndDeviceId.get().first()); - assertEquals(DEVICE_ID, maybeUuidAndDeviceId.get().second()); - - assertTrue( - apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty()); - } - - @Test - void testProcessRecurringVoipNotifications() throws ExecutionException, InterruptedException { - final ApnPushNotificationScheduler.NotificationWorker worker = apnPushNotificationScheduler.new NotificationWorker(); - final long currentTimeMillis = System.currentTimeMillis(); - - clock.pin(Instant.ofEpochMilli(currentTimeMillis - 30_000)); - apnPushNotificationScheduler.scheduleRecurringVoipNotification(account, device).toCompletableFuture().get(); - - clock.pin(Instant.ofEpochMilli(currentTimeMillis)); - - final int slot = SlotHash.getSlot(ApnPushNotificationScheduler.getEndpointKey(account, device)); - - assertEquals(1, worker.processRecurringVoipNotifications(slot)); - - final ArgumentCaptor notificationCaptor = ArgumentCaptor.forClass(PushNotification.class); - verify(apnSender).sendNotification(notificationCaptor.capture()); - - final PushNotification pushNotification = notificationCaptor.getValue(); - - assertEquals(VOIP_APN_ID, pushNotification.deviceToken()); - assertEquals(account, pushNotification.destination()); - assertEquals(device, pushNotification.destinationDevice()); - - assertEquals(0, worker.processRecurringVoipNotifications(slot)); - } - - @Test - void testScheduleBackgroundNotificationWithNoRecentNotification() throws ExecutionException, InterruptedException { - final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); - clock.pin(now); - - assertEquals(Optional.empty(), - apnPushNotificationScheduler.getLastBackgroundNotificationTimestamp(account, device)); - - assertEquals(Optional.empty(), - apnPushNotificationScheduler.getNextScheduledBackgroundNotificationTimestamp(account, device)); - - apnPushNotificationScheduler.scheduleBackgroundNotification(account, device).toCompletableFuture().get(); - - assertEquals(Optional.of(now), - apnPushNotificationScheduler.getNextScheduledBackgroundNotificationTimestamp(account, device)); - } - - @Test - void testScheduleBackgroundNotificationWithRecentNotification() throws ExecutionException, InterruptedException { - final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); - final Instant recentNotificationTimestamp = - now.minus(ApnPushNotificationScheduler.BACKGROUND_NOTIFICATION_PERIOD.dividedBy(2)); - - // Insert a timestamp for a recently-sent background push notification - clock.pin(Instant.ofEpochMilli(recentNotificationTimestamp.toEpochMilli())); - apnPushNotificationScheduler.sendBackgroundNotification(account, device); - - clock.pin(now); - apnPushNotificationScheduler.scheduleBackgroundNotification(account, device).toCompletableFuture().get(); - - final Instant expectedScheduledTimestamp = - recentNotificationTimestamp.plus(ApnPushNotificationScheduler.BACKGROUND_NOTIFICATION_PERIOD); - - assertEquals(Optional.of(expectedScheduledTimestamp), - apnPushNotificationScheduler.getNextScheduledBackgroundNotificationTimestamp(account, device)); - } - - @Test - void testProcessScheduledBackgroundNotifications() throws ExecutionException, InterruptedException { - final ApnPushNotificationScheduler.NotificationWorker worker = apnPushNotificationScheduler.new NotificationWorker(); - - final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); - - clock.pin(Instant.ofEpochMilli(now.toEpochMilli())); - apnPushNotificationScheduler.scheduleBackgroundNotification(account, device).toCompletableFuture().get(); - - final int slot = - SlotHash.getSlot(ApnPushNotificationScheduler.getPendingBackgroundNotificationQueueKey(account, device)); - - clock.pin(Instant.ofEpochMilli(now.minusMillis(1).toEpochMilli())); - assertEquals(0, worker.processScheduledBackgroundNotifications(slot)); - - clock.pin(now); - assertEquals(1, worker.processScheduledBackgroundNotifications(slot)); - - final ArgumentCaptor notificationCaptor = ArgumentCaptor.forClass(PushNotification.class); - verify(apnSender).sendNotification(notificationCaptor.capture()); - - final PushNotification pushNotification = notificationCaptor.getValue(); - - assertEquals(PushNotification.TokenType.APN, pushNotification.tokenType()); - assertEquals(APN_ID, pushNotification.deviceToken()); - assertEquals(account, pushNotification.destination()); - assertEquals(device, pushNotification.destinationDevice()); - assertEquals(PushNotification.NotificationType.NOTIFICATION, pushNotification.notificationType()); - assertFalse(pushNotification.urgent()); - - assertEquals(0, worker.processRecurringVoipNotifications(slot)); - } - - @Test - void testProcessScheduledBackgroundNotificationsCancelled() throws ExecutionException, InterruptedException { - final ApnPushNotificationScheduler.NotificationWorker worker = apnPushNotificationScheduler.new NotificationWorker(); - - final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); - - clock.pin(now); - apnPushNotificationScheduler.scheduleBackgroundNotification(account, device).toCompletableFuture().get(); - apnPushNotificationScheduler.cancelScheduledNotifications(account, device).toCompletableFuture().get(); - - final int slot = - SlotHash.getSlot(ApnPushNotificationScheduler.getPendingBackgroundNotificationQueueKey(account, device)); - - assertEquals(0, worker.processScheduledBackgroundNotifications(slot)); - - verify(apnSender, never()).sendNotification(any()); - } - - @ParameterizedTest - @CsvSource({ - "1, true", - "0, false", - }) - void testDedicatedProcessDynamicConfiguration(final int dedicatedThreadCount, final boolean expectActivity) - throws Exception { - - final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class); - when(redisCluster.withCluster(any())).thenReturn(0L); - - final AccountsManager accountsManager = mock(AccountsManager.class); - - apnPushNotificationScheduler = new ApnPushNotificationScheduler(redisCluster, apnSender, - accountsManager, dedicatedThreadCount); - - apnPushNotificationScheduler.start(); - apnPushNotificationScheduler.stop(); - - if (expectActivity) { - verify(redisCluster, atLeastOnce()).withCluster(any()); - } else { - verifyNoInteractions(redisCluster); - verifyNoInteractions(accountsManager); - verifyNoInteractions(apnSender); - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationManagerTest.java index 9d8f28b2e..039f5cee6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationManagerTest.java @@ -33,7 +33,7 @@ class PushNotificationManagerTest { private AccountsManager accountsManager; private APNSender apnSender; private FcmSender fcmSender; - private ApnPushNotificationScheduler apnPushNotificationScheduler; + private PushNotificationScheduler pushNotificationScheduler; private PushNotificationManager pushNotificationManager; @@ -42,12 +42,12 @@ class PushNotificationManagerTest { accountsManager = mock(AccountsManager.class); apnSender = mock(APNSender.class); fcmSender = mock(FcmSender.class); - apnPushNotificationScheduler = mock(ApnPushNotificationScheduler.class); + pushNotificationScheduler = mock(PushNotificationScheduler.class); AccountsHelper.setupMockUpdate(accountsManager); pushNotificationManager = - new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler); + new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler); } @ParameterizedTest @@ -152,7 +152,7 @@ class PushNotificationManagerTest { verifyNoInteractions(apnSender); verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); verify(device, never()).setGcmId(any()); - verifyNoInteractions(apnPushNotificationScheduler); + verifyNoInteractions(pushNotificationScheduler); } @ParameterizedTest @@ -171,7 +171,7 @@ class PushNotificationManagerTest { .thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, Optional.empty(), false, Optional.empty()))); if (!urgent) { - when(apnPushNotificationScheduler.scheduleBackgroundNotification(account, device)) + when(pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device)) .thenReturn(CompletableFuture.completedFuture(null)); } @@ -181,10 +181,10 @@ class PushNotificationManagerTest { if (urgent) { verify(apnSender).sendNotification(pushNotification); - verifyNoInteractions(apnPushNotificationScheduler); + verifyNoInteractions(pushNotificationScheduler); } else { verifyNoInteractions(apnSender); - verify(apnPushNotificationScheduler).scheduleBackgroundNotification(account, device); + verify(pushNotificationScheduler).scheduleBackgroundApnsNotification(account, device); } } @@ -210,8 +210,8 @@ class PushNotificationManagerTest { verifyNoInteractions(fcmSender); verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); verify(device, never()).setGcmId(any()); - verify(apnPushNotificationScheduler).scheduleRecurringVoipNotification(account, device); - verify(apnPushNotificationScheduler, never()).scheduleBackgroundNotification(any(), any()); + verify(pushNotificationScheduler).scheduleRecurringApnsVoipNotification(account, device); + verify(pushNotificationScheduler, never()).scheduleBackgroundApnsNotification(any(), any()); } @Test @@ -236,7 +236,7 @@ class PushNotificationManagerTest { verify(accountsManager).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); verify(device).setGcmId(null); verifyNoInteractions(apnSender); - verifyNoInteractions(apnPushNotificationScheduler); + verifyNoInteractions(pushNotificationScheduler); } @Test @@ -257,7 +257,7 @@ class PushNotificationManagerTest { when(apnSender.sendNotification(pushNotification)) .thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(false, Optional.empty(), true, Optional.empty()))); - when(apnPushNotificationScheduler.cancelScheduledNotifications(account, device)) + when(pushNotificationScheduler.cancelScheduledNotifications(account, device)) .thenReturn(CompletableFuture.completedFuture(null)); pushNotificationManager.sendNotification(pushNotification); @@ -266,7 +266,7 @@ class PushNotificationManagerTest { verify(accountsManager).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); verify(device).setVoipApnId(null); verify(device, never()).setApnId(any()); - verify(apnPushNotificationScheduler).cancelScheduledNotifications(account, device); + verify(pushNotificationScheduler).cancelScheduledNotifications(account, device); } @Test @@ -290,7 +290,7 @@ class PushNotificationManagerTest { when(apnSender.sendNotification(pushNotification)) .thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(false, Optional.empty(), true, Optional.of(tokenTimestamp.minusSeconds(60))))); - when(apnPushNotificationScheduler.cancelScheduledNotifications(account, device)) + when(pushNotificationScheduler.cancelScheduledNotifications(account, device)) .thenReturn(CompletableFuture.completedFuture(null)); pushNotificationManager.sendNotification(pushNotification); @@ -299,7 +299,7 @@ class PushNotificationManagerTest { verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); verify(device, never()).setVoipApnId(any()); verify(device, never()).setApnId(any()); - verify(apnPushNotificationScheduler, never()).cancelScheduledNotifications(account, device); + verify(pushNotificationScheduler, never()).cancelScheduledNotifications(account, device); } @Test @@ -312,11 +312,11 @@ class PushNotificationManagerTest { when(account.getUuid()).thenReturn(accountIdentifier); when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(apnPushNotificationScheduler.cancelScheduledNotifications(account, device)) + when(pushNotificationScheduler.cancelScheduledNotifications(account, device)) .thenReturn(CompletableFuture.completedFuture(null)); pushNotificationManager.handleMessagesRetrieved(account, device, userAgent); - verify(apnPushNotificationScheduler).cancelScheduledNotifications(account, device); + verify(pushNotificationScheduler).cancelScheduledNotifications(account, device); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationSchedulerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationSchedulerTest.java new file mode 100644 index 000000000..ead73819e --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationSchedulerTest.java @@ -0,0 +1,327 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.push; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import io.lettuce.core.cluster.SlotHash; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.TestClock; + +class PushNotificationSchedulerTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + private Account account; + private Device device; + + private APNSender apnSender; + private FcmSender fcmSender; + private TestClock clock; + + private PushNotificationScheduler pushNotificationScheduler; + + private static final UUID ACCOUNT_UUID = UUID.randomUUID(); + private static final String ACCOUNT_NUMBER = "+18005551234"; + private static final byte DEVICE_ID = 1; + private static final String APN_ID = RandomStringUtils.randomAlphanumeric(32); + private static final String VOIP_APN_ID = RandomStringUtils.randomAlphanumeric(32); + + @BeforeEach + void setUp() throws Exception { + + device = mock(Device.class); + when(device.getId()).thenReturn(DEVICE_ID); + when(device.getApnId()).thenReturn(APN_ID); + when(device.getVoipApnId()).thenReturn(VOIP_APN_ID); + when(device.getLastSeen()).thenReturn(System.currentTimeMillis()); + + account = mock(Account.class); + when(account.getUuid()).thenReturn(ACCOUNT_UUID); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_UUID); + when(account.getNumber()).thenReturn(ACCOUNT_NUMBER); + when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device)); + + final AccountsManager accountsManager = mock(AccountsManager.class); + when(accountsManager.getByE164(ACCOUNT_NUMBER)).thenReturn(Optional.of(account)); + when(accountsManager.getByAccountIdentifierAsync(ACCOUNT_UUID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + + apnSender = mock(APNSender.class); + fcmSender = mock(FcmSender.class); + clock = TestClock.now(); + + when(apnSender.sendNotification(any())) + .thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, Optional.empty(), false, Optional.empty()))); + + when(fcmSender.sendNotification(any())) + .thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, Optional.empty(), false, Optional.empty()))); + + pushNotificationScheduler = new PushNotificationScheduler(REDIS_CLUSTER_EXTENSION.getRedisCluster(), + apnSender, fcmSender, accountsManager, clock, 1, 1); + } + + @Test + void testClusterInsert() throws ExecutionException, InterruptedException { + final String endpoint = PushNotificationScheduler.getVoipEndpointKey(ACCOUNT_UUID, DEVICE_ID); + final long currentTimeMillis = System.currentTimeMillis(); + + assertTrue( + pushNotificationScheduler.getPendingDestinationsForRecurringApnsVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty()); + + clock.pin(Instant.ofEpochMilli(currentTimeMillis - 30_000)); + pushNotificationScheduler.scheduleRecurringApnsVoipNotification(account, device).toCompletableFuture().get(); + + clock.pin(Instant.ofEpochMilli(currentTimeMillis)); + final List pendingDestinations = pushNotificationScheduler.getPendingDestinationsForRecurringApnsVoipNotifications(SlotHash.getSlot(endpoint), 2); + assertEquals(1, pendingDestinations.size()); + + final Pair aciAndDeviceId = + PushNotificationScheduler.decodeAciAndDeviceId(pendingDestinations.getFirst()); + + assertEquals(ACCOUNT_UUID, aciAndDeviceId.first()); + assertEquals(DEVICE_ID, aciAndDeviceId.second()); + + assertTrue( + pushNotificationScheduler.getPendingDestinationsForRecurringApnsVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty()); + } + + @Test + void testProcessRecurringVoipNotifications() throws ExecutionException, InterruptedException { + final PushNotificationScheduler.NotificationWorker worker = pushNotificationScheduler.new NotificationWorker(1); + final long currentTimeMillis = System.currentTimeMillis(); + + clock.pin(Instant.ofEpochMilli(currentTimeMillis - 30_000)); + pushNotificationScheduler.scheduleRecurringApnsVoipNotification(account, device).toCompletableFuture().get(); + + clock.pin(Instant.ofEpochMilli(currentTimeMillis)); + + final int slot = SlotHash.getSlot(PushNotificationScheduler.getVoipEndpointKey(ACCOUNT_UUID, DEVICE_ID)); + + assertEquals(1, worker.processRecurringApnsVoipNotifications(slot)); + + final ArgumentCaptor notificationCaptor = ArgumentCaptor.forClass(PushNotification.class); + verify(apnSender).sendNotification(notificationCaptor.capture()); + + final PushNotification pushNotification = notificationCaptor.getValue(); + + assertEquals(VOIP_APN_ID, pushNotification.deviceToken()); + assertEquals(account, pushNotification.destination()); + assertEquals(device, pushNotification.destinationDevice()); + + assertEquals(0, worker.processRecurringApnsVoipNotifications(slot)); + } + + @Test + void testScheduleBackgroundNotificationWithNoRecentApnsNotification() throws ExecutionException, InterruptedException { + final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); + clock.pin(now); + + assertEquals(Optional.empty(), + pushNotificationScheduler.getLastBackgroundApnsNotificationTimestamp(account, device)); + + assertEquals(Optional.empty(), + pushNotificationScheduler.getNextScheduledBackgroundApnsNotificationTimestamp(account, device)); + + pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device).toCompletableFuture().get(); + + assertEquals(Optional.of(now), + pushNotificationScheduler.getNextScheduledBackgroundApnsNotificationTimestamp(account, device)); + } + + @Test + void testScheduleBackgroundNotificationWithRecentApnsNotification() throws ExecutionException, InterruptedException { + final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); + final Instant recentNotificationTimestamp = + now.minus(PushNotificationScheduler.BACKGROUND_NOTIFICATION_PERIOD.dividedBy(2)); + + // Insert a timestamp for a recently-sent background push notification + clock.pin(Instant.ofEpochMilli(recentNotificationTimestamp.toEpochMilli())); + pushNotificationScheduler.sendBackgroundApnsNotification(account, device); + + clock.pin(now); + pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device).toCompletableFuture().get(); + + final Instant expectedScheduledTimestamp = + recentNotificationTimestamp.plus(PushNotificationScheduler.BACKGROUND_NOTIFICATION_PERIOD); + + assertEquals(Optional.of(expectedScheduledTimestamp), + pushNotificationScheduler.getNextScheduledBackgroundApnsNotificationTimestamp(account, device)); + } + + @Test + void testCancelBackgroundApnsNotifications() { + final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); + clock.pin(now); + + pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device).toCompletableFuture().join(); + pushNotificationScheduler.cancelBackgroundApnsNotifications(account, device).join(); + + assertEquals(Optional.empty(), + pushNotificationScheduler.getLastBackgroundApnsNotificationTimestamp(account, device)); + + assertEquals(Optional.empty(), + pushNotificationScheduler.getNextScheduledBackgroundApnsNotificationTimestamp(account, device)); + } + + @Test + void testProcessScheduledBackgroundNotifications() { + final PushNotificationScheduler.NotificationWorker worker = pushNotificationScheduler.new NotificationWorker(1); + + final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); + + clock.pin(Instant.ofEpochMilli(now.toEpochMilli())); + pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device).toCompletableFuture().join(); + + final int slot = + SlotHash.getSlot(PushNotificationScheduler.getPendingBackgroundApnsNotificationQueueKey(account, device)); + + clock.pin(Instant.ofEpochMilli(now.minusMillis(1).toEpochMilli())); + assertEquals(0, worker.processScheduledBackgroundApnsNotifications(slot)); + + clock.pin(now); + assertEquals(1, worker.processScheduledBackgroundApnsNotifications(slot)); + + final ArgumentCaptor notificationCaptor = ArgumentCaptor.forClass(PushNotification.class); + verify(apnSender).sendNotification(notificationCaptor.capture()); + + final PushNotification pushNotification = notificationCaptor.getValue(); + + assertEquals(PushNotification.TokenType.APN, pushNotification.tokenType()); + assertEquals(APN_ID, pushNotification.deviceToken()); + assertEquals(account, pushNotification.destination()); + assertEquals(device, pushNotification.destinationDevice()); + assertEquals(PushNotification.NotificationType.NOTIFICATION, pushNotification.notificationType()); + assertFalse(pushNotification.urgent()); + + assertEquals(0, worker.processRecurringApnsVoipNotifications(slot)); + + assertEquals(Optional.empty(), + pushNotificationScheduler.getNextScheduledBackgroundApnsNotificationTimestamp(account, device)); + } + + @Test + void testProcessScheduledBackgroundNotificationsCancelled() throws ExecutionException, InterruptedException { + final PushNotificationScheduler.NotificationWorker worker = pushNotificationScheduler.new NotificationWorker(1); + + final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); + + clock.pin(now); + pushNotificationScheduler.scheduleBackgroundApnsNotification(account, device).toCompletableFuture().get(); + pushNotificationScheduler.cancelScheduledNotifications(account, device).toCompletableFuture().get(); + + final int slot = + SlotHash.getSlot(PushNotificationScheduler.getPendingBackgroundApnsNotificationQueueKey(account, device)); + + assertEquals(0, worker.processScheduledBackgroundApnsNotifications(slot)); + + verify(apnSender, never()).sendNotification(any()); + } + + @Test + void testScheduleDelayedNotification() { + clock.pin(Instant.now()); + + assertEquals(Optional.empty(), + pushNotificationScheduler.getNextScheduledDelayedNotificationTimestamp(account, device)); + + pushNotificationScheduler.scheduleDelayedNotification(account, device, Duration.ofMinutes(1)).join(); + + assertEquals(Optional.of(clock.instant().truncatedTo(ChronoUnit.MILLIS).plus(Duration.ofMinutes(1))), + pushNotificationScheduler.getNextScheduledDelayedNotificationTimestamp(account, device)); + + pushNotificationScheduler.scheduleDelayedNotification(account, device, Duration.ofMinutes(2)).join(); + + assertEquals(Optional.of(clock.instant().truncatedTo(ChronoUnit.MILLIS).plus(Duration.ofMinutes(2))), + pushNotificationScheduler.getNextScheduledDelayedNotificationTimestamp(account, device)); + } + + @Test + void testCancelDelayedNotification() { + pushNotificationScheduler.scheduleDelayedNotification(account, device, Duration.ofMinutes(1)).join(); + pushNotificationScheduler.cancelDelayedNotifications(account, device).join(); + + assertEquals(Optional.empty(), + pushNotificationScheduler.getNextScheduledDelayedNotificationTimestamp(account, device)); + } + + @Test + void testProcessScheduledDelayedNotifications() { + final PushNotificationScheduler.NotificationWorker worker = pushNotificationScheduler.new NotificationWorker(1); + final int slot = SlotHash.getSlot(PushNotificationScheduler.getDelayedNotificationQueueKey(account, device)); + + clock.pin(Instant.now()); + + pushNotificationScheduler.scheduleDelayedNotification(account, device, Duration.ofMinutes(1)).join(); + + assertEquals(0, worker.processScheduledDelayedNotifications(slot)); + + clock.pin(clock.instant().plus(Duration.ofMinutes(1))); + + assertEquals(1, worker.processScheduledDelayedNotifications(slot)); + assertEquals(Optional.empty(), + pushNotificationScheduler.getNextScheduledDelayedNotificationTimestamp(account, device)); + } + + @ParameterizedTest + @CsvSource({ + "1, true", + "0, false", + }) + void testDedicatedProcessDynamicConfiguration(final int dedicatedThreadCount, final boolean expectActivity) + throws Exception { + + final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class); + when(redisCluster.withCluster(any())).thenReturn(0L); + + final AccountsManager accountsManager = mock(AccountsManager.class); + + pushNotificationScheduler = new PushNotificationScheduler(redisCluster, apnSender, fcmSender, + accountsManager, dedicatedThreadCount, 1); + + pushNotificationScheduler.start(); + pushNotificationScheduler.stop(); + + if (expectActivity) { + verify(redisCluster, atLeastOnce()).withCluster(any()); + } else { + verifyNoInteractions(redisCluster); + verifyNoInteractions(accountsManager); + verifyNoInteractions(apnSender); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index b93ed78a0..53622c21e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -48,6 +48,7 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.push.PushNotificationManager; +import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.storage.Account; @@ -127,6 +128,7 @@ class WebSocketConnectionIntegrationTest { new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessageMetrics(), mock(PushNotificationManager.class), + mock(PushNotificationScheduler.class), new AuthenticatedDevice(account, device), webSocketClient, scheduledExecutorService, @@ -213,6 +215,7 @@ class WebSocketConnectionIntegrationTest { new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessageMetrics(), mock(PushNotificationManager.class), + mock(PushNotificationScheduler.class), new AuthenticatedDevice(account, device), webSocketClient, scheduledExecutorService, @@ -280,6 +283,7 @@ class WebSocketConnectionIntegrationTest { new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessageMetrics(), mock(PushNotificationManager.class), + mock(PushNotificationScheduler.class), new AuthenticatedDevice(account, device), webSocketClient, 100, // use a very short timeout, so that this test completes quickly diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index f3f32a667..0a114c051 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -59,6 +59,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; +import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; @@ -122,8 +123,8 @@ class WebSocketConnectionTest { WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class)); AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager, - new MessageMetrics(), mock(PushNotificationManager.class), mock(ClientPresenceManager.class), - retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager); + new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), + mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) @@ -626,7 +627,7 @@ class WebSocketConnectionTest { private @NotNull WebSocketConnection webSocketConnection(final WebSocketClient client) { return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(), - mock(PushNotificationManager.class), auth, client, + mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client, retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager); }