Reframe "connection ID" as "server ID" to avoid double-removing clients

This commit is contained in:
Jon Chambers 2024-11-05 18:12:22 -05:00 committed by Jon Chambers
parent d8f53954d0
commit 3e36a49142
4 changed files with 47 additions and 79 deletions

View File

@ -45,13 +45,22 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
private final FaultTolerantRedisClusterClient clusterClient;
private final Executor listenerEventExecutor;
private final UUID serverId = UUID.randomUUID();
private final byte[] CLIENT_CONNECTED_EVENT_BYTES = ClientEvent.newBuilder()
.setClientConnected(ClientConnectedEvent.newBuilder()
.setServerId(UUIDUtil.toByteString(serverId))
.build())
.build()
.toByteArray();
private final ExperimentEnrollmentManager experimentEnrollmentManager;
static final String EXPERIMENT_NAME = "pubSubPresenceManager";
@Nullable
private FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubConnection;
private final Map<AccountAndDeviceIdentifier, ConnectionIdAndListener> listenersByAccountAndDeviceIdentifier;
private final Map<AccountAndDeviceIdentifier, ClientEventListener> listenersByAccountAndDeviceIdentifier;
private static final byte[] NEW_MESSAGE_EVENT_BYTES = ClientEvent.newBuilder()
.setNewMessageAvailable(NewMessageAvailableEvent.getDefaultInstance())
@ -80,9 +89,6 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
private record AccountAndDeviceIdentifier(UUID accountIdentifier, byte deviceId) {
}
private record ConnectionIdAndListener(UUID connectionIdentifier, ClientEventListener listener) {
}
public PubSubClientEventManager(final FaultTolerantRedisClusterClient clusterClient,
final Executor listenerEventExecutor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
@ -154,15 +160,15 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
// as adding/removing listeners from the map and helps us avoid races and conflicts. Note that the enqueued
// operation is asynchronous; we're not blocking on it in the scope of the `compute` operation.
listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId),
(key, existingIdAndListener) -> {
(key, existingListener) -> {
subscribeFuture.set(pubSubConnection.withPubSubConnection(connection ->
connection.async().ssubscribe(clientPresenceKey)));
if (existingIdAndListener != null) {
displacedListener.set(existingIdAndListener.listener());
if (existingListener != null) {
displacedListener.set(existingListener);
}
return new ConnectionIdAndListener(connectionId, listener);
return listener;
});
if (displacedListener.get() != null) {
@ -171,7 +177,7 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
return subscribeFuture.get()
.thenCompose(ignored -> clusterClient.withBinaryCluster(connection -> connection.async()
.spublish(clientPresenceKey, buildClientConnectedMessage(connectionId))))
.spublish(clientPresenceKey, CLIENT_CONNECTED_EVENT_BYTES)))
.handle((ignored, throwable) -> {
if (throwable != null) {
PUBLISH_CLIENT_CONNECTION_EVENT_ERROR_COUNTER.increment();
@ -182,17 +188,15 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
}
/**
* Removes the "presence" for the given device. The presence is removed if and only if the given connection ID matches
* the connection ID for the currently-registered presence. Callers should call this method when they have closed or
* intend to close the client's underlying network connection.
* Removes the "presence" for the given device. Callers should call this method when they have been notified that
* the client's underlying network connection has been closed.
*
* @param accountIdentifier the identifier of the account for the disconnected device
* @param deviceId the ID of the disconnected device within the given account
* @param connectionId the ID of the connection that has been closed (or will be closed)
*
* @return a future that completes when the presence has been removed
*/
public CompletionStage<Void> handleClientDisconnected(final UUID accountIdentifier, final byte deviceId, final UUID connectionId) {
public CompletionStage<Void> handleClientDisconnected(final UUID accountIdentifier, final byte deviceId) {
if (pubSubConnection == null) {
throw new IllegalStateException("Presence manager not started");
}
@ -214,35 +218,19 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
// as adding/removing listeners from the map and helps us avoid races and conflicts. Note that the enqueued
// operation is asynchronous; we're not blocking on it in the scope of the `compute` operation.
listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId),
(ignored, existingIdAndListener) -> {
final ConnectionIdAndListener remainingIdAndListener;
(ignored, existingListener) -> {
unsubscribeFuture.set(pubSubConnection.withPubSubConnection(connection ->
connection.async().sunsubscribe(getClientPresenceKey(accountIdentifier, deviceId)))
.thenRun(Util.NOOP));
if (existingIdAndListener == null) {
remainingIdAndListener = null;
} else if (existingIdAndListener.connectionIdentifier().equals(connectionId)) {
remainingIdAndListener = null;
} else {
remainingIdAndListener = existingIdAndListener;
}
if (remainingIdAndListener == null) {
// Only unsubscribe if there's no listener remaining
unsubscribeFuture.set(pubSubConnection.withPubSubConnection(connection ->
connection.async().sunsubscribe(getClientPresenceKey(accountIdentifier, deviceId)))
.thenRun(Util.NOOP));
} else {
unsubscribeFuture.set(CompletableFuture.completedFuture(null));
}
return remainingIdAndListener;
return null;
});
return unsubscribeFuture.get()
.whenComplete((ignored, throwable) -> {
if (throwable != null) {
UNSUBSCRIBE_ERROR_COUNTER.increment();
}
});
return unsubscribeFuture.get().whenComplete((ignored, throwable) -> {
if (throwable != null) {
UNSUBSCRIBE_ERROR_COUNTER.increment();
}
});
}
/**
@ -355,24 +343,22 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
final AccountAndDeviceIdentifier accountAndDeviceIdentifier = parseClientPresenceKey(shardChannel);
@Nullable final ConnectionIdAndListener connectionIdAndListener =
@Nullable final ClientEventListener listener =
listenersByAccountAndDeviceIdentifier.get(accountAndDeviceIdentifier);
if (connectionIdAndListener != null) {
if (listener != null) {
switch (clientEvent.getEventCase()) {
case NEW_MESSAGE_AVAILABLE -> connectionIdAndListener.listener().handleNewMessageAvailable();
case NEW_MESSAGE_AVAILABLE -> listener.handleNewMessageAvailable();
case CLIENT_CONNECTED -> {
final UUID connectionId = UUIDUtil.fromByteString(clientEvent.getClientConnected().getConnectionId());
if (!connectionIdAndListener.connectionIdentifier().equals(connectionId)) {
listenerEventExecutor.execute(() ->
connectionIdAndListener.listener().handleConnectionDisplaced(true));
// Only act on new connections to other presence manager instances; we'll learn about displacements in THIS
// instance when we update the listener map in `handleClientConnected`
if (!this.serverId.equals(UUIDUtil.fromByteString(clientEvent.getClientConnected().getServerId()))) {
listenerEventExecutor.execute(() -> listener.handleConnectionDisplaced(true));
}
}
case DISCONNECT_REQUESTED -> listenerEventExecutor.execute(() ->
connectionIdAndListener.listener().handleConnectionDisplaced(false));
case DISCONNECT_REQUESTED -> listenerEventExecutor.execute(() -> listener.handleConnectionDisplaced(false));
default -> logger.warn("Unexpected client event type: {}", clientEvent.getClass());
}
@ -381,15 +367,6 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
}
}
private static byte[] buildClientConnectedMessage(final UUID connectionId) {
return ClientEvent.newBuilder()
.setClientConnected(ClientConnectedEvent.newBuilder()
.setConnectionId(UUIDUtil.toByteString(connectionId))
.build())
.build()
.toByteArray();
}
@VisibleForTesting
static byte[] getClientPresenceKey(final UUID accountIdentifier, final byte deviceId) {
return ("client_presence::{" + accountIdentifier + ":" + deviceId + "}").getBytes(StandardCharsets.UTF_8);

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Tags;
import java.util.UUID;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
@ -58,8 +57,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
private transient UUID connectionId;
public AuthenticatedConnectListener(ReceiptSender receiptSender,
MessagesManager messagesManager,
MessageMetrics messageMetrics,
@ -128,11 +125,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// It's preferable to start sending push notifications as soon as possible.
RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection));
if (connectionId != null) {
pubSubClientEventManager.handleClientDisconnected(auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
connectionId);
}
pubSubClientEventManager.handleClientDisconnected(auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId());
// Next, we stop listening for inbound messages. If a message arrives after this call, the websocket connection
// will not be notified and will not change its state, but that's okay because it has already closed and
@ -160,8 +154,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// Finally, we register this client's presence, which suppresses push notifications. We do this last because
// receiving extra push notifications is generally preferable to missing out on a push notification.
clientPresenceManager.setPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection);
pubSubClientEventManager.handleClientConnected(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), null)
.thenAccept(connectionId -> this.connectionId = connectionId);
pubSubClientEventManager.handleClientConnected(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), null);
renewPresenceFutureReference.set(scheduledExecutorService.scheduleAtFixedRate(() -> RedisOperation.unchecked(() ->
clientPresenceManager.renewPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())),

View File

@ -27,7 +27,7 @@ message NewMessageAvailableEvent {
* Indicates that a client has connected to the presence system.
*/
message ClientConnectedEvent {
bytes connection_id = 1;
bytes server_id = 1;
}
/**

View File

@ -175,13 +175,12 @@ class PubSubClientEventManagerTest {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final UUID connectionId =
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter())
.toCompletableFuture().join();
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter())
.toCompletableFuture().join();
assertTrue(localPresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId, connectionId).toCompletableFuture().join();
localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId).toCompletableFuture().join();
assertFalse(localPresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
}
@ -194,15 +193,14 @@ class PubSubClientEventManagerTest {
assertFalse(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
final UUID connectionId =
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter())
.toCompletableFuture()
.join();
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter())
.toCompletableFuture()
.join();
assertTrue(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId, connectionId)
localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId)
.toCompletableFuture()
.join();