Shift authority for disconnection requests to `DisconnectionRequestManager`

This commit is contained in:
Jon Chambers 2024-11-11 11:19:16 -05:00 committed by Jon Chambers
parent 81f3ba17c7
commit 09fd5e8819
4 changed files with 33 additions and 10 deletions

View File

@ -650,6 +650,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor =
new MessageDeliveryLoopMonitor(rateLimitersCluster);
disconnectionRequestManager.addListener(webSocketConnectionEventManager);
final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
accountsManager, disconnectionRequestManager, webSocketConnectionEventManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters);
@ -829,6 +831,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final GrpcClientConnectionManager grpcClientConnectionManager = new GrpcClientConnectionManager();
disconnectionRequestManager.addListener(grpcClientConnectionManager);
final ManagedDefaultEventLoopGroup localEventLoopGroup = new ManagedDefaultEventLoopGroup();
final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);

View File

@ -9,16 +9,19 @@ import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.util.AttributeKey;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
@ -30,7 +33,7 @@ import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate.
*/
public class GrpcClientConnectionManager {
public class GrpcClientConnectionManager implements DisconnectionRequestListener {
private final Map<LocalAddress, Channel> remoteChannelsByLocalAddress = new ConcurrentHashMap<>();
private final Map<AuthenticatedDevice, List<Channel>> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>();
@ -215,4 +218,11 @@ public class GrpcClientConnectionManager {
}));
});
}
@Override
public void handleDisconnectionRequest(final UUID accountIdentifier, final Collection<Byte> deviceIds) {
deviceIds.stream()
.map(deviceId -> new AuthenticatedDevice(accountIdentifier, deviceId))
.forEach(this::closeConnection);
}
}

View File

@ -21,6 +21,7 @@ import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
@ -30,6 +31,7 @@ import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
@ -55,7 +57,8 @@ import org.whispersystems.textsecuregcm.util.Util;
* @see WebSocketConnectionEventListener
* @see org.whispersystems.textsecuregcm.storage.MessagesManager#insert(UUID, byte, MessageProtos.Envelope)
*/
public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter<byte[], byte[]> implements Managed {
public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter<byte[], byte[]> implements Managed,
DisconnectionRequestListener {
private final FaultTolerantRedisClusterClient clusterClient;
private final Executor listenerEventExecutor;
@ -272,6 +275,14 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter<b
.toArray(CompletableFuture[]::new));
}
@Override
public void handleDisconnectionRequest(final UUID accountIdentifier, final Collection<Byte> deviceIds) {
deviceIds.stream()
.map(deviceId -> listenersByAccountAndDeviceIdentifier.get(new AccountAndDeviceIdentifier(accountIdentifier, deviceId)))
.filter(Objects::nonNull)
.forEach(listener -> listener.handleConnectionDisplaced(false));
}
@VisibleForTesting
void resubscribe(final ClusterTopologyChangedEvent clusterTopologyChangedEvent) {
final boolean[] changedSlots = RedisClusterUtil.getChangedSlots(clusterTopologyChangedEvent);
@ -347,7 +358,9 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter<b
}
}
case DISCONNECT_REQUESTED -> listenerEventExecutor.execute(() -> listener.handleConnectionDisplaced(false));
case DISCONNECT_REQUESTED -> {
// Handle events from `DisconnectionRequestManager` instead
}
case MESSAGES_PERSISTED -> listenerEventExecutor.execute(listener::handleMessagesPersisted);

View File

@ -161,9 +161,8 @@ class WebSocketConnectionEventManagerTest {
assertFalse(remoteEventManager.isLocallyPresent(accountIdentifier, deviceId));
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void requestDisconnection(final boolean requestDisconnectionRemotely) throws InterruptedException {
@Test
void handleDisconnectionRequest() throws InterruptedException {
final UUID accountIdentifier = UUID.randomUUID();
final byte firstDeviceId = Device.PRIMARY_ID;
final byte secondDeviceId = firstDeviceId + 1;
@ -198,10 +197,7 @@ class WebSocketConnectionEventManagerTest {
assertFalse(firstListenerDisplaced.get());
assertFalse(secondListenerDisplaced.get());
final WebSocketConnectionEventManager displacingManager =
requestDisconnectionRemotely ? remoteEventManager : localEventManager;
displacingManager.requestDisconnection(accountIdentifier, List.of(firstDeviceId)).toCompletableFuture().join();
localEventManager.handleDisconnectionRequest(accountIdentifier, List.of(firstDeviceId));
synchronized (firstListenerDisplaced) {
while (!firstListenerDisplaced.get()) {