From 09fd5e881930738099be2d8c380ec8a1a4b27ae9 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 11 Nov 2024 11:19:16 -0500 Subject: [PATCH] Shift authority for disconnection requests to `DisconnectionRequestManager` --- .../textsecuregcm/WhisperServerService.java | 4 ++++ .../grpc/net/GrpcClientConnectionManager.java | 12 +++++++++++- .../push/WebSocketConnectionEventManager.java | 17 +++++++++++++++-- .../WebSocketConnectionEventManagerTest.java | 10 +++------- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index aa35ec47f..7eb2b58b4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -650,6 +650,8 @@ public class WhisperServerService extends Application remoteChannelsByLocalAddress = new ConcurrentHashMap<>(); private final Map> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>(); @@ -215,4 +218,11 @@ public class GrpcClientConnectionManager { })); }); } + + @Override + public void handleDisconnectionRequest(final UUID accountIdentifier, final Collection deviceIds) { + deviceIds.stream() + .map(deviceId -> new AuthenticatedDevice(accountIdentifier, deviceId)) + .forEach(this::closeConnection); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManager.java index c20df6bc3..2c4c83477 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManager.java @@ -21,6 +21,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; @@ -30,6 +31,7 @@ import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection; @@ -55,7 +57,8 @@ import org.whispersystems.textsecuregcm.util.Util; * @see WebSocketConnectionEventListener * @see org.whispersystems.textsecuregcm.storage.MessagesManager#insert(UUID, byte, MessageProtos.Envelope) */ -public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter implements Managed { +public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter implements Managed, + DisconnectionRequestListener { private final FaultTolerantRedisClusterClient clusterClient; private final Executor listenerEventExecutor; @@ -272,6 +275,14 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter deviceIds) { + deviceIds.stream() + .map(deviceId -> listenersByAccountAndDeviceIdentifier.get(new AccountAndDeviceIdentifier(accountIdentifier, deviceId))) + .filter(Objects::nonNull) + .forEach(listener -> listener.handleConnectionDisplaced(false)); + } + @VisibleForTesting void resubscribe(final ClusterTopologyChangedEvent clusterTopologyChangedEvent) { final boolean[] changedSlots = RedisClusterUtil.getChangedSlots(clusterTopologyChangedEvent); @@ -347,7 +358,9 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter listenerEventExecutor.execute(() -> listener.handleConnectionDisplaced(false)); + case DISCONNECT_REQUESTED -> { + // Handle events from `DisconnectionRequestManager` instead + } case MESSAGES_PERSISTED -> listenerEventExecutor.execute(listener::handleMessagesPersisted); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java index 93e665ada..8819e1730 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java @@ -161,9 +161,8 @@ class WebSocketConnectionEventManagerTest { assertFalse(remoteEventManager.isLocallyPresent(accountIdentifier, deviceId)); } - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void requestDisconnection(final boolean requestDisconnectionRemotely) throws InterruptedException { + @Test + void handleDisconnectionRequest() throws InterruptedException { final UUID accountIdentifier = UUID.randomUUID(); final byte firstDeviceId = Device.PRIMARY_ID; final byte secondDeviceId = firstDeviceId + 1; @@ -198,10 +197,7 @@ class WebSocketConnectionEventManagerTest { assertFalse(firstListenerDisplaced.get()); assertFalse(secondListenerDisplaced.get()); - final WebSocketConnectionEventManager displacingManager = - requestDisconnectionRemotely ? remoteEventManager : localEventManager; - - displacingManager.requestDisconnection(accountIdentifier, List.of(firstDeviceId)).toCompletableFuture().join(); + localEventManager.handleDisconnectionRequest(accountIdentifier, List.of(firstDeviceId)); synchronized (firstListenerDisplaced) { while (!firstListenerDisplaced.get()) {