diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index d3e88bffb..d1d4beb9b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -102,6 +102,7 @@ import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck; import org.whispersystems.textsecuregcm.providers.RedisHealthCheck; import org.whispersystems.textsecuregcm.push.APNSender; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; +import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.GCMSender; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -173,6 +174,7 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import static com.codahale.metrics.MetricRegistry.name; @@ -335,7 +337,10 @@ public class WhisperServerService extends Application(1_000)).build(); + ScheduledExecutorService clientPresenceExecutor = environment.lifecycle().scheduledExecutorService("clientPresenceManager").threads(1).build(); + ExecutorService messageCacheClusterExperimentExecutor = environment.lifecycle().executorService("messages_cache_experiment").maxThreads(8).workQueue(new ArrayBlockingQueue<>(1_000)).build(); + ExecutorService websocketExperimentExecutor = environment.lifecycle().executorService("websocketPresenceExperiment").maxThreads(8).workQueue(new ArrayBlockingQueue<>(1_000)).build(); + ClientPresenceManager clientPresenceManager = new ClientPresenceManager(messagesCacheCluster, clientPresenceExecutor); DirectoryManager directory = new DirectoryManager(directoryClient); DirectoryQueue directoryQueue = new DirectoryQueue(config.getDirectoryConfiguration().getSqsConfiguration()); @@ -354,7 +359,7 @@ public class WhisperServerService extends Application webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); - webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(pushSender, receiptSender, messagesManager, pubSubManager, apnFallbackManager)); + webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(pushSender, receiptSender, messagesManager, pubSubManager, apnFallbackManager, clientPresenceManager)); webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET)); webSocketEnvironment.jersey().register(new KeepAliveController(pubSubManager)); webSocketEnvironment.jersey().register(messageController); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java new file mode 100644 index 000000000..5b9edcf1f --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java @@ -0,0 +1,263 @@ +package org.whispersystems.textsecuregcm.push; + +import com.codahale.metrics.Meter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.SharedMetricRegistries; +import com.codahale.metrics.Timer; +import com.google.common.annotations.VisibleForTesting; +import io.dropwizard.lifecycle.Managed; +import io.lettuce.core.ScriptOutputType; +import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; +import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.util.Constants; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import static com.codahale.metrics.MetricRegistry.name; + +/** + * The client presence manager keeps track of which clients are actively connected and "present" to receive messages. + * Only one client per account/device may be present at a time; if a second client for the same account/device declares + * its presence, the previous client is displaced. + */ +public class ClientPresenceManager extends RedisClusterPubSubAdapter implements Managed { + + private final String managerId = UUID.randomUUID().toString(); + private final String connectedClientSetKey = getConnectedClientSetKey(managerId); + + private final FaultTolerantRedisCluster presenceCluster; + private final ClusterLuaScript clearPresenceScript; + + private final ScheduledExecutorService scheduledExecutorService; + private ScheduledFuture pruneMissingPeersFuture; + + private final Map displacementListenersByPresenceKey = new ConcurrentHashMap<>(); + + private final Timer checkPresenceTimer; + private final Timer setPresenceTimer; + private final Timer clearPresenceTimer; + private final Timer prunePeersTimer; + private final Meter pruneClientMeter; + private final Meter remoteDisplacementMeter; + + private static final int PRUNE_PEERS_INTERVAL_SECONDS = (int)Duration.ofMinutes(3).toSeconds(); + + static final String MANAGER_SET_KEY = "presence::managers"; + + private static final Logger log = LoggerFactory.getLogger(ClientPresenceManager.class); + + public ClientPresenceManager(final FaultTolerantRedisCluster presenceCluster, final ScheduledExecutorService scheduledExecutorService) throws IOException { + this.presenceCluster = presenceCluster; + this.scheduledExecutorService = scheduledExecutorService; + this.clearPresenceScript = ClusterLuaScript.fromResource(presenceCluster, "lua/clear_presence.lua", ScriptOutputType.INTEGER); + + final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + metricRegistry.gauge(name(getClass(), "localClientCount"), () -> displacementListenersByPresenceKey::size); + + this.checkPresenceTimer = metricRegistry.timer(name(getClass(), "checkPresence")); + this.setPresenceTimer = metricRegistry.timer(name(getClass(), "setPresence")); + this.clearPresenceTimer = metricRegistry.timer(name(getClass(), "clearPresence")); + this.prunePeersTimer = metricRegistry.timer(name(getClass(), "prunePeers")); + this.pruneClientMeter = metricRegistry.meter(name(getClass(), "pruneClient")); + this.remoteDisplacementMeter = metricRegistry.meter(name(getClass(), "remoteDisplacement")); + } + + @Override + public void start() { + presenceCluster.usePubSubConnection(connection -> { + final String configuredKeyspaceNotifications = connection.sync().configGet("notify-keyspace-events").getOrDefault("notify-keyspace-events", ""); + + for (final char requiredNotificationType : new char[] {'K', '$'}) { + if (configuredKeyspaceNotifications.indexOf(requiredNotificationType) == -1) { + throw new IllegalStateException("Required keyspace notification type not configured. Need at least K$, but is actually: " + configuredKeyspaceNotifications); + } + } + + connection.addListener(this); + connection.getResources().eventBus().get() + .filter(event -> event instanceof ClusterTopologyChangedEvent) + .handle((event, sink) -> { + resubscribeAll(); + sink.next(event); + }); + + final String presenceChannel = getManagerPresenceChannel(managerId); + final int slot = SlotHash.getSlot(presenceChannel); + + connection.sync().nodes(node -> node.is(RedisClusterNode.NodeFlag.MASTER) && node.hasSlot(slot)).commands().subscribe(presenceChannel); + }); + + presenceCluster.useWriteCluster(connection -> connection.sync().sadd(MANAGER_SET_KEY, managerId)); + + pruneMissingPeersFuture = scheduledExecutorService.scheduleAtFixedRate(this::pruneMissingPeers, new Random().nextInt(PRUNE_PEERS_INTERVAL_SECONDS), PRUNE_PEERS_INTERVAL_SECONDS, TimeUnit.SECONDS); + } + + @Override + public void stop() { + presenceCluster.usePubSubConnection(connection -> connection.removeListener(this)); + + if (pruneMissingPeersFuture != null) { + pruneMissingPeersFuture.cancel(false); + } + + for (final String presenceKey : displacementListenersByPresenceKey.keySet()) { + clearPresence(presenceKey); + } + + presenceCluster.useWriteCluster(connection -> { + connection.sync().srem(MANAGER_SET_KEY, managerId); + connection.sync().del(getConnectedClientSetKey(managerId)); + }); + + presenceCluster.usePubSubConnection(connection -> connection.sync().masters().commands().unsubscribe(getManagerPresenceChannel(managerId))); + } + + public void setPresent(final UUID accountUuid, final long deviceId, final DisplacedPresenceListener displacementListener) { + try (final Timer.Context ignored = setPresenceTimer.time()) { + final String presenceKey = getPresenceKey(accountUuid, deviceId); + + displacePresence(presenceKey); + + displacementListenersByPresenceKey.put(presenceKey, displacementListener); + + presenceCluster.useWriteCluster(connection -> { + final RedisAdvancedClusterCommands commands = connection.sync(); + + commands.set(presenceKey, managerId); + commands.sadd(connectedClientSetKey, presenceKey); + }); + + subscribeForRemotePresenceChanges(presenceKey); + } + } + + private void displacePresence(final String presenceKey) { + final DisplacedPresenceListener displacementListener = displacementListenersByPresenceKey.get(presenceKey); + + if (displacementListener != null) { + displacementListener.handleDisplacement(); + } + + clearPresence(presenceKey); + } + + public boolean isPresent(final UUID accountUuid, final long deviceId) { + try (final Timer.Context ignored = checkPresenceTimer.time()) { + return presenceCluster.withReadCluster(connection -> connection.sync().exists(getPresenceKey(accountUuid, deviceId))) == 1; + } + } + + public boolean clearPresence(final UUID accountUuid, final long deviceId) { + return clearPresence(getPresenceKey(accountUuid, deviceId)); + } + + private boolean clearPresence(final String presenceKey) { + try (final Timer.Context ignored = clearPresenceTimer.time()) { + displacementListenersByPresenceKey.remove(presenceKey); + unsubscribeFromRemotePresenceChanges(presenceKey); + + final boolean removed = clearPresenceScript.execute(List.of(presenceKey), List.of(managerId)) != null; + presenceCluster.useWriteCluster(connection -> connection.sync().srem(connectedClientSetKey, presenceKey)); + + return removed; + } + } + + private void subscribeForRemotePresenceChanges(final String presenceKey) { + final int slot = SlotHash.getSlot(presenceKey); + + presenceCluster.usePubSubConnection(connection -> connection.sync().nodes(node -> node.is(RedisClusterNode.NodeFlag.MASTER) && node.hasSlot(slot)) + .commands() + .subscribe(getKeyspaceNotificationChannel(presenceKey))); + } + + private void resubscribeAll() { + for (final String presenceKey : displacementListenersByPresenceKey.keySet()) { + subscribeForRemotePresenceChanges(presenceKey); + } + } + + private void unsubscribeFromRemotePresenceChanges(final String presenceKey) { + presenceCluster.usePubSubConnection(connection -> connection.async().masters().commands().unsubscribe(getKeyspaceNotificationChannel(presenceKey))); + } + + void pruneMissingPeers() { + try (final Timer.Context ignored = prunePeersTimer.time()) { + final Set peerIds = presenceCluster.withReadCluster(connection -> connection.sync().smembers(MANAGER_SET_KEY)); + peerIds.remove(managerId); + + for (final String peerId : peerIds) { + final boolean peerMissing = presenceCluster.withWriteCluster(connection -> connection.sync().publish(getManagerPresenceChannel(peerId), "ping") == 0); + + if (peerMissing) { + log.debug("Presence manager {} did not respond to ping", peerId); + + final String connectedClientsKey = getConnectedClientSetKey(peerId); + + presenceCluster.useWriteCluster(connection -> { + final RedisAdvancedClusterCommands commands = connection.sync(); + + String presenceKey; + + while ((presenceKey = commands.spop(connectedClientsKey)) != null) { + clearPresenceScript.execute(List.of(presenceKey), List.of(peerId)); + pruneClientMeter.mark(); + } + + commands.del(connectedClientsKey); + commands.srem(MANAGER_SET_KEY, peerId); + }); + } + } + } + } + + @Override + public void message(final RedisClusterNode node, final String channel, final String message) { + if ("set".equals(message) && channel.startsWith("__keyspace@0__:presence::{")) { + // Another process has overwritten this presence key, which means the client has connected to another host. + // At this point, we're on a Lettuce IO thread and need to dispatch to a separate thread before making + // synchronous Lettuce calls to avoid deadlocking. + scheduledExecutorService.execute(() -> { + displacePresence(channel.substring("__keyspace@0__:".length())); + remoteDisplacementMeter.mark(); + }); + } + } + + @VisibleForTesting + static String getPresenceKey(final UUID accountUuid, final long deviceId) { + return "presence::{" + accountUuid.toString() + "::" + deviceId + "}"; + } + + private static String getKeyspaceNotificationChannel(final String presenceKey) { + return "__keyspace@0__:" + presenceKey; + } + + @VisibleForTesting + static String getConnectedClientSetKey(final String managerId) { + return "presence::clients::" + managerId; + } + + @VisibleForTesting + static String getManagerPresenceChannel(final String managerId) { + return "presence::manager::" + managerId; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/DisplacedPresenceListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/DisplacedPresenceListener.java new file mode 100644 index 000000000..ab72618b9 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/DisplacedPresenceListener.java @@ -0,0 +1,11 @@ +package org.whispersystems.textsecuregcm.push; + +/** + * A displaced presence listener is notified when a specific client's presence has been displaced because the same + * client opened a newer connection to the Signal service. + */ +@FunctionalInterface +public interface DisplacedPresenceListener { + + void handleDisplacement(); +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java index fd2ca1cf3..41c9f2044 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java @@ -22,6 +22,7 @@ import com.codahale.metrics.SharedMetricRegistries; import com.google.protobuf.ByteString; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.experiment.Experiment; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; @@ -30,6 +31,8 @@ import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; +import java.util.concurrent.Executor; + import static com.codahale.metrics.MetricRegistry.name; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; @@ -60,12 +63,18 @@ public class WebsocketSender { private final Meter provisioningOnlineMeter = metricRegistry.meter(name(getClass(), "provisioning_online" )); private final Meter provisioningOfflineMeter = metricRegistry.meter(name(getClass(), "provisioning_offline")); - private final MessagesManager messagesManager; - private final PubSubManager pubSubManager; + private final MessagesManager messagesManager; + private final PubSubManager pubSubManager; + private final ClientPresenceManager clientPresenceManager; - public WebsocketSender(MessagesManager messagesManager, PubSubManager pubSubManager) { - this.messagesManager = messagesManager; - this.pubSubManager = pubSubManager; + private final Experiment presenceExperiment = new Experiment("presence", "websocketSender"); + private final Executor experimentExecutor; + + public WebsocketSender(MessagesManager messagesManager, PubSubManager pubSubManager, ClientPresenceManager clientPresenceManager, Executor experimentExecutor) { + this.messagesManager = messagesManager; + this.pubSubManager = pubSubManager; + this.clientPresenceManager = clientPresenceManager; + this.experimentExecutor = experimentExecutor; } public DeliveryStatus sendMessage(Account account, Device device, Envelope message, Type channel, boolean online) { @@ -75,7 +84,11 @@ public class WebsocketSender { .setContent(message.toByteString()) .build(); - if (pubSubManager.publish(address, pubSubMessage)) { + final boolean clientPresent = pubSubManager.publish(address, pubSubMessage); + + presenceExperiment.compareSupplierResultAsync(clientPresent, () -> clientPresenceManager.isPresent(account.getUuid(), device.getId()), experimentExecutor); + + if (clientPresent) { if (channel == Type.APN) apnOnlineMeter.mark(); else if (channel == Type.GCM) gcmOnlineMeter.mark(); else websocketOnlineMeter.mark(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java index bf0696649..888aec501 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java @@ -51,6 +51,8 @@ public class FaultTolerantRedisCluster { this.writeCircuitBreaker = CircuitBreaker.of(name + "-write", circuitBreakerConfiguration.toCircuitBreakerConfig()); this.pubSubCircuitBreaker = CircuitBreaker.of(name + "-pubsub", circuitBreakerConfiguration.toCircuitBreakerConfig()); + this.pubSubClusterConnection.setNodeMessagePropagation(true); + CircuitBreakerUtil.registerMetrics(SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME), readCircuitBreaker, FaultTolerantRedisCluster.class); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index ecd6143ed..1332f3b8d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -1,6 +1,7 @@ package org.whispersystems.textsecuregcm.websocket; import com.codahale.metrics.Counter; +import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; @@ -8,6 +9,8 @@ import com.google.protobuf.ByteString; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; +import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.redis.RedisOperation; @@ -31,24 +34,28 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private static final Timer durationTimer = metricRegistry.timer(name(WebSocketConnection.class, "connected_duration" )); private static final Timer unauthenticatedDurationTimer = metricRegistry.timer(name(WebSocketConnection.class, "unauthenticated_connection_duration")); private static final Counter openWebsocketCounter = metricRegistry.counter(name(WebSocketConnection.class, "open_websockets")); + private static final Meter explicitDisplacementMeter = metricRegistry.meter(name(WebSocketConnection.class, "explicitDisplacement")); - private final PushSender pushSender; - private final ReceiptSender receiptSender; - private final MessagesManager messagesManager; - private final PubSubManager pubSubManager; - private final ApnFallbackManager apnFallbackManager; + private final PushSender pushSender; + private final ReceiptSender receiptSender; + private final MessagesManager messagesManager; + private final PubSubManager pubSubManager; + private final ApnFallbackManager apnFallbackManager; + private final ClientPresenceManager clientPresenceManager; public AuthenticatedConnectListener(PushSender pushSender, ReceiptSender receiptSender, MessagesManager messagesManager, PubSubManager pubSubManager, - ApnFallbackManager apnFallbackManager) + ApnFallbackManager apnFallbackManager, + ClientPresenceManager clientPresenceManager) { - this.pushSender = pushSender; - this.receiptSender = receiptSender; - this.messagesManager = messagesManager; - this.pubSubManager = pubSubManager; - this.apnFallbackManager = apnFallbackManager; + this.pushSender = pushSender; + this.receiptSender = receiptSender; + this.messagesManager = messagesManager; + this.pubSubManager = pubSubManager; + this.apnFallbackManager = apnFallbackManager; + this.clientPresenceManager = clientPresenceManager; } @Override @@ -68,6 +75,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { openWebsocketCounter.inc(); RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device)); + clientPresenceManager.setPresent(account.getUuid(), device.getId(), explicitDisplacementMeter::mark); pubSubManager.publish(address, connectMessage); pubSubManager.subscribe(address, connection); @@ -76,6 +84,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) { openWebsocketCounter.dec(); pubSubManager.unsubscribe(address, connection); + clientPresenceManager.clearPresence(account.getUuid(), device.getId()); timer.stop(); } }); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 03369b9d6..24860dd03 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -15,6 +15,7 @@ import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; +import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -40,9 +41,10 @@ import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessag @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class WebSocketConnection implements DispatchChannel { - private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - public static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration")); - private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message")); + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + public static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration")); + private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message")); + private static final Meter pubSubDisplacementMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubDisplacement")); private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); @@ -86,6 +88,7 @@ public class WebSocketConnection implements DispatchChannel { break; case PubSubMessage.Type.CONNECTED_VALUE: if (pubSubMessage.hasContent() && !new String(pubSubMessage.getContent().toByteArray()).equals(connectionId)) { + pubSubDisplacementMeter.mark(); client.hardDisconnectQuietly(); } break; diff --git a/service/src/main/resources/lua/clear_presence.lua b/service/src/main/resources/lua/clear_presence.lua new file mode 100644 index 000000000..9e716c118 --- /dev/null +++ b/service/src/main/resources/lua/clear_presence.lua @@ -0,0 +1,9 @@ +local presenceKey = KEYS[1] +local presenceUuid = ARGV[1] + +if redis.call("GET", presenceKey) == presenceUuid then + redis.call("DEL", presenceKey) + return true +end + +return false diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java new file mode 100644 index 000000000..d72f4de4c --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java @@ -0,0 +1,224 @@ +package org.whispersystems.textsecuregcm.push; + +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; + +import java.util.List; +import java.util.UUID; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ClientPresenceManagerTest extends AbstractRedisClusterTest { + + private ScheduledExecutorService presenceRenewalExecutorService; + private ClientPresenceManager clientPresenceManager; + + private static final DisplacedPresenceListener NO_OP = () -> {}; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + getRedisCluster().useWriteCluster(connection -> { + connection.sync().flushall(); + connection.sync().masters().commands().configSet("notify-keyspace-events", "K$z"); + }); + + presenceRenewalExecutorService = Executors.newSingleThreadScheduledExecutor(); + clientPresenceManager = new ClientPresenceManager(getRedisCluster(), presenceRenewalExecutorService); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + + presenceRenewalExecutorService.shutdown(); + presenceRenewalExecutorService.awaitTermination(1, TimeUnit.MINUTES); + } + + @Test + public void testIsPresent() { + final UUID accountUuid = UUID.randomUUID(); + final long deviceId = 1; + + assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId)); + + clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP); + assertTrue(clientPresenceManager.isPresent(accountUuid, deviceId)); + } + + @Test + public void testLocalDisplacement() { + final UUID accountUuid = UUID.randomUUID(); + final long deviceId = 1; + + final AtomicInteger displacementCounter = new AtomicInteger(0); + final DisplacedPresenceListener displacementListener = displacementCounter::incrementAndGet; + + clientPresenceManager.setPresent(accountUuid, deviceId, displacementListener); + + assertEquals(0, displacementCounter.get()); + + clientPresenceManager.setPresent(accountUuid, deviceId, displacementListener); + + assertEquals(1, displacementCounter.get()); + } + + @Test(timeout = 10_000) + public void testRemoteDisplacement() throws InterruptedException { + final UUID accountUuid = UUID.randomUUID(); + final long deviceId = 1; + + final AtomicBoolean displaced = new AtomicBoolean(false); + + clientPresenceManager.start(); + + try { + clientPresenceManager.setPresent(accountUuid, deviceId, () -> { + synchronized (displaced) { + displaced.set(true); + displaced.notifyAll(); + } + }); + + getRedisCluster().useWriteCluster(connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(accountUuid, deviceId), + UUID.randomUUID().toString())); + + synchronized (displaced) { + while (!displaced.get()) { + displaced.wait(); + } + } + } finally { + clientPresenceManager.stop(); + } + } + + @Test(timeout = 10_000) + public void testRemoteDisplacementAfterTopologyChange() throws InterruptedException { + final UUID accountUuid = UUID.randomUUID(); + final long deviceId = 1; + + final AtomicBoolean displaced = new AtomicBoolean(false); + + clientPresenceManager.start(); + + try { + clientPresenceManager.setPresent(accountUuid, deviceId, () -> { + synchronized (displaced) { + displaced.set(true); + displaced.notifyAll(); + } + }); + + getRedisCluster().usePubSubConnection(connection -> connection.getResources().eventBus().publish(new ClusterTopologyChangedEvent(List.of(), List.of()))); + + getRedisCluster().useWriteCluster(connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(accountUuid, deviceId), + UUID.randomUUID().toString())); + + synchronized (displaced) { + while (!displaced.get()) { + displaced.wait(); + } + } + } finally { + clientPresenceManager.stop(); + } + } + + @Test + public void testClearPresence() { + final UUID accountUuid = UUID.randomUUID(); + final long deviceId = 1; + + assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId)); + + clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP); + assertTrue(clientPresenceManager.clearPresence(accountUuid, deviceId)); + + clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP); + getRedisCluster().useWriteCluster(connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(accountUuid, deviceId), + UUID.randomUUID().toString())); + + assertFalse(clientPresenceManager.clearPresence(accountUuid, deviceId)); + } + + @Test + public void testPruneMissingPeers() { + final String presentPeerId = UUID.randomUUID().toString(); + final String missingPeerId = UUID.randomUUID().toString(); + + getRedisCluster().useWriteCluster(connection -> { + connection.sync().sadd(ClientPresenceManager.MANAGER_SET_KEY, presentPeerId); + connection.sync().sadd(ClientPresenceManager.MANAGER_SET_KEY, missingPeerId); + }); + + for (int i = 0; i < 10; i++) { + addClientPresence(presentPeerId); + addClientPresence(missingPeerId); + } + + getRedisCluster().usePubSubConnection(connection -> connection.sync().masters().commands().subscribe(ClientPresenceManager.getManagerPresenceChannel(presentPeerId))); + + clientPresenceManager.pruneMissingPeers(); + + assertEquals(1, (long)getRedisCluster().withWriteCluster(connection -> connection.sync().exists(ClientPresenceManager.getConnectedClientSetKey(presentPeerId)))); + assertTrue(getRedisCluster().withReadCluster(connection -> connection.sync().sismember(ClientPresenceManager.MANAGER_SET_KEY, presentPeerId))); + + assertEquals(0, (long)getRedisCluster().withReadCluster(connection -> connection.sync().exists(ClientPresenceManager.getConnectedClientSetKey(missingPeerId)))); + assertFalse(getRedisCluster().withReadCluster(connection -> connection.sync().sismember(ClientPresenceManager.MANAGER_SET_KEY, missingPeerId))); + } + + private void addClientPresence(final String managerId) { + final String clientPresenceKey = ClientPresenceManager.getPresenceKey(UUID.randomUUID(), 7); + + getRedisCluster().useWriteCluster(connection -> { + connection.sync().set(clientPresenceKey, managerId); + connection.sync().sadd(ClientPresenceManager.getConnectedClientSetKey(managerId), clientPresenceKey); + }); + } + + @Test + public void testClearAllOnStop() { + final int localAccounts = 10; + final UUID[] localUuids = new UUID[localAccounts]; + final long[] localDeviceIds = new long[localAccounts]; + + for (int i = 0; i < localAccounts; i++) { + localUuids[i] = UUID.randomUUID(); + localDeviceIds[i] = i; + + clientPresenceManager.setPresent(localUuids[i], localDeviceIds[i], NO_OP); + } + + final UUID displacedAccountUuid = UUID.randomUUID(); + final long displacedAccountDeviceId = 7; + + clientPresenceManager.setPresent(displacedAccountUuid, displacedAccountDeviceId, NO_OP); + getRedisCluster().useWriteCluster(connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(displacedAccountUuid, displacedAccountDeviceId), + UUID.randomUUID().toString())); + + clientPresenceManager.stop(); + + for (int i = 0; i < localAccounts; i++) { + localUuids[i] = UUID.randomUUID(); + localDeviceIds[i] = i; + + assertFalse(clientPresenceManager.isPresent(localUuids[i], localDeviceIds[i])); + } + + assertTrue(clientPresenceManager.isPresent(displacedAccountUuid, displacedAccountDeviceId)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 1ecd26524..a1382e019 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -11,6 +11,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; +import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.WebsocketSender; @@ -69,7 +70,7 @@ public class WebSocketConnectionTest { public void testCredentials() throws Exception { MessagesManager storedMessages = mock(MessagesManager.class); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); - AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(pushSender, receiptSender, storedMessages, pubSubManager, apnFallbackManager); + AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(pushSender, receiptSender, storedMessages, pubSubManager, apnFallbackManager, mock(ClientPresenceManager.class)); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))