diff --git a/pom.xml b/pom.xml
index bfba1137d..e16a6f7fa 100644
--- a/pom.xml
+++ b/pom.xml
@@ -56,7 +56,7 @@
3.0.2
1.9.24
1.5.1
- 6.3.2.RELEASE
+ 6.4.1.RELEASE
8.13.40
7.3
2.23.1
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index 02b91fc52..b716bd967 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -196,6 +196,7 @@ import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
@@ -569,6 +570,8 @@ public class WhisperServerService extends Application(AuthenticatedDevice.class));
- environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
+ environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager,
+ pubSubClientEventManager));
environment.jersey().register(new TimestampResponseFilter());
///
@@ -1006,10 +1012,11 @@ public class WhisperServerService extends Application provisioningEnvironment = new WebSocketEnvironment<>(environment,
webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000));
- provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
+ provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager,
+ pubSubClientEventManager));
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager));
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java
index 01b3b1567..7b0605cda 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java
@@ -27,6 +27,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@@ -55,6 +56,7 @@ public class RegistrationLockVerificationManager {
private final AccountsManager accounts;
private final ClientPresenceManager clientPresenceManager;
+ private final PubSubClientEventManager pubSubClientEventManager;
private final ExternalServiceCredentialsGenerator svr2CredentialGenerator;
private final ExternalServiceCredentialsGenerator svr3CredentialGenerator;
private final RateLimiters rateLimiters;
@@ -62,7 +64,9 @@ public class RegistrationLockVerificationManager {
private final PushNotificationManager pushNotificationManager;
public RegistrationLockVerificationManager(
- final AccountsManager accounts, final ClientPresenceManager clientPresenceManager,
+ final AccountsManager accounts,
+ final ClientPresenceManager clientPresenceManager,
+ final PubSubClientEventManager pubSubClientEventManager,
final ExternalServiceCredentialsGenerator svr2CredentialGenerator,
final ExternalServiceCredentialsGenerator svr3CredentialGenerator,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
@@ -70,6 +74,7 @@ public class RegistrationLockVerificationManager {
final RateLimiters rateLimiters) {
this.accounts = accounts;
this.clientPresenceManager = clientPresenceManager;
+ this.pubSubClientEventManager = pubSubClientEventManager;
this.svr2CredentialGenerator = svr2CredentialGenerator;
this.svr3CredentialGenerator = svr3CredentialGenerator;
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
@@ -161,6 +166,7 @@ public class RegistrationLockVerificationManager {
final List deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList();
clientPresenceManager.disconnectAllPresences(updatedAccount.getUuid(), deviceIds);
+ pubSubClientEventManager.requestDisconnection(updatedAccount.getUuid(), deviceIds);
try {
// Send a push notification that prompts the client to attempt login and fail due to locked credentials
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java
index 2f54d78ba..a645a790c 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java
@@ -10,6 +10,7 @@ import org.glassfish.jersey.server.monitoring.ApplicationEventListener;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
/**
@@ -20,9 +21,11 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven
private final WebsocketRefreshRequestEventListener websocketRefreshRequestEventListener;
public WebsocketRefreshApplicationEventListener(final AccountsManager accountsManager,
- final ClientPresenceManager clientPresenceManager) {
+ final ClientPresenceManager clientPresenceManager,
+ final PubSubClientEventManager pubSubClientEventManager) {
this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager,
+ pubSubClientEventManager,
new LinkedDeviceRefreshRequirementProvider(accountsManager),
new PhoneNumberChangeRefreshRequirementProvider(accountsManager));
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java
index 9fbb84fab..65f8e4dcb 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java
@@ -10,6 +10,7 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.util.Arrays;
+import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import javax.ws.rs.container.ResourceInfo;
import javax.ws.rs.core.Context;
@@ -19,10 +20,12 @@ import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
public class WebsocketRefreshRequestEventListener implements RequestEventListener {
private final ClientPresenceManager clientPresenceManager;
+ private final PubSubClientEventManager pubSubClientEventManager;
private final WebsocketRefreshRequirementProvider[] providers;
private static final Counter DISPLACED_ACCOUNTS = Metrics.counter(
@@ -35,9 +38,11 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene
public WebsocketRefreshRequestEventListener(
final ClientPresenceManager clientPresenceManager,
+ final PubSubClientEventManager pubSubClientEventManager,
final WebsocketRefreshRequirementProvider... providers) {
this.clientPresenceManager = clientPresenceManager;
+ this.pubSubClientEventManager = pubSubClientEventManager;
this.providers = providers;
}
@@ -60,6 +65,7 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene
try {
displacedDevices.incrementAndGet();
clientPresenceManager.disconnectPresence(pair.first(), pair.second());
+ pubSubClientEventManager.requestDisconnection(pair.first(), List.of(pair.second()));
} catch (final Exception e) {
logger.error("Could not displace device presence", e);
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientEventListener.java
new file mode 100644
index 000000000..1697fb545
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientEventListener.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright 2024 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.push;
+
+/**
+ * A client event listener handles events related to a client's message-retrieval presence. Handler methods are run on
+ * dedicated threads and may safely perform blocking operations.
+ */
+public interface ClientEventListener {
+
+ /**
+ * Indicates that a new message is available in the connected client's message queue.
+ */
+ void handleNewMessageAvailable();
+
+ /**
+ * Indicates that the client's presence has been displaced and the listener should close the client's underlying
+ * network connection.
+ *
+ * @param connectedElsewhere if {@code true}, indicates that the client's presence has been displaced by another
+ * connection from the same client
+ */
+ void handleConnectionDisplaced(boolean connectedElsewhere);
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java
index 80083a4dc..54cbe6be4 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java
@@ -14,6 +14,7 @@ import io.lettuce.core.RedisFuture;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
+import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter;
import io.micrometer.core.instrument.Counter;
@@ -277,7 +278,7 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter Metrics.counter(SEND_COUNTER_NAME,
+ CHANNEL_TAG_NAME, channel,
+ EPHEMERAL_TAG_NAME, String.valueOf(online),
+ CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
+ PUB_SUB_CLIENT_ONLINE_TAG_NAME, String.valueOf(Objects.requireNonNullElse(present, false)),
+ URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
+ STORY_TAG_NAME, String.valueOf(message.getStory()),
+ SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
+ .increment());
}
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManager.java
new file mode 100644
index 000000000..e3d84751d
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManager.java
@@ -0,0 +1,407 @@
+/*
+ * Copyright 2024 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.push;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.protobuf.InvalidProtocolBufferException;
+import io.dropwizard.lifecycle.Managed;
+import io.lettuce.core.cluster.SlotHash;
+import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
+import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
+import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter;
+import io.micrometer.core.instrument.Counter;
+import io.micrometer.core.instrument.Metrics;
+import io.micrometer.core.instrument.Tags;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
+import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
+import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
+import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
+import org.whispersystems.textsecuregcm.storage.Device;
+import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
+import org.whispersystems.textsecuregcm.util.UUIDUtil;
+import org.whispersystems.textsecuregcm.util.Util;
+import javax.annotation.Nullable;
+import java.nio.charset.StandardCharsets;
+import java.util.*;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionStage;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executor;
+import java.util.concurrent.atomic.AtomicReference;
+
+/**
+ * The pub/sub-based client presence manager uses the Redis 7 sharded pub/sub system to notify connected clients that
+ * new messages are available for retrieval and report to senders whether a client was present to receive a message when
+ * sent. This system makes a best effort to ensure that a given client has only a single open connection across the
+ * fleet of servers, but cannot guarantee at-most-one behavior.
+ */
+public class PubSubClientEventManager extends RedisClusterPubSubAdapter implements Managed {
+
+ private final FaultTolerantRedisClusterClient clusterClient;
+ private final Executor listenerEventExecutor;
+
+ private final ExperimentEnrollmentManager experimentEnrollmentManager;
+ static final String EXPERIMENT_NAME = "pubSubPresenceManager";
+
+ @Nullable
+ private FaultTolerantPubSubClusterConnection pubSubConnection;
+
+ private final Map listenersByAccountAndDeviceIdentifier;
+
+ private static final byte[] NEW_MESSAGE_EVENT_BYTES = ClientEvent.newBuilder()
+ .setNewMessageAvailable(NewMessageAvailableEvent.getDefaultInstance())
+ .build()
+ .toByteArray();
+
+ private static final byte[] DISCONNECT_REQUESTED_EVENT_BYTES = ClientEvent.newBuilder()
+ .setDisconnectRequested(DisconnectRequested.getDefaultInstance())
+ .build()
+ .toByteArray();
+
+ private static final Counter PUBLISH_CLIENT_CONNECTION_EVENT_ERROR_COUNTER =
+ Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "publishClientConnectionEventError"));
+
+ private static final Counter UNSUBSCRIBE_ERROR_COUNTER =
+ Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "unsubscribeError"));
+
+ private static final Counter MESSAGE_WITHOUT_LISTENER_COUNTER =
+ Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "messageWithoutListener"));
+
+ private static final String LISTENER_GAUGE_NAME =
+ MetricsUtil.name(PubSubClientEventManager.class, "listeners");
+
+ private static final Logger logger = LoggerFactory.getLogger(PubSubClientEventManager.class);
+
+ private record AccountAndDeviceIdentifier(UUID accountIdentifier, byte deviceId) {
+ }
+
+ private record ConnectionIdAndListener(UUID connectionIdentifier, ClientEventListener listener) {
+ }
+
+ public PubSubClientEventManager(final FaultTolerantRedisClusterClient clusterClient,
+ final Executor listenerEventExecutor,
+ final ExperimentEnrollmentManager experimentEnrollmentManager) {
+
+ this.clusterClient = clusterClient;
+ this.listenerEventExecutor = listenerEventExecutor;
+ this.experimentEnrollmentManager = experimentEnrollmentManager;
+
+ this.listenersByAccountAndDeviceIdentifier =
+ Metrics.gaugeMapSize(LISTENER_GAUGE_NAME, Tags.empty(), new ConcurrentHashMap<>());
+ }
+
+ @Override
+ public synchronized void start() {
+ this.pubSubConnection = clusterClient.createBinaryPubSubConnection();
+ this.pubSubConnection.usePubSubConnection(connection -> connection.addListener(this));
+
+ pubSubConnection.subscribeToClusterTopologyChangedEvents(this::resubscribe);
+ }
+
+ @Override
+ public synchronized void stop() {
+ if (pubSubConnection != null) {
+ pubSubConnection.usePubSubConnection(connection -> {
+ connection.removeListener(this);
+ connection.close();
+ });
+ }
+
+ pubSubConnection = null;
+ }
+
+ /**
+ * Marks the given device as "present" and registers a listener for new messages and conflicting connections. If the
+ * given device already has a presence registered with this presence manager instance, that presence is displaced
+ * immediately and the listener's {@link ClientEventListener#handleConnectionDisplaced(boolean)} method is called.
+ *
+ * @param accountIdentifier the account identifier for the newly-connected device
+ * @param deviceId the ID of the newly-connected device within the given account
+ * @param listener the listener to notify when new messages or conflicting connections arrive for the newly-conencted
+ * device
+ *
+ * @return a future that yields a connection identifier when the new device's presence has been registered; the future
+ * may fail if a pub/sub subscription could not be established, in which case callers should close the client's
+ * connection to the server
+ */
+ public CompletionStage handleClientConnected(final UUID accountIdentifier, final byte deviceId, final ClientEventListener listener) {
+ if (pubSubConnection == null) {
+ throw new IllegalStateException("Presence manager not started");
+ }
+
+ if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
+ return CompletableFuture.completedFuture(UUID.randomUUID());
+ }
+
+ final UUID connectionId = UUID.randomUUID();
+ final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId);
+ final AtomicReference displacedListener = new AtomicReference<>();
+ final AtomicReference> subscribeFuture = new AtomicReference<>();
+
+ // Note that we're relying on some specific implementation details of `ConcurrentHashMap#compute(...)`. In
+ // particular, the behavioral contract for `ConcurrentHashMap#compute(...)` says:
+ //
+ // > The entire method invocation is performed atomically. The supplied function is invoked exactly once per
+ // > invocation of this method. Some attempted update operations on this map by other threads may be blocked while
+ // > computation is in progress, so the computation should be short and simple.
+ //
+ // This provides a mechanism to make sure that we enqueue subscription/unsubscription operations in the same order
+ // as adding/removing listeners from the map and helps us avoid races and conflicts. Note that the enqueued
+ // operation is asynchronous; we're not blocking on it in the scope of the `compute` operation.
+ listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId),
+ (key, existingIdAndListener) -> {
+ subscribeFuture.set(pubSubConnection.withPubSubConnection(connection ->
+ connection.async().ssubscribe(clientPresenceKey)));
+
+ if (existingIdAndListener != null) {
+ displacedListener.set(existingIdAndListener.listener());
+ }
+
+ return new ConnectionIdAndListener(connectionId, listener);
+ });
+
+ if (displacedListener.get() != null) {
+ listenerEventExecutor.execute(() -> displacedListener.get().handleConnectionDisplaced(true));
+ }
+
+ return subscribeFuture.get()
+ .thenCompose(ignored -> clusterClient.withBinaryCluster(connection -> connection.async()
+ .spublish(clientPresenceKey, buildClientConnectedMessage(connectionId))))
+ .handle((ignored, throwable) -> {
+ if (throwable != null) {
+ PUBLISH_CLIENT_CONNECTION_EVENT_ERROR_COUNTER.increment();
+ }
+
+ return connectionId;
+ });
+ }
+
+ /**
+ * Removes the "presence" for the given device. The presence is removed if and only if the given connection ID matches
+ * the connection ID for the currently-registered presence. Callers should call this method when they have closed or
+ * intend to close the client's underlying network connection.
+ *
+ * @param accountIdentifier the identifier of the account for the disconnected device
+ * @param deviceId the ID of the disconnected device within the given account
+ * @param connectionId the ID of the connection that has been closed (or will be closed)
+ *
+ * @return a future that completes when the presence has been removed
+ */
+ public CompletionStage handleClientDisconnected(final UUID accountIdentifier, final byte deviceId, final UUID connectionId) {
+ if (pubSubConnection == null) {
+ throw new IllegalStateException("Presence manager not started");
+ }
+
+ if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
+ return CompletableFuture.completedFuture(null);
+ }
+
+ final AtomicReference> unsubscribeFuture = new AtomicReference<>();
+
+ // Note that we're relying on some specific implementation details of `ConcurrentHashMap#compute(...)`. In
+ // particular, the behavioral contract for `ConcurrentHashMap#compute(...)` says:
+ //
+ // > The entire method invocation is performed atomically. The supplied function is invoked exactly once per
+ // > invocation of this method. Some attempted update operations on this map by other threads may be blocked while
+ // > computation is in progress, so the computation should be short and simple.
+ //
+ // This provides a mechanism to make sure that we enqueue subscription/unsubscription operations in the same order
+ // as adding/removing listeners from the map and helps us avoid races and conflicts. Note that the enqueued
+ // operation is asynchronous; we're not blocking on it in the scope of the `compute` operation.
+ listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId),
+ (ignored, existingIdAndListener) -> {
+ final ConnectionIdAndListener remainingIdAndListener;
+
+ if (existingIdAndListener == null) {
+ remainingIdAndListener = null;
+ } else if (existingIdAndListener.connectionIdentifier().equals(connectionId)) {
+ remainingIdAndListener = null;
+ } else {
+ remainingIdAndListener = existingIdAndListener;
+ }
+
+ if (remainingIdAndListener == null) {
+ // Only unsubscribe if there's no listener remaining
+ unsubscribeFuture.set(pubSubConnection.withPubSubConnection(connection ->
+ connection.async().sunsubscribe(getClientPresenceKey(accountIdentifier, deviceId)))
+ .thenRun(Util.NOOP));
+ } else {
+ unsubscribeFuture.set(CompletableFuture.completedFuture(null));
+ }
+
+ return remainingIdAndListener;
+ });
+
+ return unsubscribeFuture.get()
+ .whenComplete((ignored, throwable) -> {
+ if (throwable != null) {
+ UNSUBSCRIBE_ERROR_COUNTER.increment();
+ }
+ });
+ }
+
+ /**
+ * Publishes an event notifying a specific device that a new message is available for retrieval. This method indicates
+ * whether the target device is "present" (i.e. has an active listener). Callers may choose to take follow-up action
+ * (like sending a push notification) if the target device is not present.
+ *
+ * @param accountIdentifier the account identifier of the receiving device
+ * @param deviceId the ID of the receiving device within the target account
+ *
+ * @return a future that yields {@code true} if the target device had an active listener or {@code false} otherwise
+ */
+ public CompletionStage handleNewMessageAvailable(final UUID accountIdentifier, final byte deviceId) {
+ if (pubSubConnection == null) {
+ throw new IllegalStateException("Presence manager not started");
+ }
+
+ if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
+ return CompletableFuture.completedFuture(false);
+ }
+
+ return pubSubConnection.withPubSubConnection(connection ->
+ connection.async().spublish(getClientPresenceKey(accountIdentifier, deviceId), NEW_MESSAGE_EVENT_BYTES))
+ .thenApply(listeners -> listeners > 0);
+ }
+
+ /**
+ * Tests whether a client with the given account/device is connected to this presence manager instance.
+ *
+ * @param accountUuid the account identifier for the client to check
+ * @param deviceId the ID of the device within the given account
+ *
+ * @return {@code true} if a client with the given account/device is connected to this presence manager instance or
+ * {@code false} if the client is not connected at all or is connected to a different presence manager instance
+ */
+ public boolean isLocallyPresent(final UUID accountUuid, final byte deviceId) {
+ return listenersByAccountAndDeviceIdentifier.containsKey(new AccountAndDeviceIdentifier(accountUuid, deviceId));
+ }
+
+ /**
+ * Broadcasts a request that all devices associated with the identified account and connected to any client presence
+ * instance close their network connections.
+ *
+ * @param accountIdentifier the account identifier for which to request disconnection
+ *
+ * @return a future that completes when the request has been sent
+ */
+ public CompletableFuture requestDisconnection(final UUID accountIdentifier) {
+ return requestDisconnection(accountIdentifier, Device.ALL_POSSIBLE_DEVICE_IDS);
+ }
+
+ /**
+ * Broadcasts a request that the specified devices associated with the identified account and connected to any client
+ * presence instance close their network connections.
+ *
+ * @param accountIdentifier the account identifier for which to request disconnection
+ * @param deviceIds the IDs of the devices for which to request disconnection
+ *
+ * @return a future that completes when the request has been sent
+ */
+ public CompletableFuture requestDisconnection(final UUID accountIdentifier, final Collection deviceIds) {
+ return CompletableFuture.allOf(deviceIds.stream()
+ .map(deviceId -> {
+ final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId);
+
+ return clusterClient.withBinaryCluster(connection -> connection.async()
+ .spublish(clientPresenceKey, DISCONNECT_REQUESTED_EVENT_BYTES))
+ .toCompletableFuture();
+ })
+ .toArray(CompletableFuture[]::new));
+ }
+
+ @VisibleForTesting
+ void resubscribe(final ClusterTopologyChangedEvent clusterTopologyChangedEvent) {
+ final boolean[] changedSlots = RedisClusterUtil.getChangedSlots(clusterTopologyChangedEvent);
+
+ final Map> clientPresenceKeysBySlot = new HashMap<>();
+
+ // Organize subscriptions by slot so we can issue a smaller number of larger resubscription commands
+ listenersByAccountAndDeviceIdentifier.keySet()
+ .stream()
+ .map(accountAndDeviceIdentifier -> getClientPresenceKey(accountAndDeviceIdentifier.accountIdentifier(), accountAndDeviceIdentifier.deviceId()))
+ .forEach(clientPresenceKey -> {
+ final int slot = SlotHash.getSlot(clientPresenceKey);
+
+ if (changedSlots[slot]) {
+ clientPresenceKeysBySlot.computeIfAbsent(slot, ignored -> new ArrayList<>()).add(clientPresenceKey);
+ }
+ });
+
+ // Issue one resubscription command per affected slot
+ clientPresenceKeysBySlot.forEach((slot, clientPresenceKeys) -> {
+ if (pubSubConnection != null) {
+ final byte[][] clientPresenceKeyArray = clientPresenceKeys.toArray(byte[][]::new);
+ pubSubConnection.usePubSubConnection(connection -> connection.sync().ssubscribe(clientPresenceKeyArray));
+ }
+ });
+ }
+
+ @Override
+ public void smessage(final RedisClusterNode node, final byte[] shardChannel, final byte[] message) {
+ final ClientEvent clientEvent;
+
+ try {
+ clientEvent = ClientEvent.parseFrom(message);
+ } catch (final InvalidProtocolBufferException e) {
+ logger.error("Failed to parse pub/sub event protobuf", e);
+ return;
+ }
+
+ final AccountAndDeviceIdentifier accountAndDeviceIdentifier = parseClientPresenceKey(shardChannel);
+
+ @Nullable final ConnectionIdAndListener connectionIdAndListener =
+ listenersByAccountAndDeviceIdentifier.get(accountAndDeviceIdentifier);
+
+ if (connectionIdAndListener != null) {
+ switch (clientEvent.getEventCase()) {
+ case NEW_MESSAGE_AVAILABLE -> connectionIdAndListener.listener().handleNewMessageAvailable();
+
+ case CLIENT_CONNECTED -> {
+ final UUID connectionId = UUIDUtil.fromByteString(clientEvent.getClientConnected().getConnectionId());
+
+ if (!connectionIdAndListener.connectionIdentifier().equals(connectionId)) {
+ listenerEventExecutor.execute(() ->
+ connectionIdAndListener.listener().handleConnectionDisplaced(true));
+ }
+ }
+
+ case DISCONNECT_REQUESTED -> listenerEventExecutor.execute(() ->
+ connectionIdAndListener.listener().handleConnectionDisplaced(false));
+
+ default -> logger.warn("Unexpected client event type: {}", clientEvent.getClass());
+ }
+ } else {
+ MESSAGE_WITHOUT_LISTENER_COUNTER.increment();
+ }
+ }
+
+ private static byte[] buildClientConnectedMessage(final UUID connectionId) {
+ return ClientEvent.newBuilder()
+ .setClientConnected(ClientConnectedEvent.newBuilder()
+ .setConnectionId(UUIDUtil.toByteString(connectionId))
+ .build())
+ .build()
+ .toByteArray();
+ }
+
+ @VisibleForTesting
+ static byte[] getClientPresenceKey(final UUID accountIdentifier, final byte deviceId) {
+ return ("client_presence::{" + accountIdentifier + ":" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
+ }
+
+ private static AccountAndDeviceIdentifier parseClientPresenceKey(final byte[] clientPresenceKeyBytes) {
+ final String clientPresenceKey = new String(clientPresenceKeyBytes, StandardCharsets.UTF_8);
+ final int uuidStart = "client_presence::{".length();
+
+ final UUID accountIdentifier = UUID.fromString(clientPresenceKey.substring(uuidStart, uuidStart + 36));
+ final byte deviceId = Byte.parseByte(clientPresenceKey.substring(uuidStart + 37, clientPresenceKey.length() - 1));
+
+ return new AccountAndDeviceIdentifier(accountIdentifier, deviceId);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java
index fe795d11a..e86eb99f6 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java
@@ -11,6 +11,7 @@ import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.scheduler.Scheduler;
+import java.util.function.Consumer;
public class FaultTolerantPubSubClusterConnection extends AbstractFaultTolerantPubSubConnection> {
@@ -32,7 +33,7 @@ public class FaultTolerantPubSubClusterConnection extends AbstractFaultTol
this.topologyChangedEventScheduler = topologyChangedEventScheduler;
}
- public void subscribeToClusterTopologyChangedEvents(final Runnable eventHandler) {
+ public void subscribeToClusterTopologyChangedEvents(final Consumer eventHandler) {
usePubSubConnection(connection -> connection.getResources().eventBus().get()
.filter(event -> {
@@ -53,7 +54,7 @@ public class FaultTolerantPubSubClusterConnection extends AbstractFaultTol
resubscribeRetry.executeRunnable(() -> {
try {
- eventHandler.run();
+ eventHandler.accept((ClusterTopologyChangedEvent) event);
} catch (final RuntimeException e) {
logger.warn("Resubscribe for {} failed", getName(), e);
throw e;
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java
index 7f5248336..5053c7c6b 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java
@@ -202,4 +202,11 @@ public class FaultTolerantRedisClusterClient {
Schedulers.newSingle(name + "-redisPubSubEvents", true));
}
+ public FaultTolerantPubSubClusterConnection createBinaryPubSubConnection() {
+ final StatefulRedisClusterPubSubConnection pubSubConnection = clusterClient.connectPubSub(ByteArrayCodec.INSTANCE);
+ pubSubConnections.add(pubSubConnection);
+
+ return new FaultTolerantPubSubClusterConnection<>(name, pubSubConnection, topologyChangedEventRetry,
+ Schedulers.newSingle(name + "-redisPubSubEvents", true));
+ }
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java
index 4f402c531..989dcd04b 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java
@@ -76,6 +76,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
@@ -126,6 +127,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen
private final SecureStorageClient secureStorageClient;
private final SecureValueRecovery2Client secureValueRecovery2Client;
private final ClientPresenceManager clientPresenceManager;
+ private final PubSubClientEventManager pubSubClientEventManager;
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final ClientPublicKeysManager clientPublicKeysManager;
private final Executor accountLockExecutor;
@@ -205,6 +207,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen
final SecureStorageClient secureStorageClient,
final SecureValueRecovery2Client secureValueRecovery2Client,
final ClientPresenceManager clientPresenceManager,
+ final PubSubClientEventManager pubSubClientEventManager,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final ClientPublicKeysManager clientPublicKeysManager,
final Executor accountLockExecutor,
@@ -223,6 +226,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen
this.secureStorageClient = secureStorageClient;
this.secureValueRecovery2Client = secureValueRecovery2Client;
this.clientPresenceManager = clientPresenceManager;
+ this.pubSubClientEventManager = pubSubClientEventManager;
this.registrationRecoveryPasswordsManager = requireNonNull(registrationRecoveryPasswordsManager);
this.clientPublicKeysManager = clientPublicKeysManager;
this.accountLockExecutor = accountLockExecutor;
@@ -329,7 +333,10 @@ public class AccountsManager extends RedisPubSubAdapter implemen
keysManager.deleteSingleUsePreKeys(pni),
messagesManager.clear(aci),
profilesManager.deleteAll(aci))
- .thenRunAsync(() -> clientPresenceManager.disconnectAllPresencesForUuid(aci), clientPresenceExecutor)
+ .thenRunAsync(() -> {
+ clientPresenceManager.disconnectAllPresencesForUuid(aci);
+ pubSubClientEventManager.requestDisconnection(aci);
+ }, clientPresenceExecutor)
.thenCompose(ignored -> accounts.reclaimAccount(e.getExistingAccount(), account, additionalWriteItems))
.thenCompose(ignored -> {
// We should have cleared all messages before overwriting the old account, but more may have arrived
@@ -594,6 +601,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen
.whenCompleteAsync((ignored, throwable) -> {
if (throwable == null) {
RedisOperation.unchecked(() -> clientPresenceManager.disconnectPresence(accountIdentifier, deviceId));
+ pubSubClientEventManager.requestDisconnection(accountIdentifier, List.of(deviceId));
}
}, clientPresenceExecutor);
}
@@ -1240,9 +1248,11 @@ public class AccountsManager extends RedisPubSubAdapter implemen
registrationRecoveryPasswordsManager.removeForNumber(account.getNumber()))
.thenCompose(ignored -> accounts.delete(account.getUuid(), additionalWriteItems))
.thenCompose(ignored -> redisDeleteAsync(account))
- .thenRunAsync(() -> RedisOperation.unchecked(() ->
- account.getDevices().forEach(device ->
- clientPresenceManager.disconnectPresence(account.getUuid(), device.getId()))), clientPresenceExecutor);
+ .thenRunAsync(() -> {
+ RedisOperation.unchecked(() -> clientPresenceManager.disconnectAllPresencesForUuid(account.getUuid()));
+
+ pubSubClientEventManager.requestDisconnection(account.getUuid());
+ }, clientPresenceExecutor);
}
private String getAccountMapKey(String key) {
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java
index 720cc3fb1..dbb2fd7bd 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java
@@ -13,6 +13,7 @@ import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.ZAddArgs;
import io.lettuce.core.cluster.SlotHash;
+import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter;
import io.micrometer.core.instrument.Counter;
@@ -247,7 +248,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp
pubSubConnection.usePubSubConnection(connection -> connection.sync().upstream().commands().unsubscribe());
}
- private void resubscribeAll() {
+ private void resubscribeAll(final ClusterTopologyChangedEvent event) {
final Set queueNames;
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java
index 188e1e4f4..f5fcd9a0f 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java
@@ -6,6 +6,12 @@
package org.whispersystems.textsecuregcm.util;
import io.lettuce.core.cluster.SlotHash;
+import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
+import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
public class RedisClusterUtil {
@@ -38,4 +44,51 @@ public class RedisClusterUtil {
public static String getMinimalHashTag(final int slot) {
return HASHES_BY_SLOT[slot];
}
+
+ /**
+ * Returns an array indicating which slots have moved as part of a {@link ClusterTopologyChangedEvent}. The elements
+ * of the array map to slots in the cluster; for example, if slot 1234 has changed, then element 1234 of the returned
+ * array will be {@code true}.
+ *
+ * @param clusterTopologyChangedEvent the event from which to derive an array of changed slots
+ *
+ * @return an array indicating which slots of changed
+ */
+ public static boolean[] getChangedSlots(final ClusterTopologyChangedEvent clusterTopologyChangedEvent) {
+ final Map beforeNodesById = clusterTopologyChangedEvent.before().stream()
+ .collect(Collectors.toMap(RedisClusterNode::getNodeId, node -> node));
+
+ final Map afterNodesById = clusterTopologyChangedEvent.after().stream()
+ .collect(Collectors.toMap(RedisClusterNode::getNodeId, node -> node));
+
+ final Set nodeIds = new HashSet<>(beforeNodesById.keySet());
+ nodeIds.addAll(afterNodesById.keySet());
+
+ final boolean[] changedSlots = new boolean[SlotHash.SLOT_COUNT];
+
+ for (final String nodeId : nodeIds) {
+ if (beforeNodesById.containsKey(nodeId) && afterNodesById.containsKey(nodeId)) {
+ // This node was present before and after the topology change, but its slots may have changed
+ final boolean[] beforeSlots = new boolean[SlotHash.SLOT_COUNT];
+ beforeNodesById.get(nodeId).getSlots().forEach(slot -> beforeSlots[slot] = true);
+
+ final boolean[] afterSlots = new boolean[SlotHash.SLOT_COUNT];
+ afterNodesById.get(nodeId).getSlots().forEach(slot -> afterSlots[slot] = true);
+
+ for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) {
+ changedSlots[slot] |= beforeSlots[slot] ^ afterSlots[slot];
+ }
+ } else if (beforeNodesById.containsKey(nodeId)) {
+ // The node was present before the topology change, but is gone now; all of its slots should be considered
+ // changed
+ beforeNodesById.get(nodeId).getSlots().forEach(slot -> changedSlots[slot] = true);
+ } else {
+ // The node was present after the change, but wasn't there before; all of its slots should be considered
+ // changed
+ afterNodesById.get(nodeId).getSlots().forEach(slot -> changedSlots[slot] = true);
+ }
+ }
+
+ return changedSlots;
+ }
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
index 68375d341..6d7519f2c 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
@@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Tags;
+import java.util.UUID;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
@@ -19,6 +20,7 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
@@ -47,6 +49,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final PushNotificationManager pushNotificationManager;
private final PushNotificationScheduler pushNotificationScheduler;
private final ClientPresenceManager clientPresenceManager;
+ private final PubSubClientEventManager pubSubClientEventManager;
private final ScheduledExecutorService scheduledExecutorService;
private final Scheduler messageDeliveryScheduler;
private final ClientReleaseManager clientReleaseManager;
@@ -55,12 +58,15 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
+ private transient UUID connectionId;
+
public AuthenticatedConnectListener(ReceiptSender receiptSender,
MessagesManager messagesManager,
MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
ClientPresenceManager clientPresenceManager,
+ PubSubClientEventManager pubSubClientEventManager,
ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager,
@@ -71,6 +77,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
this.pushNotificationManager = pushNotificationManager;
this.pushNotificationScheduler = pushNotificationScheduler;
this.clientPresenceManager = clientPresenceManager;
+ this.pubSubClientEventManager = pubSubClientEventManager;
this.scheduledExecutorService = scheduledExecutorService;
this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager;
@@ -121,6 +128,12 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// It's preferable to start sending push notifications as soon as possible.
RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection));
+ if (connectionId != null) {
+ pubSubClientEventManager.handleClientDisconnected(auth.getAccount().getUuid(),
+ auth.getAuthenticatedDevice().getId(),
+ connectionId);
+ }
+
// Next, we stop listening for inbound messages. If a message arrives after this call, the websocket connection
// will not be notified and will not change its state, but that's okay because it has already closed and
// attempts to deliver mesages via this connection will not succeed.
@@ -147,6 +160,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// Finally, we register this client's presence, which suppresses push notifications. We do this last because
// receiving extra push notifications is generally preferable to missing out on a push notification.
clientPresenceManager.setPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection);
+ pubSubClientEventManager.handleClientConnected(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), null)
+ .thenAccept(connectionId -> this.connectionId = connectionId);
renewPresenceFutureReference.set(scheduledExecutorService.scheduleAtFixedRate(() -> RedisOperation.unchecked(() ->
clientPresenceManager.renewPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())),
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
index e21ace10d..b87281b94 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
@@ -45,6 +45,7 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
+import org.whispersystems.textsecuregcm.push.ClientEventListener;
import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
@@ -63,15 +64,13 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
-public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener {
+public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener, ClientEventListener {
private static final DistributionSummary messageTime = Metrics.summary(
name(MessageController.class, "messageDeliveryDuration"));
private static final DistributionSummary primaryDeviceMessageTime = Metrics.summary(
name(MessageController.class, "primaryDeviceMessageDeliveryDuration"));
private static final Counter sendMessageCounter = Metrics.counter(name(WebSocketConnection.class, "sendMessage"));
- private static final Counter messageAvailableCounter = Metrics.counter(
- name(WebSocketConnection.class, "messagesAvailable"));
private static final Counter messagesPersistedCounter = Metrics.counter(
name(WebSocketConnection.class, "messagesPersisted"));
private static final Counter bytesSentCounter = Metrics.counter(name(WebSocketConnection.class, "bytesSent"));
@@ -91,6 +90,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
"sendMessages");
private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class,
"sendMessageError");
+ private static final String MESSAGE_AVAILABLE_COUNTER_NAME = name(WebSocketConnection.class, "messagesAvailable");
+
+ private static final String PRESENCE_MANAGER_TAG = "presenceManager";
private static final String STATUS_CODE_TAG = "status";
private static final String STATUS_MESSAGE_TAG = "message";
private static final String ERROR_TYPE_TAG = "errorType";
@@ -468,7 +470,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
return false;
}
- messageAvailableCounter.increment();
+ Metrics.counter(MESSAGE_AVAILABLE_COUNTER_NAME,
+ PRESENCE_MANAGER_TAG, "legacy")
+ .increment();
storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE);
@@ -477,6 +481,13 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
return true;
}
+ @Override
+ public void handleNewMessageAvailable() {
+ Metrics.counter(MESSAGE_AVAILABLE_COUNTER_NAME,
+ PRESENCE_MANAGER_TAG, "pubsub")
+ .increment();
+ }
+
@Override
public boolean handleMessagesPersisted() {
if (!client.isOpen()) {
@@ -497,7 +508,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
public void handleDisplacement(final boolean connectedElsewhere) {
final Tags tags = Tags.of(
UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
- Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere))
+ Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)),
+ Tag.of(PRESENCE_MANAGER_TAG, "legacy")
);
Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment();
@@ -522,6 +534,17 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
}
}
+ @Override
+ public void handleConnectionDisplaced(final boolean connectedElsewhere) {
+ final Tags tags = Tags.of(
+ UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
+ Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)),
+ Tag.of(PRESENCE_MANAGER_TAG, "pubsub")
+ );
+
+ Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment();
+ }
+
private record StoredMessageInfo(UUID guid, long serverTimestamp) {
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java
index 5db958f84..c86d68d17 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java
@@ -31,10 +31,12 @@ import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
+import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
import org.whispersystems.textsecuregcm.push.APNSender;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender;
@@ -141,6 +143,8 @@ record CommandDependencies(
.maxThreads(1).minThreads(1).build();
ExecutorService fcmSenderExecutor = environment.lifecycle().executorService(name(name, "fcmSender-%d"))
.maxThreads(16).minThreads(16).build();
+ ExecutorService clientEventExecutor = environment.lifecycle()
+ .virtualExecutorService(name(name, "clientEvent-%d"));
ScheduledExecutorService secureValueRecoveryServiceRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "secureValueRecoveryServiceRetry-%d")).threads(1).build();
@@ -214,6 +218,9 @@ record CommandDependencies(
storageServiceExecutor, storageServiceRetryExecutor, configuration.getSecureStorageServiceConfiguration());
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster,
recurringJobExecutor, keyspaceNotificationDispatchExecutor);
+ ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(
+ dynamicConfigurationManager);
+ PubSubClientEventManager pubSubClientEventManager = new PubSubClientEventManager(messagesCluster, clientEventExecutor, experimentEnrollmentManager);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
@@ -230,7 +237,7 @@ record CommandDependencies(
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
pubsubClient, accountLockManager, keys, messagesManager, profilesManager,
- secureStorageClient, secureValueRecovery2Client, clientPresenceManager,
+ secureStorageClient, secureValueRecovery2Client, clientPresenceManager, pubSubClientEventManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor,
clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(),
@@ -269,6 +276,7 @@ record CommandDependencies(
environment.lifecycle().manage(apnSender);
environment.lifecycle().manage(messagesCache);
environment.lifecycle().manage(clientPresenceManager);
+ environment.lifecycle().manage(pubSubClientEventManager);
environment.lifecycle().manage(new ManagedAwsCrt());
return new CommandDependencies(
diff --git a/service/src/main/proto/ClientPresence.proto b/service/src/main/proto/ClientPresence.proto
new file mode 100644
index 000000000..b2141009d
--- /dev/null
+++ b/service/src/main/proto/ClientPresence.proto
@@ -0,0 +1,38 @@
+/**
+ * Copyright 2024 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+syntax = "proto3";
+
+package org.signal.chat.presence;
+
+option java_package = "org.whispersystems.textsecuregcm.push";
+option java_multiple_files = true;
+
+message ClientEvent {
+ oneof event {
+ NewMessageAvailableEvent new_message_available = 1;
+ ClientConnectedEvent client_connected = 2;
+ DisconnectRequested disconnect_requested = 3;
+ }
+}
+
+/**
+ * Indicates that a new message is available for the client to retrieve.
+ */
+message NewMessageAvailableEvent {
+}
+
+/**
+ * Indicates that a client has connected to the presence system.
+ */
+message ClientConnectedEvent {
+ bytes connection_id = 1;
+}
+
+/**
+ * Indicates that the server has requested that the client disconnect due to
+ * (for example) account lifecycle events.
+ */
+message DisconnectRequested {
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java
index 166e86438..5d5eccaab 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java
@@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.auth;
-import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
@@ -61,6 +60,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -97,6 +97,7 @@ class LinkedDeviceRefreshRequirementProviderTest {
private AccountsManager accountsManager;
private ClientPresenceManager clientPresenceManager;
+ private PubSubClientEventManager pubSubClientEventManager;
private LinkedDeviceRefreshRequirementProvider provider;
@@ -104,11 +105,12 @@ class LinkedDeviceRefreshRequirementProviderTest {
void setup() {
accountsManager = mock(AccountsManager.class);
clientPresenceManager = mock(ClientPresenceManager.class);
+ pubSubClientEventManager = mock(PubSubClientEventManager.class);
provider = new LinkedDeviceRefreshRequirementProvider(accountsManager);
final WebsocketRefreshRequestEventListener listener =
- new WebsocketRefreshRequestEventListener(clientPresenceManager, provider);
+ new WebsocketRefreshRequestEventListener(clientPresenceManager, pubSubClientEventManager, provider);
when(applicationEventListener.onRequest(any())).thenReturn(listener);
@@ -146,6 +148,10 @@ class LinkedDeviceRefreshRequirementProviderTest {
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 1);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 2);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 3);
+
+ verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 1));
+ verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 2));
+ verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 3));
}
@ParameterizedTest
@@ -173,8 +179,10 @@ class LinkedDeviceRefreshRequirementProviderTest {
assertEquals(200, response.getStatus());
- initialDeviceIds.forEach(deviceId ->
- verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId));
+ initialDeviceIds.forEach(deviceId -> {
+ verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId);
+ verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(deviceId));
+ });
verifyNoMoreInteractions(clientPresenceManager);
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java
index 9911bd48c..0bb77228e 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java
@@ -28,6 +28,7 @@ import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.EnumSet;
+import java.util.List;
import java.util.Optional;
import java.util.UUID;
import javax.servlet.DispatcherType;
@@ -47,6 +48,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -74,6 +76,7 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class);
private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
private static final ClientPresenceManager CLIENT_PRESENCE = mock(ClientPresenceManager.class);
+ private static final PubSubClientEventManager PUBSUB_CLIENT_PRESENCE = mock(PubSubClientEventManager.class);
private WebSocketClient client;
private final Account account1 = new Account();
@@ -122,9 +125,9 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.jersey()
- .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
+ .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE, PUBSUB_CLIENT_PRESENCE));
environment.jersey()
- .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
+ .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE, PUBSUB_CLIENT_PRESENCE));
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
@@ -215,6 +218,10 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
+
+ verify(PUBSUB_CLIENT_PRESENCE, timeout(5000))
+ .requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
+ verifyNoMoreInteractions(PUBSUB_CLIENT_PRESENCE);
}
@Test
@@ -231,6 +238,10 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
+
+ verify(PUBSUB_CLIENT_PRESENCE, timeout(5000))
+ .requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
+ verifyNoMoreInteractions(PUBSUB_CLIENT_PRESENCE);
}
@ParameterizedTest
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java
index a6d0c0949..d922e526b 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java
@@ -35,6 +35,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@@ -47,6 +48,7 @@ class RegistrationLockVerificationManagerTest {
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class);
+ private final PubSubClientEventManager pubSubClientEventManager = mock(PubSubClientEventManager.class);
private final ExternalServiceCredentialsGenerator svr2CredentialsGenerator = mock(
ExternalServiceCredentialsGenerator.class);
private final ExternalServiceCredentialsGenerator svr3CredentialsGenerator = mock(
@@ -56,7 +58,7 @@ class RegistrationLockVerificationManagerTest {
private static PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
- accountsManager, clientPresenceManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
+ accountsManager, clientPresenceManager, pubSubClientEventManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters);
private final RateLimiter pinLimiter = mock(RateLimiter.class);
@@ -108,6 +110,7 @@ class RegistrationLockVerificationManagerTest {
verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber());
}
verify(clientPresenceManager).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID));
+ verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(Device.PRIMARY_ID));
try {
verify(pushNotificationManager).sendAttemptLoginNotification(any(), eq("failedRegistrationLock"));
} catch (NotPushRegisteredException npre) {}
@@ -131,6 +134,7 @@ class RegistrationLockVerificationManagerTest {
} catch (NotPushRegisteredException npre) {}
verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber());
verify(clientPresenceManager, never()).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID));
+ verify(pubSubClientEventManager, never()).requestDisconnection(any(), any());
});
}
};
@@ -169,6 +173,7 @@ class RegistrationLockVerificationManagerTest {
verify(account, never()).lockAuthTokenHash();
verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber());
verify(clientPresenceManager, never()).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID));
+ verify(pubSubClientEventManager, never()).requestDisconnection(any(), any());
}
static Stream testSuccess() {
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java
index d54a41651..9078b14df 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java
@@ -80,6 +80,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
@@ -111,6 +112,7 @@ class DeviceControllerTest {
private static final Account maxedAccount = mock(Account.class);
private static final Device primaryDevice = mock(Device.class);
private static final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class);
+ private static final PubSubClientEventManager pubSubClientEventManager = mock(PubSubClientEventManager.class);
private static final Map deviceConfiguration = new HashMap<>();
private static final TestClock testClock = TestClock.now();
@@ -131,7 +133,8 @@ class DeviceControllerTest {
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.addProvider(new RateLimitExceededExceptionMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
- .addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager))
+ .addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager,
+ pubSubClientEventManager))
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(deviceController)
.build();
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java
index 257ffd7bc..344505e47 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java
@@ -21,6 +21,7 @@ import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -39,6 +40,7 @@ class MessageSenderTest {
private MessageProtos.Envelope message;
private ClientPresenceManager clientPresenceManager;
+ private PubSubClientEventManager pubSubClientEventManager;
private MessagesManager messagesManager;
private PushNotificationManager pushNotificationManager;
private MessageSender messageSender;
@@ -54,9 +56,14 @@ class MessageSenderTest {
message = generateRandomMessage();
clientPresenceManager = mock(ClientPresenceManager.class);
+ pubSubClientEventManager = mock(PubSubClientEventManager.class);
messagesManager = mock(MessagesManager.class);
pushNotificationManager = mock(PushNotificationManager.class);
- messageSender = new MessageSender(clientPresenceManager, messagesManager, pushNotificationManager);
+
+ when(pubSubClientEventManager.handleNewMessageAvailable(any(), anyByte()))
+ .thenReturn(CompletableFuture.completedFuture(true));
+
+ messageSender = new MessageSender(clientPresenceManager, pubSubClientEventManager, messagesManager, pushNotificationManager);
when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(device.getId()).thenReturn(DEVICE_ID);
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManagerTest.java
new file mode 100644
index 000000000..38a04aa83
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManagerTest.java
@@ -0,0 +1,337 @@
+/*
+ * Copyright 2024 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.push;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import io.lettuce.core.cluster.SlotHash;
+import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
+import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
+import io.lettuce.core.cluster.pubsub.api.async.RedisClusterPubSubAsyncCommands;
+import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.IntStream;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
+import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
+import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
+import org.whispersystems.textsecuregcm.storage.Device;
+import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
+import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
+
+@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
+class PubSubClientEventManagerTest {
+
+ private PubSubClientEventManager localPresenceManager;
+ private PubSubClientEventManager remotePresenceManager;
+
+ private static ExecutorService clientEventExecutor;
+
+ @RegisterExtension
+ static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
+
+ private static class ClientEventAdapter implements ClientEventListener {
+
+ @Override
+ public void handleNewMessageAvailable() {
+ }
+
+ @Override
+ public void handleConnectionDisplaced(final boolean connectedElsewhere) {
+ }
+ }
+
+ @BeforeAll
+ static void setUpBeforeAll() {
+ clientEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
+ }
+
+ @BeforeEach
+ void setUp() {
+ final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
+ when(experimentEnrollmentManager.isEnrolled(any(UUID.class), any())).thenReturn(true);
+
+ localPresenceManager = new PubSubClientEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor, experimentEnrollmentManager);
+ remotePresenceManager = new PubSubClientEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor, experimentEnrollmentManager);
+
+ localPresenceManager.start();
+ remotePresenceManager.start();
+ }
+
+ @AfterEach
+ void tearDown() {
+ localPresenceManager.stop();
+ remotePresenceManager.stop();
+ }
+
+ @AfterAll
+ static void tearDownAfterAll() {
+ clientEventExecutor.shutdown();
+ }
+
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void handleClientConnected(final boolean displaceRemotely) throws InterruptedException {
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = Device.PRIMARY_ID;
+
+ final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false);
+ final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false);
+
+ final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false);
+
+ localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
+ @Override
+ public void handleConnectionDisplaced(final boolean connectedElsewhere) {
+ synchronized (firstListenerDisplaced) {
+ firstListenerDisplaced.set(true);
+ firstListenerConnectedElsewhere.set(connectedElsewhere);
+
+ firstListenerDisplaced.notifyAll();
+ }
+ }
+ }).toCompletableFuture().join();
+
+ assertFalse(firstListenerDisplaced.get());
+ assertFalse(secondListenerDisplaced.get());
+
+ final PubSubClientEventManager displacingManager =
+ displaceRemotely ? remotePresenceManager : localPresenceManager;
+
+ displacingManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
+ @Override
+ public void handleConnectionDisplaced(final boolean connectedElsewhere) {
+ secondListenerDisplaced.set(true);
+ }
+ }).toCompletableFuture().join();
+
+ synchronized (firstListenerDisplaced) {
+ while (!firstListenerDisplaced.get()) {
+ firstListenerDisplaced.wait();
+ }
+ }
+
+ assertTrue(firstListenerDisplaced.get());
+ assertFalse(secondListenerDisplaced.get());
+
+ assertTrue(firstListenerConnectedElsewhere.get());
+ }
+
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void handleNewMessageAvailable(final boolean messageAvailableRemotely) throws InterruptedException {
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = Device.PRIMARY_ID;
+
+ final AtomicBoolean messageReceived = new AtomicBoolean(false);
+
+ localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
+ @Override
+ public void handleNewMessageAvailable() {
+ synchronized (messageReceived) {
+ messageReceived.set(true);
+ messageReceived.notifyAll();
+ }
+ }
+ }).toCompletableFuture().join();
+
+ final PubSubClientEventManager messagePresenceManager =
+ messageAvailableRemotely ? remotePresenceManager : localPresenceManager;
+
+ assertTrue(messagePresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
+
+ synchronized (messageReceived) {
+ while (!messageReceived.get()) {
+ messageReceived.wait();
+ }
+ }
+
+ assertTrue(messageReceived.get());
+ }
+
+ @Test
+ void handleClientDisconnected() {
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = Device.PRIMARY_ID;
+
+ final UUID connectionId =
+ localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter())
+ .toCompletableFuture().join();
+
+ assertTrue(localPresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
+
+ localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId, connectionId).toCompletableFuture().join();
+
+ assertFalse(localPresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
+ }
+
+ @Test
+ void isLocallyPresent() {
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = Device.PRIMARY_ID;
+
+ assertFalse(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
+ assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
+
+ final UUID connectionId =
+ localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter())
+ .toCompletableFuture()
+ .join();
+
+ assertTrue(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
+ assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
+
+ localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId, connectionId)
+ .toCompletableFuture()
+ .join();
+
+ assertFalse(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
+ assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
+ }
+
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void requestDisconnection(final boolean requestDisconnectionRemotely) throws InterruptedException {
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte firstDeviceId = Device.PRIMARY_ID;
+ final byte secondDeviceId = firstDeviceId + 1;
+
+ final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false);
+ final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false);
+
+ final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false);
+
+ localPresenceManager.handleClientConnected(accountIdentifier, firstDeviceId, new ClientEventAdapter() {
+ @Override
+ public void handleConnectionDisplaced(final boolean connectedElsewhere) {
+ synchronized (firstListenerDisplaced) {
+ firstListenerDisplaced.set(true);
+ firstListenerConnectedElsewhere.set(connectedElsewhere);
+
+ firstListenerDisplaced.notifyAll();
+ }
+ }
+ }).toCompletableFuture().join();
+
+ localPresenceManager.handleClientConnected(accountIdentifier, secondDeviceId, new ClientEventAdapter() {
+ @Override
+ public void handleConnectionDisplaced(final boolean connectedElsewhere) {
+ synchronized (secondListenerDisplaced) {
+ secondListenerDisplaced.set(true);
+ secondListenerDisplaced.notifyAll();
+ }
+ }
+ }).toCompletableFuture().join();
+
+ assertFalse(firstListenerDisplaced.get());
+ assertFalse(secondListenerDisplaced.get());
+
+ final PubSubClientEventManager displacingManager =
+ requestDisconnectionRemotely ? remotePresenceManager : localPresenceManager;
+
+ displacingManager.requestDisconnection(accountIdentifier, List.of(firstDeviceId)).toCompletableFuture().join();
+
+ synchronized (firstListenerDisplaced) {
+ while (!firstListenerDisplaced.get()) {
+ firstListenerDisplaced.wait();
+ }
+ }
+
+ assertTrue(firstListenerDisplaced.get());
+ assertFalse(secondListenerDisplaced.get());
+
+ assertFalse(firstListenerConnectedElsewhere.get());
+ }
+
+ @Test
+ void resubscribe() {
+ final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
+ when(experimentEnrollmentManager.isEnrolled(any(UUID.class), any())).thenReturn(true);
+
+ @SuppressWarnings("unchecked") final RedisClusterPubSubCommands pubSubCommands =
+ mock(RedisClusterPubSubCommands.class);
+
+ @SuppressWarnings("unchecked") final RedisClusterPubSubAsyncCommands pubSubAsyncCommands =
+ mock(RedisClusterPubSubAsyncCommands.class);
+
+ when(pubSubAsyncCommands.ssubscribe(any())).thenReturn(MockRedisFuture.completedFuture(null));
+
+ final FaultTolerantRedisClusterClient clusterClient = RedisClusterHelper.builder()
+ .binaryPubSubCommands(pubSubCommands)
+ .binaryPubSubAsyncCommands(pubSubAsyncCommands)
+ .build();
+
+ final PubSubClientEventManager presenceManager =
+ new PubSubClientEventManager(clusterClient, Runnable::run, experimentEnrollmentManager);
+
+ presenceManager.start();
+
+ final UUID firstAccountIdentifier = UUID.randomUUID();
+ final byte firstDeviceId = Device.PRIMARY_ID;
+ final int firstSlot = SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(firstAccountIdentifier, firstDeviceId));
+
+ final UUID secondAccountIdentifier;
+ final byte secondDeviceId = firstDeviceId + 1;
+
+ // Make sure that the two subscriptions wind up in different slots
+ {
+ UUID candidateIdentifier;
+
+ do {
+ candidateIdentifier = UUID.randomUUID();
+ } while (SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(candidateIdentifier, secondDeviceId)) == firstSlot);
+
+ secondAccountIdentifier = candidateIdentifier;
+ }
+
+ presenceManager.handleClientConnected(firstAccountIdentifier, firstDeviceId, new ClientEventAdapter()).toCompletableFuture().join();
+ presenceManager.handleClientConnected(secondAccountIdentifier, secondDeviceId, new ClientEventAdapter()).toCompletableFuture().join();
+
+ final int secondSlot = SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(secondAccountIdentifier, secondDeviceId));
+
+ final String firstNodeId = UUID.randomUUID().toString();
+
+ final RedisClusterNode firstBeforeNode = mock(RedisClusterNode.class);
+ when(firstBeforeNode.getNodeId()).thenReturn(firstNodeId);
+ when(firstBeforeNode.getSlots()).thenReturn(IntStream.range(0, SlotHash.SLOT_COUNT).boxed().toList());
+
+ final RedisClusterNode firstAfterNode = mock(RedisClusterNode.class);
+ when(firstAfterNode.getNodeId()).thenReturn(firstNodeId);
+ when(firstAfterNode.getSlots()).thenReturn(IntStream.range(0, SlotHash.SLOT_COUNT)
+ .filter(slot -> slot != secondSlot)
+ .boxed()
+ .toList());
+
+ final RedisClusterNode secondAfterNode = mock(RedisClusterNode.class);
+ when(secondAfterNode.getNodeId()).thenReturn(UUID.randomUUID().toString());
+ when(secondAfterNode.getSlots()).thenReturn(List.of(secondSlot));
+
+ presenceManager.resubscribe(new ClusterTopologyChangedEvent(
+ List.of(firstBeforeNode),
+ List.of(firstAfterNode, secondAfterNode)));
+
+ verify(pubSubCommands).ssubscribe(PubSubClientEventManager.getClientPresenceKey(secondAccountIdentifier, secondDeviceId));
+ verify(pubSubCommands, never()).ssubscribe(PubSubClientEventManager.getClientPresenceKey(firstAccountIdentifier, firstDeviceId));
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java
index cd5e60779..72c54618b 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java
@@ -31,6 +31,7 @@ import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Consumer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
@@ -46,7 +47,7 @@ class FaultTolerantPubSubClusterConnectionTest {
private TestPublisher eventPublisher;
- private Runnable resubscribe;
+ private Consumer resubscribe;
private AtomicInteger resubscribeCounter;
private CountDownLatch resubscribeFailure;
@@ -93,7 +94,7 @@ class FaultTolerantPubSubClusterConnectionTest {
resubscribeCounter = new AtomicInteger();
- resubscribe = () -> {
+ resubscribe = event -> {
try {
resubscribeCounter.incrementAndGet();
pubSubConnection.sync().nodes((ignored) -> true);
@@ -137,7 +138,7 @@ class FaultTolerantPubSubClusterConnectionTest {
void testFilterClusterTopologyChangeEvents() throws InterruptedException {
final CountDownLatch topologyEventLatch = new CountDownLatch(1);
- faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(topologyEventLatch::countDown);
+ faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(event -> topologyEventLatch.countDown());
final RedisClusterNode nodeFromDifferentCluster = mock(RedisClusterNode.class);
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java
index cbf444092..eaee1f505 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java
@@ -44,6 +44,7 @@ import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -152,6 +153,7 @@ public class AccountCreationDeletionIntegrationTest {
secureStorageClient,
svr2Client,
mock(ClientPresenceManager.class),
+ mock(PubSubClientEventManager.class),
registrationRecoveryPasswordsManager,
clientPublicKeysManager,
accountLockExecutor,
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java
index a9fcfc94a..274903c7b 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java
@@ -37,6 +37,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -67,6 +68,7 @@ class AccountsManagerChangeNumberIntegrationTest {
private KeysManager keysManager;
private ClientPresenceManager clientPresenceManager;
+ private PubSubClientEventManager pubSubClientEventManager;
private ExecutorService accountLockExecutor;
private ExecutorService clientPresenceExecutor;
@@ -119,6 +121,7 @@ class AccountsManagerChangeNumberIntegrationTest {
when(svr2Client.deleteBackups(any())).thenReturn(CompletableFuture.completedFuture(null));
clientPresenceManager = mock(ClientPresenceManager.class);
+ pubSubClientEventManager = mock(PubSubClientEventManager.class);
final PhoneNumberIdentifiers phoneNumberIdentifiers =
new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.PNI.tableName());
@@ -147,6 +150,7 @@ class AccountsManagerChangeNumberIntegrationTest {
secureStorageClient,
svr2Client,
clientPresenceManager,
+ pubSubClientEventManager,
registrationRecoveryPasswordsManager,
clientPublicKeysManager,
accountLockExecutor,
@@ -281,7 +285,8 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(secondNumber, accountsManager.getByAccountIdentifier(originalUuid).map(Account::getNumber).orElseThrow());
- verify(clientPresenceManager).disconnectPresence(existingAccountUuid, Device.PRIMARY_ID);
+ verify(clientPresenceManager).disconnectAllPresencesForUuid(existingAccountUuid);
+ verify(pubSubClientEventManager).requestDisconnection(existingAccountUuid);
assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalNumber));
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondNumber));
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java
index a8d15bce6..1ba0a8e8b 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java
@@ -49,6 +49,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@@ -134,6 +135,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(SecureStorageClient.class),
mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class),
+ mock(PubSubClientEventManager.class),
mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class),
mock(Executor.class),
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java
index 3dfe355cf..228ebef2e 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java
@@ -15,6 +15,7 @@ import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -63,6 +64,7 @@ public class AccountsManagerDeviceTransferIntegrationTest {
mock(SecureStorageClient.class),
mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class),
+ mock(PubSubClientEventManager.class),
mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class),
mock(ExecutorService.class),
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java
index d92233584..f435d35a0 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java
@@ -80,6 +80,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -118,6 +119,7 @@ class AccountsManagerTest {
private MessagesManager messagesManager;
private ProfilesManager profilesManager;
private ClientPresenceManager clientPresenceManager;
+ private PubSubClientEventManager pubSubClientEventManager;
private ClientPublicKeysManager clientPublicKeysManager;
private Map phoneNumberIdentifiersByE164;
@@ -153,6 +155,7 @@ class AccountsManagerTest {
messagesManager = mock(MessagesManager.class);
profilesManager = mock(ProfilesManager.class);
clientPresenceManager = mock(ClientPresenceManager.class);
+ pubSubClientEventManager = mock(PubSubClientEventManager.class);
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
dynamicConfiguration = mock(DynamicConfiguration.class);
@@ -259,6 +262,7 @@ class AccountsManagerTest {
storageClient,
svr2Client,
clientPresenceManager,
+ pubSubClientEventManager,
registrationRecoveryPasswordsManager,
clientPublicKeysManager,
mock(Executor.class),
@@ -799,6 +803,7 @@ class AccountsManagerTest {
verify(keysManager).buildWriteItemsForRemovedDevice(account.getUuid(), account.getPhoneNumberIdentifier(), linkedDevice.getId());
verify(clientPublicKeysManager).buildTransactWriteItemForDeletion(account.getUuid(), linkedDevice.getId());
verify(clientPresenceManager).disconnectPresence(account.getUuid(), linkedDevice.getId());
+ verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(linkedDevice.getId()));
}
@Test
@@ -817,6 +822,7 @@ class AccountsManagerTest {
verify(messagesManager, never()).clear(any(), anyByte());
verify(keysManager, never()).deleteSingleUsePreKeys(any(), anyByte());
verify(clientPresenceManager, never()).disconnectPresence(any(), anyByte());
+ verify(pubSubClientEventManager, never()).requestDisconnection(any(), any());
}
@Test
@@ -886,6 +892,7 @@ class AccountsManagerTest {
verify(messagesManager, times(2)).clear(existingUuid);
verify(profilesManager, times(2)).deleteAll(existingUuid);
verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid);
+ verify(pubSubClientEventManager).requestDisconnection(existingUuid);
}
@Test
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java
index 918a9c275..ebf9f506d 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java
@@ -36,6 +36,7 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -146,6 +147,7 @@ class AccountsManagerUsernameIntegrationTest {
mock(SecureStorageClient.class),
mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class),
+ mock(PubSubClientEventManager.class),
mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class),
Executors.newSingleThreadExecutor(),
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java
index e05cfcaa0..27dbcd548 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java
@@ -34,6 +34,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -152,6 +153,7 @@ public class AddRemoveDeviceIntegrationTest {
secureStorageClient,
svr2Client,
mock(ClientPresenceManager.class),
+ mock(PubSubClientEventManager.class),
registrationRecoveryPasswordsManager,
clientPublicKeysManager,
accountLockExecutor,
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java
index b6b7ac306..86343e299 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java
@@ -14,8 +14,12 @@ import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
+import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection;
+import io.lettuce.core.cluster.pubsub.api.async.RedisClusterPubSubAsyncCommands;
+import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands;
import java.util.function.Consumer;
import java.util.function.Function;
+import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
public class RedisClusterHelper {
@@ -30,7 +34,12 @@ public class RedisClusterHelper {
final RedisAdvancedClusterAsyncCommands stringAsyncCommands,
final RedisAdvancedClusterCommands binaryCommands,
final RedisAdvancedClusterAsyncCommands binaryAsyncCommands,
- final RedisAdvancedClusterReactiveCommands binaryReactiveCommands) {
+ final RedisAdvancedClusterReactiveCommands binaryReactiveCommands,
+ final RedisClusterPubSubCommands stringPubSubCommands,
+ final RedisClusterPubSubAsyncCommands stringAsyncPubSubCommands,
+ final RedisClusterPubSubCommands binaryPubSubCommands,
+ final RedisClusterPubSubAsyncCommands binaryAsyncPubSubCommands) {
+
final FaultTolerantRedisClusterClient cluster = mock(FaultTolerantRedisClusterClient.class);
final StatefulRedisClusterConnection stringConnection = mock(StatefulRedisClusterConnection.class);
final StatefulRedisClusterConnection binaryConnection = mock(StatefulRedisClusterConnection.class);
@@ -59,6 +68,45 @@ public class RedisClusterHelper {
return null;
}).when(cluster).useBinaryCluster(any(Consumer.class));
+ final StatefulRedisClusterPubSubConnection stringPubSubConnection =
+ mock(StatefulRedisClusterPubSubConnection.class);
+
+ final StatefulRedisClusterPubSubConnection binaryPubSubConnection =
+ mock(StatefulRedisClusterPubSubConnection.class);
+
+ final FaultTolerantPubSubClusterConnection faultTolerantPubSubClusterConnection =
+ mock(FaultTolerantPubSubClusterConnection.class);
+
+ final FaultTolerantPubSubClusterConnection faultTolerantBinaryPubSubClusterConnection =
+ mock(FaultTolerantPubSubClusterConnection.class);
+
+ when(stringPubSubConnection.sync()).thenReturn(stringPubSubCommands);
+ when(stringPubSubConnection.async()).thenReturn(stringAsyncPubSubCommands);
+ when(binaryPubSubConnection.sync()).thenReturn(binaryPubSubCommands);
+ when(binaryPubSubConnection.async()).thenReturn(binaryAsyncPubSubCommands);
+
+ when(cluster.createPubSubConnection()).thenReturn(faultTolerantPubSubClusterConnection);
+ when(cluster.createBinaryPubSubConnection()).thenReturn(faultTolerantBinaryPubSubClusterConnection);
+
+ when(faultTolerantPubSubClusterConnection.withPubSubConnection(any(Function.class))).thenAnswer(invocation -> {
+ return invocation.getArgument(0, Function.class).apply(stringPubSubConnection);
+ });
+
+ doAnswer(invocation -> {
+ invocation.getArgument(0, Consumer.class).accept(stringPubSubConnection);
+ return null;
+ }).when(faultTolerantPubSubClusterConnection).usePubSubConnection(any(Consumer.class));
+
+ when(faultTolerantBinaryPubSubClusterConnection.withPubSubConnection(any(Function.class))).thenAnswer(
+ invocation -> {
+ return invocation.getArgument(0, Function.class).apply(binaryPubSubConnection);
+ });
+
+ doAnswer(invocation -> {
+ invocation.getArgument(0, Consumer.class).accept(binaryPubSubConnection);
+ return null;
+ }).when(faultTolerantBinaryPubSubClusterConnection).usePubSubConnection(any(Consumer.class));
+
return cluster;
}
@@ -77,6 +125,18 @@ public class RedisClusterHelper {
private RedisAdvancedClusterReactiveCommands binaryReactiveCommands =
mock(RedisAdvancedClusterReactiveCommands.class);
+ private RedisClusterPubSubCommands stringPubSubCommands =
+ mock(RedisClusterPubSubCommands.class);
+
+ private RedisClusterPubSubCommands binaryPubSubCommands =
+ mock(RedisClusterPubSubCommands.class);
+
+ private RedisClusterPubSubAsyncCommands stringPubSubAsyncCommands =
+ mock(RedisClusterPubSubAsyncCommands.class);
+
+ private RedisClusterPubSubAsyncCommands binaryPubSubAsyncCommands =
+ mock(RedisClusterPubSubAsyncCommands.class);
+
private Builder() {
}
@@ -107,9 +167,33 @@ public class RedisClusterHelper {
return this;
}
+ public Builder stringPubSubCommands(final RedisClusterPubSubCommands stringPubSubCommands) {
+ this.stringPubSubCommands = stringPubSubCommands;
+ return this;
+ }
+
+ public Builder binaryPubSubCommands(final RedisClusterPubSubCommands binaryPubSubCommands) {
+ this.binaryPubSubCommands = binaryPubSubCommands;
+ return this;
+ }
+
+ public Builder stringPubSubAsyncCommands(
+ final RedisClusterPubSubAsyncCommands stringPubSubAsyncCommands) {
+ this.stringPubSubAsyncCommands = stringPubSubAsyncCommands;
+ return this;
+ }
+
+ public Builder binaryPubSubAsyncCommands(
+ final RedisClusterPubSubAsyncCommands binaryPubSubAsyncCommands) {
+ this.binaryPubSubAsyncCommands = binaryPubSubAsyncCommands;
+ return this;
+ }
+
public FaultTolerantRedisClusterClient build() {
- return RedisClusterHelper.buildMockRedisCluster(stringCommands, stringAsyncCommands, binaryCommands, binaryAsyncCommands,
- binaryReactiveCommands);
+ return RedisClusterHelper.buildMockRedisCluster(stringCommands, stringAsyncCommands, binaryCommands,
+ binaryAsyncCommands,
+ binaryReactiveCommands, stringPubSubCommands, stringPubSubAsyncCommands, binaryPubSubCommands,
+ binaryPubSubAsyncCommands);
}
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java
index 465d6aac1..260362705 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java
@@ -5,17 +5,199 @@
package org.whispersystems.textsecuregcm.util;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.SlotHash;
+import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
+import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
class RedisClusterUtilTest {
- @Test
- void testGetMinimalHashTag() {
- for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) {
- assertEquals(slot, SlotHash.getSlot(RedisClusterUtil.getMinimalHashTag(slot)));
- }
+ @Test
+ void testGetMinimalHashTag() {
+ for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) {
+ assertEquals(slot, SlotHash.getSlot(RedisClusterUtil.getMinimalHashTag(slot)));
}
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void getChangedSlots(final ClusterTopologyChangedEvent event, final boolean[] expectedSlotsChanged) {
+ assertArrayEquals(expectedSlotsChanged, RedisClusterUtil.getChangedSlots(event));
+ }
+
+ private static List getChangedSlots() {
+ final List arguments = new ArrayList<>();
+
+ // Slot moved from one node to another
+ {
+ final String firstNodeId = UUID.randomUUID().toString();
+ final String secondNodeId = UUID.randomUUID().toString();
+
+ final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
+ when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId);
+ when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192));
+
+ final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
+ when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
+ when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384));
+
+ final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class);
+ when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId);
+ when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8191));
+
+ final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
+ when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
+ when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8191, 16384));
+
+ final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
+ List.of(firstNodeBefore, secondNodeBefore),
+ List.of(firstNodeAfter, secondNodeAfter));
+
+ final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
+ slotsChanged[8191] = true;
+
+ arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
+ }
+
+ // New node added to cluster
+ {
+ final String firstNodeId = UUID.randomUUID().toString();
+ final String secondNodeId = UUID.randomUUID().toString();
+
+ final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
+ when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId);
+ when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192));
+
+ final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
+ when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
+ when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384));
+
+ final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class);
+ when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId);
+ when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8192));
+
+ final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
+ when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
+ when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8192, 12288));
+
+ final RedisClusterNode thirdNodeAfter = mock(RedisClusterNode.class);
+ when(thirdNodeAfter.getNodeId()).thenReturn(UUID.randomUUID().toString());
+ when(thirdNodeAfter.getSlots()).thenReturn(getSlotRange(12288, 16384));
+
+ final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
+ List.of(firstNodeBefore, secondNodeBefore),
+ List.of(firstNodeAfter, secondNodeAfter, thirdNodeAfter));
+
+ final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
+
+ for (int slot = 12288; slot < 16384; slot++) {
+ slotsChanged[slot] = true;
+ }
+
+ arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
+ }
+
+ // Node removed from cluster
+ {
+ final String firstNodeId = UUID.randomUUID().toString();
+ final String secondNodeId = UUID.randomUUID().toString();
+
+ final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
+ when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId);
+ when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192));
+
+ final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
+ when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
+ when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 12288));
+
+ final RedisClusterNode thirdNodeBefore = mock(RedisClusterNode.class);
+ when(thirdNodeBefore.getNodeId()).thenReturn(UUID.randomUUID().toString());
+ when(thirdNodeBefore.getSlots()).thenReturn(getSlotRange(12288, 16384));
+
+ final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class);
+ when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId);
+ when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8192));
+
+ final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
+ when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
+ when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8192, 16384));
+
+ final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
+ List.of(firstNodeBefore, secondNodeBefore, thirdNodeBefore),
+ List.of(firstNodeAfter, secondNodeAfter));
+
+ final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
+
+ for (int slot = 12288; slot < 16384; slot++) {
+ slotsChanged[slot] = true;
+ }
+
+ arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
+ }
+
+ // Node added, node removed, and slot moved
+ // Node removed from cluster
+ {
+ final String secondNodeId = UUID.randomUUID().toString();
+ final String thirdNodeId = UUID.randomUUID().toString();
+
+ final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
+ when(firstNodeBefore.getNodeId()).thenReturn(UUID.randomUUID().toString());
+ when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 1));
+
+ final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
+ when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
+ when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(1, 8192));
+
+ final RedisClusterNode thirdNodeBefore = mock(RedisClusterNode.class);
+ when(thirdNodeBefore.getNodeId()).thenReturn(thirdNodeId);
+ when(thirdNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384));
+
+ final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
+ when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
+ when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8191));
+
+ final RedisClusterNode thirdNodeAfter = mock(RedisClusterNode.class);
+ when(thirdNodeAfter.getNodeId()).thenReturn(thirdNodeId);
+ when(thirdNodeAfter.getSlots()).thenReturn(getSlotRange(8191, 16383));
+
+ final RedisClusterNode fourthNodeAfter = mock(RedisClusterNode.class);
+ when(fourthNodeAfter.getNodeId()).thenReturn(UUID.randomUUID().toString());
+ when(fourthNodeAfter.getSlots()).thenReturn(getSlotRange(16383, 16384));
+
+ final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
+ List.of(firstNodeBefore, secondNodeBefore, thirdNodeBefore),
+ List.of(secondNodeAfter, thirdNodeAfter, fourthNodeAfter));
+
+ final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
+ slotsChanged[0] = true;
+ slotsChanged[8191] = true;
+ slotsChanged[16383] = true;
+
+ arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
+ }
+
+ return arguments;
+ }
+
+ private static List getSlotRange(final int startInclusive, final int endExclusive) {
+ final List slots = new ArrayList<>(endExclusive - startInclusive);
+
+ for (int i = startInclusive; i < endExclusive; i++) {
+ slots.add(i);
+ }
+
+ return slots;
+ }
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java
index 08c83f60e..5ce2d20bd 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java
@@ -58,6 +58,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
@@ -124,8 +125,8 @@ class WebSocketConnectionTest {
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class),
- mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager,
- mock(MessageDeliveryLoopMonitor.class));
+ mock(ClientPresenceManager.class), mock(PubSubClientEventManager.class), retrySchedulingExecutor,
+ messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class));
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))