From 697c380cd1ea7fe25083bb522539bf3666dc51f0 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Tue, 1 Sep 2020 16:55:50 -0400 Subject: [PATCH] Close websocket connections when displaced. --- .../websocket/AuthenticatedConnectListener.java | 5 ++--- .../textsecuregcm/websocket/WebSocketConnection.java | 10 +++++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index b0911fd4b..b64f6b0ab 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -33,7 +33,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private static final Timer durationTimer = metricRegistry.timer(name(WebSocketConnection.class, "connected_duration" )); private static final Timer unauthenticatedDurationTimer = metricRegistry.timer(name(WebSocketConnection.class, "unauthenticated_connection_duration")); private static final Counter openWebsocketCounter = metricRegistry.counter(name(WebSocketConnection.class, "open_websockets")); - private static final Meter explicitDisplacementMeter = metricRegistry.meter(name(WebSocketConnection.class, "explicitDisplacement")); private final PushSender pushSender; private final ReceiptSender receiptSender; @@ -74,7 +73,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { openWebsocketCounter.inc(); RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device)); - clientPresenceManager.setPresent(account.getUuid(), device.getId(), explicitDisplacementMeter::mark); + + clientPresenceManager.setPresent(account.getUuid(), device.getId(), connection); messagesManager.addMessageAvailabilityListener(account.getUuid(), device.getId(), connection); pubSubManager.publish(address, connectMessage); pubSubManager.subscribe(address, connection); @@ -95,4 +95,3 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { } } } - diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 404328544..7054187a0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -15,6 +15,7 @@ import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; +import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -39,7 +40,7 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") -public class WebSocketConnection implements DispatchChannel, MessageAvailabilityListener { +public class WebSocketConnection implements DispatchChannel, MessageAvailabilityListener, DisplacedPresenceListener { private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); public static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration")); @@ -49,6 +50,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability private static final Meter messagesPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesPersisted")); private static final Meter pubSubNewMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubNewMessage")); private static final Meter pubSubPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubPersisted")); + private static final Meter displacementMeter = metricRegistry.meter(name(WebSocketConnection.class, "explicitDisplacement")); private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); @@ -230,6 +232,12 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability messagesPersistedMeter.mark(); } + @Override + public void handleDisplacement() { + displacementMeter.mark(); + client.hardDisconnectQuietly(); + } + private static class StoredMessageInfo { private final long id; private final boolean cached;