diff --git a/pom.xml b/pom.xml index bfba1137d..e16a6f7fa 100644 --- a/pom.xml +++ b/pom.xml @@ -56,7 +56,7 @@ 3.0.2 1.9.24 1.5.1 - 6.3.2.RELEASE + 6.4.1.RELEASE 8.13.40 7.3 2.23.1 diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 02b91fc52..b716bd967 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -196,6 +196,7 @@ import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.FcmSender; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.ProvisioningManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -569,6 +570,8 @@ public class WhisperServerService extends Application(AuthenticatedDevice.class)); - environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); + environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager, + pubSubClientEventManager)); environment.jersey().register(new TimestampResponseFilter()); /// @@ -1006,10 +1012,11 @@ public class WhisperServerService extends Application provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000)); - provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); + provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager, + pubSubClientEventManager)); provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager)); provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager)); provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java index 01b3b1567..7b0605cda 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java @@ -27,6 +27,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; @@ -55,6 +56,7 @@ public class RegistrationLockVerificationManager { private final AccountsManager accounts; private final ClientPresenceManager clientPresenceManager; + private final PubSubClientEventManager pubSubClientEventManager; private final ExternalServiceCredentialsGenerator svr2CredentialGenerator; private final ExternalServiceCredentialsGenerator svr3CredentialGenerator; private final RateLimiters rateLimiters; @@ -62,7 +64,9 @@ public class RegistrationLockVerificationManager { private final PushNotificationManager pushNotificationManager; public RegistrationLockVerificationManager( - final AccountsManager accounts, final ClientPresenceManager clientPresenceManager, + final AccountsManager accounts, + final ClientPresenceManager clientPresenceManager, + final PubSubClientEventManager pubSubClientEventManager, final ExternalServiceCredentialsGenerator svr2CredentialGenerator, final ExternalServiceCredentialsGenerator svr3CredentialGenerator, final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, @@ -70,6 +74,7 @@ public class RegistrationLockVerificationManager { final RateLimiters rateLimiters) { this.accounts = accounts; this.clientPresenceManager = clientPresenceManager; + this.pubSubClientEventManager = pubSubClientEventManager; this.svr2CredentialGenerator = svr2CredentialGenerator; this.svr3CredentialGenerator = svr3CredentialGenerator; this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager; @@ -161,6 +166,7 @@ public class RegistrationLockVerificationManager { final List deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList(); clientPresenceManager.disconnectAllPresences(updatedAccount.getUuid(), deviceIds); + pubSubClientEventManager.requestDisconnection(updatedAccount.getUuid(), deviceIds); try { // Send a push notification that prompts the client to attempt login and fail due to locked credentials diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java index 2f54d78ba..a645a790c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java @@ -10,6 +10,7 @@ import org.glassfish.jersey.server.monitoring.ApplicationEventListener; import org.glassfish.jersey.server.monitoring.RequestEvent; import org.glassfish.jersey.server.monitoring.RequestEventListener; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.storage.AccountsManager; /** @@ -20,9 +21,11 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven private final WebsocketRefreshRequestEventListener websocketRefreshRequestEventListener; public WebsocketRefreshApplicationEventListener(final AccountsManager accountsManager, - final ClientPresenceManager clientPresenceManager) { + final ClientPresenceManager clientPresenceManager, + final PubSubClientEventManager pubSubClientEventManager) { this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager, + pubSubClientEventManager, new LinkedDeviceRefreshRequirementProvider(accountsManager), new PhoneNumberChangeRefreshRequirementProvider(accountsManager)); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java index 9fbb84fab..65f8e4dcb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java @@ -10,6 +10,7 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import java.util.Arrays; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import javax.ws.rs.container.ResourceInfo; import javax.ws.rs.core.Context; @@ -19,10 +20,12 @@ import org.glassfish.jersey.server.monitoring.RequestEventListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; public class WebsocketRefreshRequestEventListener implements RequestEventListener { private final ClientPresenceManager clientPresenceManager; + private final PubSubClientEventManager pubSubClientEventManager; private final WebsocketRefreshRequirementProvider[] providers; private static final Counter DISPLACED_ACCOUNTS = Metrics.counter( @@ -35,9 +38,11 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene public WebsocketRefreshRequestEventListener( final ClientPresenceManager clientPresenceManager, + final PubSubClientEventManager pubSubClientEventManager, final WebsocketRefreshRequirementProvider... providers) { this.clientPresenceManager = clientPresenceManager; + this.pubSubClientEventManager = pubSubClientEventManager; this.providers = providers; } @@ -60,6 +65,7 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene try { displacedDevices.incrementAndGet(); clientPresenceManager.disconnectPresence(pair.first(), pair.second()); + pubSubClientEventManager.requestDisconnection(pair.first(), List.of(pair.second())); } catch (final Exception e) { logger.error("Could not displace device presence", e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientEventListener.java new file mode 100644 index 000000000..1697fb545 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientEventListener.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.push; + +/** + * A client event listener handles events related to a client's message-retrieval presence. Handler methods are run on + * dedicated threads and may safely perform blocking operations. + */ +public interface ClientEventListener { + + /** + * Indicates that a new message is available in the connected client's message queue. + */ + void handleNewMessageAvailable(); + + /** + * Indicates that the client's presence has been displaced and the listener should close the client's underlying + * network connection. + * + * @param connectedElsewhere if {@code true}, indicates that the client's presence has been displaced by another + * connection from the same client + */ + void handleConnectionDisplaced(boolean connectedElsewhere); +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java index 80083a4dc..54cbe6be4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java @@ -14,6 +14,7 @@ import io.lettuce.core.RedisFuture; import io.lettuce.core.ScriptOutputType; import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; import io.micrometer.core.instrument.Counter; @@ -277,7 +278,7 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter Metrics.counter(SEND_COUNTER_NAME, + CHANNEL_TAG_NAME, channel, + EPHEMERAL_TAG_NAME, String.valueOf(online), + CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent), + PUB_SUB_CLIENT_ONLINE_TAG_NAME, String.valueOf(Objects.requireNonNullElse(present, false)), + URGENT_TAG_NAME, String.valueOf(message.getUrgent()), + STORY_TAG_NAME, String.valueOf(message.getStory()), + SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId())) + .increment()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManager.java new file mode 100644 index 000000000..e3d84751d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManager.java @@ -0,0 +1,407 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.push; + +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.InvalidProtocolBufferException; +import io.dropwizard.lifecycle.Managed; +import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; +import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tags; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.RedisClusterUtil; +import org.whispersystems.textsecuregcm.util.UUIDUtil; +import org.whispersystems.textsecuregcm.util.Util; +import javax.annotation.Nullable; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; + +/** + * The pub/sub-based client presence manager uses the Redis 7 sharded pub/sub system to notify connected clients that + * new messages are available for retrieval and report to senders whether a client was present to receive a message when + * sent. This system makes a best effort to ensure that a given client has only a single open connection across the + * fleet of servers, but cannot guarantee at-most-one behavior. + */ +public class PubSubClientEventManager extends RedisClusterPubSubAdapter implements Managed { + + private final FaultTolerantRedisClusterClient clusterClient; + private final Executor listenerEventExecutor; + + private final ExperimentEnrollmentManager experimentEnrollmentManager; + static final String EXPERIMENT_NAME = "pubSubPresenceManager"; + + @Nullable + private FaultTolerantPubSubClusterConnection pubSubConnection; + + private final Map listenersByAccountAndDeviceIdentifier; + + private static final byte[] NEW_MESSAGE_EVENT_BYTES = ClientEvent.newBuilder() + .setNewMessageAvailable(NewMessageAvailableEvent.getDefaultInstance()) + .build() + .toByteArray(); + + private static final byte[] DISCONNECT_REQUESTED_EVENT_BYTES = ClientEvent.newBuilder() + .setDisconnectRequested(DisconnectRequested.getDefaultInstance()) + .build() + .toByteArray(); + + private static final Counter PUBLISH_CLIENT_CONNECTION_EVENT_ERROR_COUNTER = + Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "publishClientConnectionEventError")); + + private static final Counter UNSUBSCRIBE_ERROR_COUNTER = + Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "unsubscribeError")); + + private static final Counter MESSAGE_WITHOUT_LISTENER_COUNTER = + Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "messageWithoutListener")); + + private static final String LISTENER_GAUGE_NAME = + MetricsUtil.name(PubSubClientEventManager.class, "listeners"); + + private static final Logger logger = LoggerFactory.getLogger(PubSubClientEventManager.class); + + private record AccountAndDeviceIdentifier(UUID accountIdentifier, byte deviceId) { + } + + private record ConnectionIdAndListener(UUID connectionIdentifier, ClientEventListener listener) { + } + + public PubSubClientEventManager(final FaultTolerantRedisClusterClient clusterClient, + final Executor listenerEventExecutor, + final ExperimentEnrollmentManager experimentEnrollmentManager) { + + this.clusterClient = clusterClient; + this.listenerEventExecutor = listenerEventExecutor; + this.experimentEnrollmentManager = experimentEnrollmentManager; + + this.listenersByAccountAndDeviceIdentifier = + Metrics.gaugeMapSize(LISTENER_GAUGE_NAME, Tags.empty(), new ConcurrentHashMap<>()); + } + + @Override + public synchronized void start() { + this.pubSubConnection = clusterClient.createBinaryPubSubConnection(); + this.pubSubConnection.usePubSubConnection(connection -> connection.addListener(this)); + + pubSubConnection.subscribeToClusterTopologyChangedEvents(this::resubscribe); + } + + @Override + public synchronized void stop() { + if (pubSubConnection != null) { + pubSubConnection.usePubSubConnection(connection -> { + connection.removeListener(this); + connection.close(); + }); + } + + pubSubConnection = null; + } + + /** + * Marks the given device as "present" and registers a listener for new messages and conflicting connections. If the + * given device already has a presence registered with this presence manager instance, that presence is displaced + * immediately and the listener's {@link ClientEventListener#handleConnectionDisplaced(boolean)} method is called. + * + * @param accountIdentifier the account identifier for the newly-connected device + * @param deviceId the ID of the newly-connected device within the given account + * @param listener the listener to notify when new messages or conflicting connections arrive for the newly-conencted + * device + * + * @return a future that yields a connection identifier when the new device's presence has been registered; the future + * may fail if a pub/sub subscription could not be established, in which case callers should close the client's + * connection to the server + */ + public CompletionStage handleClientConnected(final UUID accountIdentifier, final byte deviceId, final ClientEventListener listener) { + if (pubSubConnection == null) { + throw new IllegalStateException("Presence manager not started"); + } + + if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) { + return CompletableFuture.completedFuture(UUID.randomUUID()); + } + + final UUID connectionId = UUID.randomUUID(); + final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId); + final AtomicReference displacedListener = new AtomicReference<>(); + final AtomicReference> subscribeFuture = new AtomicReference<>(); + + // Note that we're relying on some specific implementation details of `ConcurrentHashMap#compute(...)`. In + // particular, the behavioral contract for `ConcurrentHashMap#compute(...)` says: + // + // > The entire method invocation is performed atomically. The supplied function is invoked exactly once per + // > invocation of this method. Some attempted update operations on this map by other threads may be blocked while + // > computation is in progress, so the computation should be short and simple. + // + // This provides a mechanism to make sure that we enqueue subscription/unsubscription operations in the same order + // as adding/removing listeners from the map and helps us avoid races and conflicts. Note that the enqueued + // operation is asynchronous; we're not blocking on it in the scope of the `compute` operation. + listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId), + (key, existingIdAndListener) -> { + subscribeFuture.set(pubSubConnection.withPubSubConnection(connection -> + connection.async().ssubscribe(clientPresenceKey))); + + if (existingIdAndListener != null) { + displacedListener.set(existingIdAndListener.listener()); + } + + return new ConnectionIdAndListener(connectionId, listener); + }); + + if (displacedListener.get() != null) { + listenerEventExecutor.execute(() -> displacedListener.get().handleConnectionDisplaced(true)); + } + + return subscribeFuture.get() + .thenCompose(ignored -> clusterClient.withBinaryCluster(connection -> connection.async() + .spublish(clientPresenceKey, buildClientConnectedMessage(connectionId)))) + .handle((ignored, throwable) -> { + if (throwable != null) { + PUBLISH_CLIENT_CONNECTION_EVENT_ERROR_COUNTER.increment(); + } + + return connectionId; + }); + } + + /** + * Removes the "presence" for the given device. The presence is removed if and only if the given connection ID matches + * the connection ID for the currently-registered presence. Callers should call this method when they have closed or + * intend to close the client's underlying network connection. + * + * @param accountIdentifier the identifier of the account for the disconnected device + * @param deviceId the ID of the disconnected device within the given account + * @param connectionId the ID of the connection that has been closed (or will be closed) + * + * @return a future that completes when the presence has been removed + */ + public CompletionStage handleClientDisconnected(final UUID accountIdentifier, final byte deviceId, final UUID connectionId) { + if (pubSubConnection == null) { + throw new IllegalStateException("Presence manager not started"); + } + + if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) { + return CompletableFuture.completedFuture(null); + } + + final AtomicReference> unsubscribeFuture = new AtomicReference<>(); + + // Note that we're relying on some specific implementation details of `ConcurrentHashMap#compute(...)`. In + // particular, the behavioral contract for `ConcurrentHashMap#compute(...)` says: + // + // > The entire method invocation is performed atomically. The supplied function is invoked exactly once per + // > invocation of this method. Some attempted update operations on this map by other threads may be blocked while + // > computation is in progress, so the computation should be short and simple. + // + // This provides a mechanism to make sure that we enqueue subscription/unsubscription operations in the same order + // as adding/removing listeners from the map and helps us avoid races and conflicts. Note that the enqueued + // operation is asynchronous; we're not blocking on it in the scope of the `compute` operation. + listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId), + (ignored, existingIdAndListener) -> { + final ConnectionIdAndListener remainingIdAndListener; + + if (existingIdAndListener == null) { + remainingIdAndListener = null; + } else if (existingIdAndListener.connectionIdentifier().equals(connectionId)) { + remainingIdAndListener = null; + } else { + remainingIdAndListener = existingIdAndListener; + } + + if (remainingIdAndListener == null) { + // Only unsubscribe if there's no listener remaining + unsubscribeFuture.set(pubSubConnection.withPubSubConnection(connection -> + connection.async().sunsubscribe(getClientPresenceKey(accountIdentifier, deviceId))) + .thenRun(Util.NOOP)); + } else { + unsubscribeFuture.set(CompletableFuture.completedFuture(null)); + } + + return remainingIdAndListener; + }); + + return unsubscribeFuture.get() + .whenComplete((ignored, throwable) -> { + if (throwable != null) { + UNSUBSCRIBE_ERROR_COUNTER.increment(); + } + }); + } + + /** + * Publishes an event notifying a specific device that a new message is available for retrieval. This method indicates + * whether the target device is "present" (i.e. has an active listener). Callers may choose to take follow-up action + * (like sending a push notification) if the target device is not present. + * + * @param accountIdentifier the account identifier of the receiving device + * @param deviceId the ID of the receiving device within the target account + * + * @return a future that yields {@code true} if the target device had an active listener or {@code false} otherwise + */ + public CompletionStage handleNewMessageAvailable(final UUID accountIdentifier, final byte deviceId) { + if (pubSubConnection == null) { + throw new IllegalStateException("Presence manager not started"); + } + + if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) { + return CompletableFuture.completedFuture(false); + } + + return pubSubConnection.withPubSubConnection(connection -> + connection.async().spublish(getClientPresenceKey(accountIdentifier, deviceId), NEW_MESSAGE_EVENT_BYTES)) + .thenApply(listeners -> listeners > 0); + } + + /** + * Tests whether a client with the given account/device is connected to this presence manager instance. + * + * @param accountUuid the account identifier for the client to check + * @param deviceId the ID of the device within the given account + * + * @return {@code true} if a client with the given account/device is connected to this presence manager instance or + * {@code false} if the client is not connected at all or is connected to a different presence manager instance + */ + public boolean isLocallyPresent(final UUID accountUuid, final byte deviceId) { + return listenersByAccountAndDeviceIdentifier.containsKey(new AccountAndDeviceIdentifier(accountUuid, deviceId)); + } + + /** + * Broadcasts a request that all devices associated with the identified account and connected to any client presence + * instance close their network connections. + * + * @param accountIdentifier the account identifier for which to request disconnection + * + * @return a future that completes when the request has been sent + */ + public CompletableFuture requestDisconnection(final UUID accountIdentifier) { + return requestDisconnection(accountIdentifier, Device.ALL_POSSIBLE_DEVICE_IDS); + } + + /** + * Broadcasts a request that the specified devices associated with the identified account and connected to any client + * presence instance close their network connections. + * + * @param accountIdentifier the account identifier for which to request disconnection + * @param deviceIds the IDs of the devices for which to request disconnection + * + * @return a future that completes when the request has been sent + */ + public CompletableFuture requestDisconnection(final UUID accountIdentifier, final Collection deviceIds) { + return CompletableFuture.allOf(deviceIds.stream() + .map(deviceId -> { + final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId); + + return clusterClient.withBinaryCluster(connection -> connection.async() + .spublish(clientPresenceKey, DISCONNECT_REQUESTED_EVENT_BYTES)) + .toCompletableFuture(); + }) + .toArray(CompletableFuture[]::new)); + } + + @VisibleForTesting + void resubscribe(final ClusterTopologyChangedEvent clusterTopologyChangedEvent) { + final boolean[] changedSlots = RedisClusterUtil.getChangedSlots(clusterTopologyChangedEvent); + + final Map> clientPresenceKeysBySlot = new HashMap<>(); + + // Organize subscriptions by slot so we can issue a smaller number of larger resubscription commands + listenersByAccountAndDeviceIdentifier.keySet() + .stream() + .map(accountAndDeviceIdentifier -> getClientPresenceKey(accountAndDeviceIdentifier.accountIdentifier(), accountAndDeviceIdentifier.deviceId())) + .forEach(clientPresenceKey -> { + final int slot = SlotHash.getSlot(clientPresenceKey); + + if (changedSlots[slot]) { + clientPresenceKeysBySlot.computeIfAbsent(slot, ignored -> new ArrayList<>()).add(clientPresenceKey); + } + }); + + // Issue one resubscription command per affected slot + clientPresenceKeysBySlot.forEach((slot, clientPresenceKeys) -> { + if (pubSubConnection != null) { + final byte[][] clientPresenceKeyArray = clientPresenceKeys.toArray(byte[][]::new); + pubSubConnection.usePubSubConnection(connection -> connection.sync().ssubscribe(clientPresenceKeyArray)); + } + }); + } + + @Override + public void smessage(final RedisClusterNode node, final byte[] shardChannel, final byte[] message) { + final ClientEvent clientEvent; + + try { + clientEvent = ClientEvent.parseFrom(message); + } catch (final InvalidProtocolBufferException e) { + logger.error("Failed to parse pub/sub event protobuf", e); + return; + } + + final AccountAndDeviceIdentifier accountAndDeviceIdentifier = parseClientPresenceKey(shardChannel); + + @Nullable final ConnectionIdAndListener connectionIdAndListener = + listenersByAccountAndDeviceIdentifier.get(accountAndDeviceIdentifier); + + if (connectionIdAndListener != null) { + switch (clientEvent.getEventCase()) { + case NEW_MESSAGE_AVAILABLE -> connectionIdAndListener.listener().handleNewMessageAvailable(); + + case CLIENT_CONNECTED -> { + final UUID connectionId = UUIDUtil.fromByteString(clientEvent.getClientConnected().getConnectionId()); + + if (!connectionIdAndListener.connectionIdentifier().equals(connectionId)) { + listenerEventExecutor.execute(() -> + connectionIdAndListener.listener().handleConnectionDisplaced(true)); + } + } + + case DISCONNECT_REQUESTED -> listenerEventExecutor.execute(() -> + connectionIdAndListener.listener().handleConnectionDisplaced(false)); + + default -> logger.warn("Unexpected client event type: {}", clientEvent.getClass()); + } + } else { + MESSAGE_WITHOUT_LISTENER_COUNTER.increment(); + } + } + + private static byte[] buildClientConnectedMessage(final UUID connectionId) { + return ClientEvent.newBuilder() + .setClientConnected(ClientConnectedEvent.newBuilder() + .setConnectionId(UUIDUtil.toByteString(connectionId)) + .build()) + .build() + .toByteArray(); + } + + @VisibleForTesting + static byte[] getClientPresenceKey(final UUID accountIdentifier, final byte deviceId) { + return ("client_presence::{" + accountIdentifier + ":" + deviceId + "}").getBytes(StandardCharsets.UTF_8); + } + + private static AccountAndDeviceIdentifier parseClientPresenceKey(final byte[] clientPresenceKeyBytes) { + final String clientPresenceKey = new String(clientPresenceKeyBytes, StandardCharsets.UTF_8); + final int uuidStart = "client_presence::{".length(); + + final UUID accountIdentifier = UUID.fromString(clientPresenceKey.substring(uuidStart, uuidStart + 36)); + final byte deviceId = Byte.parseByte(clientPresenceKey.substring(uuidStart + 37, clientPresenceKey.length() - 1)); + + return new AccountAndDeviceIdentifier(accountIdentifier, deviceId); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java index fe795d11a..e86eb99f6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java @@ -11,6 +11,7 @@ import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.scheduler.Scheduler; +import java.util.function.Consumer; public class FaultTolerantPubSubClusterConnection extends AbstractFaultTolerantPubSubConnection> { @@ -32,7 +33,7 @@ public class FaultTolerantPubSubClusterConnection extends AbstractFaultTol this.topologyChangedEventScheduler = topologyChangedEventScheduler; } - public void subscribeToClusterTopologyChangedEvents(final Runnable eventHandler) { + public void subscribeToClusterTopologyChangedEvents(final Consumer eventHandler) { usePubSubConnection(connection -> connection.getResources().eventBus().get() .filter(event -> { @@ -53,7 +54,7 @@ public class FaultTolerantPubSubClusterConnection extends AbstractFaultTol resubscribeRetry.executeRunnable(() -> { try { - eventHandler.run(); + eventHandler.accept((ClusterTopologyChangedEvent) event); } catch (final RuntimeException e) { logger.warn("Resubscribe for {} failed", getName(), e); throw e; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java index 7f5248336..5053c7c6b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterClient.java @@ -202,4 +202,11 @@ public class FaultTolerantRedisClusterClient { Schedulers.newSingle(name + "-redisPubSubEvents", true)); } + public FaultTolerantPubSubClusterConnection createBinaryPubSubConnection() { + final StatefulRedisClusterPubSubConnection pubSubConnection = clusterClient.connectPubSub(ByteArrayCodec.INSTANCE); + pubSubConnections.add(pubSubConnection); + + return new FaultTolerantPubSubClusterConnection<>(name, pubSubConnection, topologyChangedEventRetry, + Schedulers.newSingle(name + "-redisPubSubEvents", true)); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 4f402c531..989dcd04b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -76,6 +76,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; @@ -126,6 +127,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen private final SecureStorageClient secureStorageClient; private final SecureValueRecovery2Client secureValueRecovery2Client; private final ClientPresenceManager clientPresenceManager; + private final PubSubClientEventManager pubSubClientEventManager; private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager; private final ClientPublicKeysManager clientPublicKeysManager; private final Executor accountLockExecutor; @@ -205,6 +207,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen final SecureStorageClient secureStorageClient, final SecureValueRecovery2Client secureValueRecovery2Client, final ClientPresenceManager clientPresenceManager, + final PubSubClientEventManager pubSubClientEventManager, final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, final ClientPublicKeysManager clientPublicKeysManager, final Executor accountLockExecutor, @@ -223,6 +226,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen this.secureStorageClient = secureStorageClient; this.secureValueRecovery2Client = secureValueRecovery2Client; this.clientPresenceManager = clientPresenceManager; + this.pubSubClientEventManager = pubSubClientEventManager; this.registrationRecoveryPasswordsManager = requireNonNull(registrationRecoveryPasswordsManager); this.clientPublicKeysManager = clientPublicKeysManager; this.accountLockExecutor = accountLockExecutor; @@ -329,7 +333,10 @@ public class AccountsManager extends RedisPubSubAdapter implemen keysManager.deleteSingleUsePreKeys(pni), messagesManager.clear(aci), profilesManager.deleteAll(aci)) - .thenRunAsync(() -> clientPresenceManager.disconnectAllPresencesForUuid(aci), clientPresenceExecutor) + .thenRunAsync(() -> { + clientPresenceManager.disconnectAllPresencesForUuid(aci); + pubSubClientEventManager.requestDisconnection(aci); + }, clientPresenceExecutor) .thenCompose(ignored -> accounts.reclaimAccount(e.getExistingAccount(), account, additionalWriteItems)) .thenCompose(ignored -> { // We should have cleared all messages before overwriting the old account, but more may have arrived @@ -594,6 +601,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen .whenCompleteAsync((ignored, throwable) -> { if (throwable == null) { RedisOperation.unchecked(() -> clientPresenceManager.disconnectPresence(accountIdentifier, deviceId)); + pubSubClientEventManager.requestDisconnection(accountIdentifier, List.of(deviceId)); } }, clientPresenceExecutor); } @@ -1240,9 +1248,11 @@ public class AccountsManager extends RedisPubSubAdapter implemen registrationRecoveryPasswordsManager.removeForNumber(account.getNumber())) .thenCompose(ignored -> accounts.delete(account.getUuid(), additionalWriteItems)) .thenCompose(ignored -> redisDeleteAsync(account)) - .thenRunAsync(() -> RedisOperation.unchecked(() -> - account.getDevices().forEach(device -> - clientPresenceManager.disconnectPresence(account.getUuid(), device.getId()))), clientPresenceExecutor); + .thenRunAsync(() -> { + RedisOperation.unchecked(() -> clientPresenceManager.disconnectAllPresencesForUuid(account.getUuid())); + + pubSubClientEventManager.requestDisconnection(account.getUuid()); + }, clientPresenceExecutor); } private String getAccountMapKey(String key) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 720cc3fb1..dbb2fd7bd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -13,6 +13,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.lifecycle.Managed; import io.lettuce.core.ZAddArgs; import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; import io.micrometer.core.instrument.Counter; @@ -247,7 +248,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp pubSubConnection.usePubSubConnection(connection -> connection.sync().upstream().commands().unsubscribe()); } - private void resubscribeAll() { + private void resubscribeAll(final ClusterTopologyChangedEvent event) { final Set queueNames; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java index 188e1e4f4..f5fcd9a0f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java @@ -6,6 +6,12 @@ package org.whispersystems.textsecuregcm.util; import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; public class RedisClusterUtil { @@ -38,4 +44,51 @@ public class RedisClusterUtil { public static String getMinimalHashTag(final int slot) { return HASHES_BY_SLOT[slot]; } + + /** + * Returns an array indicating which slots have moved as part of a {@link ClusterTopologyChangedEvent}. The elements + * of the array map to slots in the cluster; for example, if slot 1234 has changed, then element 1234 of the returned + * array will be {@code true}. + * + * @param clusterTopologyChangedEvent the event from which to derive an array of changed slots + * + * @return an array indicating which slots of changed + */ + public static boolean[] getChangedSlots(final ClusterTopologyChangedEvent clusterTopologyChangedEvent) { + final Map beforeNodesById = clusterTopologyChangedEvent.before().stream() + .collect(Collectors.toMap(RedisClusterNode::getNodeId, node -> node)); + + final Map afterNodesById = clusterTopologyChangedEvent.after().stream() + .collect(Collectors.toMap(RedisClusterNode::getNodeId, node -> node)); + + final Set nodeIds = new HashSet<>(beforeNodesById.keySet()); + nodeIds.addAll(afterNodesById.keySet()); + + final boolean[] changedSlots = new boolean[SlotHash.SLOT_COUNT]; + + for (final String nodeId : nodeIds) { + if (beforeNodesById.containsKey(nodeId) && afterNodesById.containsKey(nodeId)) { + // This node was present before and after the topology change, but its slots may have changed + final boolean[] beforeSlots = new boolean[SlotHash.SLOT_COUNT]; + beforeNodesById.get(nodeId).getSlots().forEach(slot -> beforeSlots[slot] = true); + + final boolean[] afterSlots = new boolean[SlotHash.SLOT_COUNT]; + afterNodesById.get(nodeId).getSlots().forEach(slot -> afterSlots[slot] = true); + + for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) { + changedSlots[slot] |= beforeSlots[slot] ^ afterSlots[slot]; + } + } else if (beforeNodesById.containsKey(nodeId)) { + // The node was present before the topology change, but is gone now; all of its slots should be considered + // changed + beforeNodesById.get(nodeId).getSlots().forEach(slot -> changedSlots[slot] = true); + } else { + // The node was present after the change, but wasn't there before; all of its slots should be considered + // changed + afterNodesById.get(nodeId).getSlots().forEach(slot -> changedSlots[slot] = true); + } + } + + return changedSlots; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 68375d341..6d7519f2c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.websocket; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import io.micrometer.core.instrument.Tags; +import java.util.UUID; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -19,6 +20,7 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -47,6 +49,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final PushNotificationManager pushNotificationManager; private final PushNotificationScheduler pushNotificationScheduler; private final ClientPresenceManager clientPresenceManager; + private final PubSubClientEventManager pubSubClientEventManager; private final ScheduledExecutorService scheduledExecutorService; private final Scheduler messageDeliveryScheduler; private final ClientReleaseManager clientReleaseManager; @@ -55,12 +58,15 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final OpenWebSocketCounter openAuthenticatedWebSocketCounter; private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter; + private transient UUID connectionId; + public AuthenticatedConnectListener(ReceiptSender receiptSender, MessagesManager messagesManager, MessageMetrics messageMetrics, PushNotificationManager pushNotificationManager, PushNotificationScheduler pushNotificationScheduler, ClientPresenceManager clientPresenceManager, + PubSubClientEventManager pubSubClientEventManager, ScheduledExecutorService scheduledExecutorService, Scheduler messageDeliveryScheduler, ClientReleaseManager clientReleaseManager, @@ -71,6 +77,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { this.pushNotificationManager = pushNotificationManager; this.pushNotificationScheduler = pushNotificationScheduler; this.clientPresenceManager = clientPresenceManager; + this.pubSubClientEventManager = pubSubClientEventManager; this.scheduledExecutorService = scheduledExecutorService; this.messageDeliveryScheduler = messageDeliveryScheduler; this.clientReleaseManager = clientReleaseManager; @@ -121,6 +128,12 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { // It's preferable to start sending push notifications as soon as possible. RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection)); + if (connectionId != null) { + pubSubClientEventManager.handleClientDisconnected(auth.getAccount().getUuid(), + auth.getAuthenticatedDevice().getId(), + connectionId); + } + // Next, we stop listening for inbound messages. If a message arrives after this call, the websocket connection // will not be notified and will not change its state, but that's okay because it has already closed and // attempts to deliver mesages via this connection will not succeed. @@ -147,6 +160,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { // Finally, we register this client's presence, which suppresses push notifications. We do this last because // receiving extra push notifications is generally preferable to missing out on a push notification. clientPresenceManager.setPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection); + pubSubClientEventManager.handleClientConnected(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), null) + .thenAccept(connectionId -> this.connectionId = connectionId); renewPresenceFutureReference.set(scheduledExecutorService.scheduleAtFixedRate(() -> RedisOperation.unchecked(() -> clientPresenceManager.renewPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index e21ace10d..b87281b94 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -45,6 +45,7 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; +import org.whispersystems.textsecuregcm.push.ClientEventListener; import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; @@ -63,15 +64,13 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; -public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener { +public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener, ClientEventListener { private static final DistributionSummary messageTime = Metrics.summary( name(MessageController.class, "messageDeliveryDuration")); private static final DistributionSummary primaryDeviceMessageTime = Metrics.summary( name(MessageController.class, "primaryDeviceMessageDeliveryDuration")); private static final Counter sendMessageCounter = Metrics.counter(name(WebSocketConnection.class, "sendMessage")); - private static final Counter messageAvailableCounter = Metrics.counter( - name(WebSocketConnection.class, "messagesAvailable")); private static final Counter messagesPersistedCounter = Metrics.counter( name(WebSocketConnection.class, "messagesPersisted")); private static final Counter bytesSentCounter = Metrics.counter(name(WebSocketConnection.class, "bytesSent")); @@ -91,6 +90,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac "sendMessages"); private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class, "sendMessageError"); + private static final String MESSAGE_AVAILABLE_COUNTER_NAME = name(WebSocketConnection.class, "messagesAvailable"); + + private static final String PRESENCE_MANAGER_TAG = "presenceManager"; private static final String STATUS_CODE_TAG = "status"; private static final String STATUS_MESSAGE_TAG = "message"; private static final String ERROR_TYPE_TAG = "errorType"; @@ -468,7 +470,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac return false; } - messageAvailableCounter.increment(); + Metrics.counter(MESSAGE_AVAILABLE_COUNTER_NAME, + PRESENCE_MANAGER_TAG, "legacy") + .increment(); storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE); @@ -477,6 +481,13 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac return true; } + @Override + public void handleNewMessageAvailable() { + Metrics.counter(MESSAGE_AVAILABLE_COUNTER_NAME, + PRESENCE_MANAGER_TAG, "pubsub") + .increment(); + } + @Override public boolean handleMessagesPersisted() { if (!client.isOpen()) { @@ -497,7 +508,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac public void handleDisplacement(final boolean connectedElsewhere) { final Tags tags = Tags.of( UserAgentTagUtil.getPlatformTag(client.getUserAgent()), - Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)) + Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)), + Tag.of(PRESENCE_MANAGER_TAG, "legacy") ); Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment(); @@ -522,6 +534,17 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac } } + @Override + public void handleConnectionDisplaced(final boolean connectedElsewhere) { + final Tags tags = Tags.of( + UserAgentTagUtil.getPlatformTag(client.getUserAgent()), + Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)), + Tag.of(PRESENCE_MANAGER_TAG, "pubsub") + ); + + Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment(); + } + private record StoredMessageInfo(UUID guid, long serverTimestamp) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index 5db958f84..c86d68d17 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -31,10 +31,12 @@ import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher; import org.whispersystems.textsecuregcm.push.APNSender; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.FcmSender; @@ -141,6 +143,8 @@ record CommandDependencies( .maxThreads(1).minThreads(1).build(); ExecutorService fcmSenderExecutor = environment.lifecycle().executorService(name(name, "fcmSender-%d")) .maxThreads(16).minThreads(16).build(); + ExecutorService clientEventExecutor = environment.lifecycle() + .virtualExecutorService(name(name, "clientEvent-%d")); ScheduledExecutorService secureValueRecoveryServiceRetryExecutor = environment.lifecycle() .scheduledExecutorService(name(name, "secureValueRecoveryServiceRetry-%d")).threads(1).build(); @@ -214,6 +218,9 @@ record CommandDependencies( storageServiceExecutor, storageServiceRetryExecutor, configuration.getSecureStorageServiceConfiguration()); ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster, recurringJobExecutor, keyspaceNotificationDispatchExecutor); + ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager( + dynamicConfigurationManager); + PubSubClientEventManager pubSubClientEventManager = new PubSubClientEventManager(messagesCluster, clientEventExecutor, experimentEnrollmentManager); MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor, messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); @@ -230,7 +237,7 @@ record CommandDependencies( new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, pubsubClient, accountLockManager, keys, messagesManager, profilesManager, - secureStorageClient, secureValueRecovery2Client, clientPresenceManager, + secureStorageClient, secureValueRecovery2Client, clientPresenceManager, pubSubClientEventManager, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor, clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager); RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(), @@ -269,6 +276,7 @@ record CommandDependencies( environment.lifecycle().manage(apnSender); environment.lifecycle().manage(messagesCache); environment.lifecycle().manage(clientPresenceManager); + environment.lifecycle().manage(pubSubClientEventManager); environment.lifecycle().manage(new ManagedAwsCrt()); return new CommandDependencies( diff --git a/service/src/main/proto/ClientPresence.proto b/service/src/main/proto/ClientPresence.proto new file mode 100644 index 000000000..b2141009d --- /dev/null +++ b/service/src/main/proto/ClientPresence.proto @@ -0,0 +1,38 @@ +/** + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +syntax = "proto3"; + +package org.signal.chat.presence; + +option java_package = "org.whispersystems.textsecuregcm.push"; +option java_multiple_files = true; + +message ClientEvent { + oneof event { + NewMessageAvailableEvent new_message_available = 1; + ClientConnectedEvent client_connected = 2; + DisconnectRequested disconnect_requested = 3; + } +} + +/** + * Indicates that a new message is available for the client to retrieve. + */ +message NewMessageAvailableEvent { +} + +/** + * Indicates that a client has connected to the presence system. + */ +message ClientConnectedEvent { + bytes connection_id = 1; +} + +/** + * Indicates that the server has requested that the client disconnect due to + * (for example) account lifecycle events. + */ +message DisconnectRequested { +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java index 166e86438..5d5eccaab 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java @@ -5,7 +5,6 @@ package org.whispersystems.textsecuregcm.auth; -import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -61,6 +60,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -97,6 +97,7 @@ class LinkedDeviceRefreshRequirementProviderTest { private AccountsManager accountsManager; private ClientPresenceManager clientPresenceManager; + private PubSubClientEventManager pubSubClientEventManager; private LinkedDeviceRefreshRequirementProvider provider; @@ -104,11 +105,12 @@ class LinkedDeviceRefreshRequirementProviderTest { void setup() { accountsManager = mock(AccountsManager.class); clientPresenceManager = mock(ClientPresenceManager.class); + pubSubClientEventManager = mock(PubSubClientEventManager.class); provider = new LinkedDeviceRefreshRequirementProvider(accountsManager); final WebsocketRefreshRequestEventListener listener = - new WebsocketRefreshRequestEventListener(clientPresenceManager, provider); + new WebsocketRefreshRequestEventListener(clientPresenceManager, pubSubClientEventManager, provider); when(applicationEventListener.onRequest(any())).thenReturn(listener); @@ -146,6 +148,10 @@ class LinkedDeviceRefreshRequirementProviderTest { verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 1); verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 2); verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 3); + + verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 1)); + verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 2)); + verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 3)); } @ParameterizedTest @@ -173,8 +179,10 @@ class LinkedDeviceRefreshRequirementProviderTest { assertEquals(200, response.getStatus()); - initialDeviceIds.forEach(deviceId -> - verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId)); + initialDeviceIds.forEach(deviceId -> { + verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId); + verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(deviceId)); + }); verifyNoMoreInteractions(clientPresenceManager); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java index 9911bd48c..0bb77228e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.net.URI; import java.util.Collections; import java.util.EnumSet; +import java.util.List; import java.util.Optional; import java.util.UUID; import javax.servlet.DispatcherType; @@ -47,6 +48,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -74,6 +76,7 @@ class PhoneNumberChangeRefreshRequirementProviderTest { private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class); private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class); private static final ClientPresenceManager CLIENT_PRESENCE = mock(ClientPresenceManager.class); + private static final PubSubClientEventManager PUBSUB_CLIENT_PRESENCE = mock(PubSubClientEventManager.class); private WebSocketClient client; private final Account account1 = new Account(); @@ -122,9 +125,9 @@ class PhoneNumberChangeRefreshRequirementProviderTest { .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); webSocketEnvironment.jersey().register(new RemoteAddressFilter()); webSocketEnvironment.jersey() - .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE)); + .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE, PUBSUB_CLIENT_PRESENCE)); environment.jersey() - .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE)); + .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE, PUBSUB_CLIENT_PRESENCE)); webSocketEnvironment.setConnectListener(webSocketSessionContext -> { }); @@ -215,6 +218,10 @@ class PhoneNumberChangeRefreshRequirementProviderTest { verify(CLIENT_PRESENCE, timeout(5000)) .disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId())); verifyNoMoreInteractions(CLIENT_PRESENCE); + + verify(PUBSUB_CLIENT_PRESENCE, timeout(5000)) + .requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId())); + verifyNoMoreInteractions(PUBSUB_CLIENT_PRESENCE); } @Test @@ -231,6 +238,10 @@ class PhoneNumberChangeRefreshRequirementProviderTest { verify(CLIENT_PRESENCE, timeout(5000)) .disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId())); verifyNoMoreInteractions(CLIENT_PRESENCE); + + verify(PUBSUB_CLIENT_PRESENCE, timeout(5000)) + .requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId())); + verifyNoMoreInteractions(PUBSUB_CLIENT_PRESENCE); } @ParameterizedTest diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java index a6d0c0949..d922e526b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java @@ -35,6 +35,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; @@ -47,6 +48,7 @@ class RegistrationLockVerificationManagerTest { private final AccountsManager accountsManager = mock(AccountsManager.class); private final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class); + private final PubSubClientEventManager pubSubClientEventManager = mock(PubSubClientEventManager.class); private final ExternalServiceCredentialsGenerator svr2CredentialsGenerator = mock( ExternalServiceCredentialsGenerator.class); private final ExternalServiceCredentialsGenerator svr3CredentialsGenerator = mock( @@ -56,7 +58,7 @@ class RegistrationLockVerificationManagerTest { private static PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); private final RateLimiters rateLimiters = mock(RateLimiters.class); private final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager( - accountsManager, clientPresenceManager, svr2CredentialsGenerator, svr3CredentialsGenerator, + accountsManager, clientPresenceManager, pubSubClientEventManager, svr2CredentialsGenerator, svr3CredentialsGenerator, registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters); private final RateLimiter pinLimiter = mock(RateLimiter.class); @@ -108,6 +110,7 @@ class RegistrationLockVerificationManagerTest { verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber()); } verify(clientPresenceManager).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID)); + verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(Device.PRIMARY_ID)); try { verify(pushNotificationManager).sendAttemptLoginNotification(any(), eq("failedRegistrationLock")); } catch (NotPushRegisteredException npre) {} @@ -131,6 +134,7 @@ class RegistrationLockVerificationManagerTest { } catch (NotPushRegisteredException npre) {} verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber()); verify(clientPresenceManager, never()).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID)); + verify(pubSubClientEventManager, never()).requestDisconnection(any(), any()); }); } }; @@ -169,6 +173,7 @@ class RegistrationLockVerificationManagerTest { verify(account, never()).lockAuthTokenHash(); verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber()); verify(clientPresenceManager, never()).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID)); + verify(pubSubClientEventManager, never()).requestDisconnection(any(), any()); } static Stream testSuccess() { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index d54a41651..9078b14df 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -80,6 +80,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; @@ -111,6 +112,7 @@ class DeviceControllerTest { private static final Account maxedAccount = mock(Account.class); private static final Device primaryDevice = mock(Device.class); private static final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class); + private static final PubSubClientEventManager pubSubClientEventManager = mock(PubSubClientEventManager.class); private static final Map deviceConfiguration = new HashMap<>(); private static final TestClock testClock = TestClock.now(); @@ -131,7 +133,8 @@ class DeviceControllerTest { .addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class)) .addProvider(new RateLimitExceededExceptionMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)) + .addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager, + pubSubClientEventManager)) .addProvider(new DeviceLimitExceededExceptionMapper()) .addResource(deviceController) .build(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java index 257ffd7bc..344505e47 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.when; import com.google.protobuf.ByteString; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -39,6 +40,7 @@ class MessageSenderTest { private MessageProtos.Envelope message; private ClientPresenceManager clientPresenceManager; + private PubSubClientEventManager pubSubClientEventManager; private MessagesManager messagesManager; private PushNotificationManager pushNotificationManager; private MessageSender messageSender; @@ -54,9 +56,14 @@ class MessageSenderTest { message = generateRandomMessage(); clientPresenceManager = mock(ClientPresenceManager.class); + pubSubClientEventManager = mock(PubSubClientEventManager.class); messagesManager = mock(MessagesManager.class); pushNotificationManager = mock(PushNotificationManager.class); - messageSender = new MessageSender(clientPresenceManager, messagesManager, pushNotificationManager); + + when(pubSubClientEventManager.handleNewMessageAvailable(any(), anyByte())) + .thenReturn(CompletableFuture.completedFuture(true)); + + messageSender = new MessageSender(clientPresenceManager, pubSubClientEventManager, messagesManager, pushNotificationManager); when(account.getUuid()).thenReturn(ACCOUNT_UUID); when(device.getId()).thenReturn(DEVICE_ID); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManagerTest.java new file mode 100644 index 000000000..38a04aa83 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/PubSubClientEventManagerTest.java @@ -0,0 +1,337 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.push; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; +import io.lettuce.core.cluster.pubsub.api.async.RedisClusterPubSubAsyncCommands; +import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.IntStream; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture; +import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; + +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) +class PubSubClientEventManagerTest { + + private PubSubClientEventManager localPresenceManager; + private PubSubClientEventManager remotePresenceManager; + + private static ExecutorService clientEventExecutor; + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + private static class ClientEventAdapter implements ClientEventListener { + + @Override + public void handleNewMessageAvailable() { + } + + @Override + public void handleConnectionDisplaced(final boolean connectedElsewhere) { + } + } + + @BeforeAll + static void setUpBeforeAll() { + clientEventExecutor = Executors.newVirtualThreadPerTaskExecutor(); + } + + @BeforeEach + void setUp() { + final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); + when(experimentEnrollmentManager.isEnrolled(any(UUID.class), any())).thenReturn(true); + + localPresenceManager = new PubSubClientEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor, experimentEnrollmentManager); + remotePresenceManager = new PubSubClientEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor, experimentEnrollmentManager); + + localPresenceManager.start(); + remotePresenceManager.start(); + } + + @AfterEach + void tearDown() { + localPresenceManager.stop(); + remotePresenceManager.stop(); + } + + @AfterAll + static void tearDownAfterAll() { + clientEventExecutor.shutdown(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void handleClientConnected(final boolean displaceRemotely) throws InterruptedException { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false); + final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false); + + final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false); + + localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() { + @Override + public void handleConnectionDisplaced(final boolean connectedElsewhere) { + synchronized (firstListenerDisplaced) { + firstListenerDisplaced.set(true); + firstListenerConnectedElsewhere.set(connectedElsewhere); + + firstListenerDisplaced.notifyAll(); + } + } + }).toCompletableFuture().join(); + + assertFalse(firstListenerDisplaced.get()); + assertFalse(secondListenerDisplaced.get()); + + final PubSubClientEventManager displacingManager = + displaceRemotely ? remotePresenceManager : localPresenceManager; + + displacingManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() { + @Override + public void handleConnectionDisplaced(final boolean connectedElsewhere) { + secondListenerDisplaced.set(true); + } + }).toCompletableFuture().join(); + + synchronized (firstListenerDisplaced) { + while (!firstListenerDisplaced.get()) { + firstListenerDisplaced.wait(); + } + } + + assertTrue(firstListenerDisplaced.get()); + assertFalse(secondListenerDisplaced.get()); + + assertTrue(firstListenerConnectedElsewhere.get()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void handleNewMessageAvailable(final boolean messageAvailableRemotely) throws InterruptedException { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + final AtomicBoolean messageReceived = new AtomicBoolean(false); + + localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() { + @Override + public void handleNewMessageAvailable() { + synchronized (messageReceived) { + messageReceived.set(true); + messageReceived.notifyAll(); + } + } + }).toCompletableFuture().join(); + + final PubSubClientEventManager messagePresenceManager = + messageAvailableRemotely ? remotePresenceManager : localPresenceManager; + + assertTrue(messagePresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join()); + + synchronized (messageReceived) { + while (!messageReceived.get()) { + messageReceived.wait(); + } + } + + assertTrue(messageReceived.get()); + } + + @Test + void handleClientDisconnected() { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + final UUID connectionId = + localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter()) + .toCompletableFuture().join(); + + assertTrue(localPresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join()); + + localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId, connectionId).toCompletableFuture().join(); + + assertFalse(localPresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join()); + } + + @Test + void isLocallyPresent() { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + assertFalse(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId)); + assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId)); + + final UUID connectionId = + localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter()) + .toCompletableFuture() + .join(); + + assertTrue(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId)); + assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId)); + + localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId, connectionId) + .toCompletableFuture() + .join(); + + assertFalse(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId)); + assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void requestDisconnection(final boolean requestDisconnectionRemotely) throws InterruptedException { + final UUID accountIdentifier = UUID.randomUUID(); + final byte firstDeviceId = Device.PRIMARY_ID; + final byte secondDeviceId = firstDeviceId + 1; + + final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false); + final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false); + + final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false); + + localPresenceManager.handleClientConnected(accountIdentifier, firstDeviceId, new ClientEventAdapter() { + @Override + public void handleConnectionDisplaced(final boolean connectedElsewhere) { + synchronized (firstListenerDisplaced) { + firstListenerDisplaced.set(true); + firstListenerConnectedElsewhere.set(connectedElsewhere); + + firstListenerDisplaced.notifyAll(); + } + } + }).toCompletableFuture().join(); + + localPresenceManager.handleClientConnected(accountIdentifier, secondDeviceId, new ClientEventAdapter() { + @Override + public void handleConnectionDisplaced(final boolean connectedElsewhere) { + synchronized (secondListenerDisplaced) { + secondListenerDisplaced.set(true); + secondListenerDisplaced.notifyAll(); + } + } + }).toCompletableFuture().join(); + + assertFalse(firstListenerDisplaced.get()); + assertFalse(secondListenerDisplaced.get()); + + final PubSubClientEventManager displacingManager = + requestDisconnectionRemotely ? remotePresenceManager : localPresenceManager; + + displacingManager.requestDisconnection(accountIdentifier, List.of(firstDeviceId)).toCompletableFuture().join(); + + synchronized (firstListenerDisplaced) { + while (!firstListenerDisplaced.get()) { + firstListenerDisplaced.wait(); + } + } + + assertTrue(firstListenerDisplaced.get()); + assertFalse(secondListenerDisplaced.get()); + + assertFalse(firstListenerConnectedElsewhere.get()); + } + + @Test + void resubscribe() { + final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); + when(experimentEnrollmentManager.isEnrolled(any(UUID.class), any())).thenReturn(true); + + @SuppressWarnings("unchecked") final RedisClusterPubSubCommands pubSubCommands = + mock(RedisClusterPubSubCommands.class); + + @SuppressWarnings("unchecked") final RedisClusterPubSubAsyncCommands pubSubAsyncCommands = + mock(RedisClusterPubSubAsyncCommands.class); + + when(pubSubAsyncCommands.ssubscribe(any())).thenReturn(MockRedisFuture.completedFuture(null)); + + final FaultTolerantRedisClusterClient clusterClient = RedisClusterHelper.builder() + .binaryPubSubCommands(pubSubCommands) + .binaryPubSubAsyncCommands(pubSubAsyncCommands) + .build(); + + final PubSubClientEventManager presenceManager = + new PubSubClientEventManager(clusterClient, Runnable::run, experimentEnrollmentManager); + + presenceManager.start(); + + final UUID firstAccountIdentifier = UUID.randomUUID(); + final byte firstDeviceId = Device.PRIMARY_ID; + final int firstSlot = SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(firstAccountIdentifier, firstDeviceId)); + + final UUID secondAccountIdentifier; + final byte secondDeviceId = firstDeviceId + 1; + + // Make sure that the two subscriptions wind up in different slots + { + UUID candidateIdentifier; + + do { + candidateIdentifier = UUID.randomUUID(); + } while (SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(candidateIdentifier, secondDeviceId)) == firstSlot); + + secondAccountIdentifier = candidateIdentifier; + } + + presenceManager.handleClientConnected(firstAccountIdentifier, firstDeviceId, new ClientEventAdapter()).toCompletableFuture().join(); + presenceManager.handleClientConnected(secondAccountIdentifier, secondDeviceId, new ClientEventAdapter()).toCompletableFuture().join(); + + final int secondSlot = SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(secondAccountIdentifier, secondDeviceId)); + + final String firstNodeId = UUID.randomUUID().toString(); + + final RedisClusterNode firstBeforeNode = mock(RedisClusterNode.class); + when(firstBeforeNode.getNodeId()).thenReturn(firstNodeId); + when(firstBeforeNode.getSlots()).thenReturn(IntStream.range(0, SlotHash.SLOT_COUNT).boxed().toList()); + + final RedisClusterNode firstAfterNode = mock(RedisClusterNode.class); + when(firstAfterNode.getNodeId()).thenReturn(firstNodeId); + when(firstAfterNode.getSlots()).thenReturn(IntStream.range(0, SlotHash.SLOT_COUNT) + .filter(slot -> slot != secondSlot) + .boxed() + .toList()); + + final RedisClusterNode secondAfterNode = mock(RedisClusterNode.class); + when(secondAfterNode.getNodeId()).thenReturn(UUID.randomUUID().toString()); + when(secondAfterNode.getSlots()).thenReturn(List.of(secondSlot)); + + presenceManager.resubscribe(new ClusterTopologyChangedEvent( + List.of(firstBeforeNode), + List.of(firstAfterNode, secondAfterNode))); + + verify(pubSubCommands).ssubscribe(PubSubClientEventManager.getClientPresenceKey(secondAccountIdentifier, secondDeviceId)); + verify(pubSubCommands, never()).ssubscribe(PubSubClientEventManager.getClientPresenceKey(firstAccountIdentifier, firstDeviceId)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java index cd5e60779..72c54618b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; @@ -46,7 +47,7 @@ class FaultTolerantPubSubClusterConnectionTest { private TestPublisher eventPublisher; - private Runnable resubscribe; + private Consumer resubscribe; private AtomicInteger resubscribeCounter; private CountDownLatch resubscribeFailure; @@ -93,7 +94,7 @@ class FaultTolerantPubSubClusterConnectionTest { resubscribeCounter = new AtomicInteger(); - resubscribe = () -> { + resubscribe = event -> { try { resubscribeCounter.incrementAndGet(); pubSubConnection.sync().nodes((ignored) -> true); @@ -137,7 +138,7 @@ class FaultTolerantPubSubClusterConnectionTest { void testFilterClusterTopologyChangeEvents() throws InterruptedException { final CountDownLatch topologyEventLatch = new CountDownLatch(1); - faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(topologyEventLatch::countDown); + faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(event -> topologyEventLatch.countDown()); final RedisClusterNode nodeFromDifferentCluster = mock(RedisClusterNode.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java index cbf444092..eaee1f505 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java @@ -44,6 +44,7 @@ import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; @@ -152,6 +153,7 @@ public class AccountCreationDeletionIntegrationTest { secureStorageClient, svr2Client, mock(ClientPresenceManager.class), + mock(PubSubClientEventManager.class), registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index a9fcfc94a..274903c7b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -37,6 +37,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; @@ -67,6 +68,7 @@ class AccountsManagerChangeNumberIntegrationTest { private KeysManager keysManager; private ClientPresenceManager clientPresenceManager; + private PubSubClientEventManager pubSubClientEventManager; private ExecutorService accountLockExecutor; private ExecutorService clientPresenceExecutor; @@ -119,6 +121,7 @@ class AccountsManagerChangeNumberIntegrationTest { when(svr2Client.deleteBackups(any())).thenReturn(CompletableFuture.completedFuture(null)); clientPresenceManager = mock(ClientPresenceManager.class); + pubSubClientEventManager = mock(PubSubClientEventManager.class); final PhoneNumberIdentifiers phoneNumberIdentifiers = new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.PNI.tableName()); @@ -147,6 +150,7 @@ class AccountsManagerChangeNumberIntegrationTest { secureStorageClient, svr2Client, clientPresenceManager, + pubSubClientEventManager, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, @@ -281,7 +285,8 @@ class AccountsManagerChangeNumberIntegrationTest { assertEquals(secondNumber, accountsManager.getByAccountIdentifier(originalUuid).map(Account::getNumber).orElseThrow()); - verify(clientPresenceManager).disconnectPresence(existingAccountUuid, Device.PRIMARY_ID); + verify(clientPresenceManager).disconnectAllPresencesForUuid(existingAccountUuid); + verify(pubSubClientEventManager).requestDisconnection(existingAccountUuid); assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalNumber)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondNumber)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index a8d15bce6..1ba0a8e8b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -49,6 +49,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; @@ -134,6 +135,7 @@ class AccountsManagerConcurrentModificationIntegrationTest { mock(SecureStorageClient.class), mock(SecureValueRecovery2Client.class), mock(ClientPresenceManager.class), + mock(PubSubClientEventManager.class), mock(RegistrationRecoveryPasswordsManager.class), mock(ClientPublicKeysManager.class), mock(Executor.class), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java index 3dfe355cf..228ebef2e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java @@ -15,6 +15,7 @@ import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest; import org.whispersystems.textsecuregcm.entities.RemoteAttachment; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.RedisServerExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; @@ -63,6 +64,7 @@ public class AccountsManagerDeviceTransferIntegrationTest { mock(SecureStorageClient.class), mock(SecureValueRecovery2Client.class), mock(ClientPresenceManager.class), + mock(PubSubClientEventManager.class), mock(RegistrationRecoveryPasswordsManager.class), mock(ClientPublicKeysManager.class), mock(ExecutorService.class), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index d92233584..f435d35a0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -80,6 +80,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; @@ -118,6 +119,7 @@ class AccountsManagerTest { private MessagesManager messagesManager; private ProfilesManager profilesManager; private ClientPresenceManager clientPresenceManager; + private PubSubClientEventManager pubSubClientEventManager; private ClientPublicKeysManager clientPublicKeysManager; private Map phoneNumberIdentifiersByE164; @@ -153,6 +155,7 @@ class AccountsManagerTest { messagesManager = mock(MessagesManager.class); profilesManager = mock(ProfilesManager.class); clientPresenceManager = mock(ClientPresenceManager.class); + pubSubClientEventManager = mock(PubSubClientEventManager.class); clientPublicKeysManager = mock(ClientPublicKeysManager.class); dynamicConfiguration = mock(DynamicConfiguration.class); @@ -259,6 +262,7 @@ class AccountsManagerTest { storageClient, svr2Client, clientPresenceManager, + pubSubClientEventManager, registrationRecoveryPasswordsManager, clientPublicKeysManager, mock(Executor.class), @@ -799,6 +803,7 @@ class AccountsManagerTest { verify(keysManager).buildWriteItemsForRemovedDevice(account.getUuid(), account.getPhoneNumberIdentifier(), linkedDevice.getId()); verify(clientPublicKeysManager).buildTransactWriteItemForDeletion(account.getUuid(), linkedDevice.getId()); verify(clientPresenceManager).disconnectPresence(account.getUuid(), linkedDevice.getId()); + verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(linkedDevice.getId())); } @Test @@ -817,6 +822,7 @@ class AccountsManagerTest { verify(messagesManager, never()).clear(any(), anyByte()); verify(keysManager, never()).deleteSingleUsePreKeys(any(), anyByte()); verify(clientPresenceManager, never()).disconnectPresence(any(), anyByte()); + verify(pubSubClientEventManager, never()).requestDisconnection(any(), any()); } @Test @@ -886,6 +892,7 @@ class AccountsManagerTest { verify(messagesManager, times(2)).clear(existingUuid); verify(profilesManager, times(2)).deleteAll(existingUuid); verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid); + verify(pubSubClientEventManager).requestDisconnection(existingUuid); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java index 918a9c275..ebf9f506d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java @@ -36,6 +36,7 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.Mockito; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; @@ -146,6 +147,7 @@ class AccountsManagerUsernameIntegrationTest { mock(SecureStorageClient.class), mock(SecureValueRecovery2Client.class), mock(ClientPresenceManager.class), + mock(PubSubClientEventManager.class), mock(RegistrationRecoveryPasswordsManager.class), mock(ClientPublicKeysManager.class), Executors.newSingleThreadExecutor(), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java index e05cfcaa0..27dbcd548 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java @@ -34,6 +34,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisServerExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; @@ -152,6 +153,7 @@ public class AddRemoveDeviceIntegrationTest { secureStorageClient, svr2Client, mock(ClientPresenceManager.class), + mock(PubSubClientEventManager.class), registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java index b6b7ac306..86343e299 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java @@ -14,8 +14,12 @@ import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; +import io.lettuce.core.cluster.pubsub.api.async.RedisClusterPubSubAsyncCommands; +import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands; import java.util.function.Consumer; import java.util.function.Function; +import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; public class RedisClusterHelper { @@ -30,7 +34,12 @@ public class RedisClusterHelper { final RedisAdvancedClusterAsyncCommands stringAsyncCommands, final RedisAdvancedClusterCommands binaryCommands, final RedisAdvancedClusterAsyncCommands binaryAsyncCommands, - final RedisAdvancedClusterReactiveCommands binaryReactiveCommands) { + final RedisAdvancedClusterReactiveCommands binaryReactiveCommands, + final RedisClusterPubSubCommands stringPubSubCommands, + final RedisClusterPubSubAsyncCommands stringAsyncPubSubCommands, + final RedisClusterPubSubCommands binaryPubSubCommands, + final RedisClusterPubSubAsyncCommands binaryAsyncPubSubCommands) { + final FaultTolerantRedisClusterClient cluster = mock(FaultTolerantRedisClusterClient.class); final StatefulRedisClusterConnection stringConnection = mock(StatefulRedisClusterConnection.class); final StatefulRedisClusterConnection binaryConnection = mock(StatefulRedisClusterConnection.class); @@ -59,6 +68,45 @@ public class RedisClusterHelper { return null; }).when(cluster).useBinaryCluster(any(Consumer.class)); + final StatefulRedisClusterPubSubConnection stringPubSubConnection = + mock(StatefulRedisClusterPubSubConnection.class); + + final StatefulRedisClusterPubSubConnection binaryPubSubConnection = + mock(StatefulRedisClusterPubSubConnection.class); + + final FaultTolerantPubSubClusterConnection faultTolerantPubSubClusterConnection = + mock(FaultTolerantPubSubClusterConnection.class); + + final FaultTolerantPubSubClusterConnection faultTolerantBinaryPubSubClusterConnection = + mock(FaultTolerantPubSubClusterConnection.class); + + when(stringPubSubConnection.sync()).thenReturn(stringPubSubCommands); + when(stringPubSubConnection.async()).thenReturn(stringAsyncPubSubCommands); + when(binaryPubSubConnection.sync()).thenReturn(binaryPubSubCommands); + when(binaryPubSubConnection.async()).thenReturn(binaryAsyncPubSubCommands); + + when(cluster.createPubSubConnection()).thenReturn(faultTolerantPubSubClusterConnection); + when(cluster.createBinaryPubSubConnection()).thenReturn(faultTolerantBinaryPubSubClusterConnection); + + when(faultTolerantPubSubClusterConnection.withPubSubConnection(any(Function.class))).thenAnswer(invocation -> { + return invocation.getArgument(0, Function.class).apply(stringPubSubConnection); + }); + + doAnswer(invocation -> { + invocation.getArgument(0, Consumer.class).accept(stringPubSubConnection); + return null; + }).when(faultTolerantPubSubClusterConnection).usePubSubConnection(any(Consumer.class)); + + when(faultTolerantBinaryPubSubClusterConnection.withPubSubConnection(any(Function.class))).thenAnswer( + invocation -> { + return invocation.getArgument(0, Function.class).apply(binaryPubSubConnection); + }); + + doAnswer(invocation -> { + invocation.getArgument(0, Consumer.class).accept(binaryPubSubConnection); + return null; + }).when(faultTolerantBinaryPubSubClusterConnection).usePubSubConnection(any(Consumer.class)); + return cluster; } @@ -77,6 +125,18 @@ public class RedisClusterHelper { private RedisAdvancedClusterReactiveCommands binaryReactiveCommands = mock(RedisAdvancedClusterReactiveCommands.class); + private RedisClusterPubSubCommands stringPubSubCommands = + mock(RedisClusterPubSubCommands.class); + + private RedisClusterPubSubCommands binaryPubSubCommands = + mock(RedisClusterPubSubCommands.class); + + private RedisClusterPubSubAsyncCommands stringPubSubAsyncCommands = + mock(RedisClusterPubSubAsyncCommands.class); + + private RedisClusterPubSubAsyncCommands binaryPubSubAsyncCommands = + mock(RedisClusterPubSubAsyncCommands.class); + private Builder() { } @@ -107,9 +167,33 @@ public class RedisClusterHelper { return this; } + public Builder stringPubSubCommands(final RedisClusterPubSubCommands stringPubSubCommands) { + this.stringPubSubCommands = stringPubSubCommands; + return this; + } + + public Builder binaryPubSubCommands(final RedisClusterPubSubCommands binaryPubSubCommands) { + this.binaryPubSubCommands = binaryPubSubCommands; + return this; + } + + public Builder stringPubSubAsyncCommands( + final RedisClusterPubSubAsyncCommands stringPubSubAsyncCommands) { + this.stringPubSubAsyncCommands = stringPubSubAsyncCommands; + return this; + } + + public Builder binaryPubSubAsyncCommands( + final RedisClusterPubSubAsyncCommands binaryPubSubAsyncCommands) { + this.binaryPubSubAsyncCommands = binaryPubSubAsyncCommands; + return this; + } + public FaultTolerantRedisClusterClient build() { - return RedisClusterHelper.buildMockRedisCluster(stringCommands, stringAsyncCommands, binaryCommands, binaryAsyncCommands, - binaryReactiveCommands); + return RedisClusterHelper.buildMockRedisCluster(stringCommands, stringAsyncCommands, binaryCommands, + binaryAsyncCommands, + binaryReactiveCommands, stringPubSubCommands, stringPubSubAsyncCommands, binaryPubSubCommands, + binaryPubSubAsyncCommands); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java index 465d6aac1..260362705 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java @@ -5,17 +5,199 @@ package org.whispersystems.textsecuregcm.util; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; class RedisClusterUtilTest { - @Test - void testGetMinimalHashTag() { - for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) { - assertEquals(slot, SlotHash.getSlot(RedisClusterUtil.getMinimalHashTag(slot))); - } + @Test + void testGetMinimalHashTag() { + for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) { + assertEquals(slot, SlotHash.getSlot(RedisClusterUtil.getMinimalHashTag(slot))); } + } + + @ParameterizedTest + @MethodSource + void getChangedSlots(final ClusterTopologyChangedEvent event, final boolean[] expectedSlotsChanged) { + assertArrayEquals(expectedSlotsChanged, RedisClusterUtil.getChangedSlots(event)); + } + + private static List getChangedSlots() { + final List arguments = new ArrayList<>(); + + // Slot moved from one node to another + { + final String firstNodeId = UUID.randomUUID().toString(); + final String secondNodeId = UUID.randomUUID().toString(); + + final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class); + when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId); + when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192)); + + final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class); + when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId); + when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384)); + + final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class); + when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId); + when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8191)); + + final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class); + when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId); + when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8191, 16384)); + + final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent( + List.of(firstNodeBefore, secondNodeBefore), + List.of(firstNodeAfter, secondNodeAfter)); + + final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT]; + slotsChanged[8191] = true; + + arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged)); + } + + // New node added to cluster + { + final String firstNodeId = UUID.randomUUID().toString(); + final String secondNodeId = UUID.randomUUID().toString(); + + final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class); + when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId); + when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192)); + + final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class); + when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId); + when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384)); + + final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class); + when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId); + when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8192)); + + final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class); + when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId); + when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8192, 12288)); + + final RedisClusterNode thirdNodeAfter = mock(RedisClusterNode.class); + when(thirdNodeAfter.getNodeId()).thenReturn(UUID.randomUUID().toString()); + when(thirdNodeAfter.getSlots()).thenReturn(getSlotRange(12288, 16384)); + + final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent( + List.of(firstNodeBefore, secondNodeBefore), + List.of(firstNodeAfter, secondNodeAfter, thirdNodeAfter)); + + final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT]; + + for (int slot = 12288; slot < 16384; slot++) { + slotsChanged[slot] = true; + } + + arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged)); + } + + // Node removed from cluster + { + final String firstNodeId = UUID.randomUUID().toString(); + final String secondNodeId = UUID.randomUUID().toString(); + + final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class); + when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId); + when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192)); + + final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class); + when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId); + when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 12288)); + + final RedisClusterNode thirdNodeBefore = mock(RedisClusterNode.class); + when(thirdNodeBefore.getNodeId()).thenReturn(UUID.randomUUID().toString()); + when(thirdNodeBefore.getSlots()).thenReturn(getSlotRange(12288, 16384)); + + final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class); + when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId); + when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8192)); + + final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class); + when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId); + when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8192, 16384)); + + final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent( + List.of(firstNodeBefore, secondNodeBefore, thirdNodeBefore), + List.of(firstNodeAfter, secondNodeAfter)); + + final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT]; + + for (int slot = 12288; slot < 16384; slot++) { + slotsChanged[slot] = true; + } + + arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged)); + } + + // Node added, node removed, and slot moved + // Node removed from cluster + { + final String secondNodeId = UUID.randomUUID().toString(); + final String thirdNodeId = UUID.randomUUID().toString(); + + final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class); + when(firstNodeBefore.getNodeId()).thenReturn(UUID.randomUUID().toString()); + when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 1)); + + final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class); + when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId); + when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(1, 8192)); + + final RedisClusterNode thirdNodeBefore = mock(RedisClusterNode.class); + when(thirdNodeBefore.getNodeId()).thenReturn(thirdNodeId); + when(thirdNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384)); + + final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class); + when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId); + when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8191)); + + final RedisClusterNode thirdNodeAfter = mock(RedisClusterNode.class); + when(thirdNodeAfter.getNodeId()).thenReturn(thirdNodeId); + when(thirdNodeAfter.getSlots()).thenReturn(getSlotRange(8191, 16383)); + + final RedisClusterNode fourthNodeAfter = mock(RedisClusterNode.class); + when(fourthNodeAfter.getNodeId()).thenReturn(UUID.randomUUID().toString()); + when(fourthNodeAfter.getSlots()).thenReturn(getSlotRange(16383, 16384)); + + final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent( + List.of(firstNodeBefore, secondNodeBefore, thirdNodeBefore), + List.of(secondNodeAfter, thirdNodeAfter, fourthNodeAfter)); + + final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT]; + slotsChanged[0] = true; + slotsChanged[8191] = true; + slotsChanged[16383] = true; + + arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged)); + } + + return arguments; + } + + private static List getSlotRange(final int startInclusive, final int endExclusive) { + final List slots = new ArrayList<>(endExclusive - startInclusive); + + for (int i = startInclusive; i < endExclusive; i++) { + slots.add(i); + } + + return slots; + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 08c83f60e..5ce2d20bd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -58,6 +58,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.PubSubClientEventManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -124,8 +125,8 @@ class WebSocketConnectionTest { new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class)); AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager, new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), - mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager, - mock(MessageDeliveryLoopMonitor.class)); + mock(ClientPresenceManager.class), mock(PubSubClientEventManager.class), retrySchedulingExecutor, + messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class)); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))