Introduce `DisconnectionRequestManager`
This commit is contained in:
parent
1323b42169
commit
7e861f388f
|
@ -720,6 +720,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
environment.lifecycle().manage(apnSender);
|
environment.lifecycle().manage(apnSender);
|
||||||
environment.lifecycle().manage(pushNotificationScheduler);
|
environment.lifecycle().manage(pushNotificationScheduler);
|
||||||
environment.lifecycle().manage(provisioningManager);
|
environment.lifecycle().manage(provisioningManager);
|
||||||
|
environment.lifecycle().manage(disconnectionRequestManager);
|
||||||
environment.lifecycle().manage(webSocketConnectionEventManager);
|
environment.lifecycle().manage(webSocketConnectionEventManager);
|
||||||
environment.lifecycle().manage(currencyManager);
|
environment.lifecycle().manage(currencyManager);
|
||||||
environment.lifecycle().manage(registrationServiceClient);
|
environment.lifecycle().manage(registrationServiceClient);
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.auth;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A disconnection request listener receives and handles requests to close authenticated client network connections.
|
||||||
|
*/
|
||||||
|
public interface DisconnectionRequestListener {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles a request to close authenticated network connections for one or more authenticated devices. Requests are
|
||||||
|
* dispatched on dedicated threads, and implementations may safely block.
|
||||||
|
*
|
||||||
|
* @param accountIdentifier the account identifier for which to close authenticated connections
|
||||||
|
* @param deviceIds the device IDs within the identified account for which to close authenticated connections
|
||||||
|
*/
|
||||||
|
void handleDisconnectionRequest(UUID accountIdentifier, Collection<Byte> deviceIds);
|
||||||
|
}
|
|
@ -0,0 +1,165 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.auth;
|
||||||
|
|
||||||
|
import com.google.protobuf.InvalidProtocolBufferException;
|
||||||
|
import io.dropwizard.lifecycle.Managed;
|
||||||
|
import io.lettuce.core.pubsub.RedisPubSubAdapter;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletionStage;
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
import java.util.concurrent.Executor;
|
||||||
|
import javax.annotation.Nullable;
|
||||||
|
import io.micrometer.core.instrument.Counter;
|
||||||
|
import io.micrometer.core.instrument.Metrics;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||||
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
|
||||||
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
|
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A disconnection request manager broadcasts and dispatches requests for servers to close authenticated connections
|
||||||
|
* from specific clients.
|
||||||
|
*
|
||||||
|
* @see DisconnectionRequestListener
|
||||||
|
*/
|
||||||
|
public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte[]> implements Managed {
|
||||||
|
|
||||||
|
private final FaultTolerantRedisClient pubSubClient;
|
||||||
|
private final Executor listenerEventExecutor;
|
||||||
|
|
||||||
|
// We expect just a couple listeners to get added at startup time and not at all at steady-state. There are several
|
||||||
|
// reasonable ways to model this, but a copy-on-write list gives us good flexibility with minimal performance cost.
|
||||||
|
private final List<DisconnectionRequestListener> listeners = new CopyOnWriteArrayList<>();
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
private FaultTolerantPubSubConnection<byte[], byte[]> pubSubConnection;
|
||||||
|
|
||||||
|
private static final byte[] DISCONNECTION_REQUEST_CHANNEL = "disconnection_requests".getBytes(StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
private static final Counter DISCONNECTION_REQUESTS_SENT_COUNTER =
|
||||||
|
Metrics.counter(MetricsUtil.name(DisconnectionRequestManager.class, "requestsSent"));
|
||||||
|
|
||||||
|
private static final Counter DISCONNECTION_REQUESTS_RECEIVED_COUNTER =
|
||||||
|
Metrics.counter(MetricsUtil.name(DisconnectionRequestManager.class, "requestsReceived"));
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(DisconnectionRequestManager.class);
|
||||||
|
|
||||||
|
public DisconnectionRequestManager(final FaultTolerantRedisClient pubSubClient,
|
||||||
|
final Executor listenerEventExecutor) {
|
||||||
|
|
||||||
|
this.pubSubClient = pubSubClient;
|
||||||
|
this.listenerEventExecutor = listenerEventExecutor;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized void start() {
|
||||||
|
this.pubSubConnection = pubSubClient.createBinaryPubSubConnection();
|
||||||
|
this.pubSubConnection.usePubSubConnection(connection -> {
|
||||||
|
connection.addListener(this);
|
||||||
|
connection.sync().subscribe(DISCONNECTION_REQUEST_CHANNEL);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized void stop() {
|
||||||
|
if (pubSubConnection != null) {
|
||||||
|
pubSubConnection.usePubSubConnection(connection -> {
|
||||||
|
connection.removeListener(this);
|
||||||
|
connection.close();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pubSubConnection = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds a listener for disconnection requests. Listeners will receive all broadcast disconnection requests regardless
|
||||||
|
* of whether the device in connection is connected to this server.
|
||||||
|
*
|
||||||
|
* @param listener the listener to register
|
||||||
|
*/
|
||||||
|
public void addListener(final DisconnectionRequestListener listener) {
|
||||||
|
listeners.add(listener);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Broadcasts a request to close all connections associated with the given account identifier to all servers.
|
||||||
|
*
|
||||||
|
* @param accountIdentifier the account for which to close connections
|
||||||
|
*
|
||||||
|
* @return a future that completes when the request has been broadcast
|
||||||
|
*/
|
||||||
|
public CompletionStage<Void> requestDisconnection(final UUID accountIdentifier) {
|
||||||
|
return requestDisconnection(accountIdentifier, Collections.emptyList());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Broadcasts a request to close connections associated with the given account identifier and device IDs to all
|
||||||
|
* servers.
|
||||||
|
*
|
||||||
|
* @param accountIdentifier the account for which to close connections
|
||||||
|
* @param deviceIds the device IDs for which to close connections
|
||||||
|
*
|
||||||
|
* @return a future that completes when the request has been broadcast
|
||||||
|
*/
|
||||||
|
public CompletionStage<Void> requestDisconnection(final UUID accountIdentifier, final Collection<Byte> deviceIds) {
|
||||||
|
final DisconnectionRequest disconnectionRequest = DisconnectionRequest.newBuilder()
|
||||||
|
.setAccountIdentifier(UUIDUtil.toByteString(accountIdentifier))
|
||||||
|
.addAllDeviceIds(deviceIds.stream().mapToInt(Byte::intValue).boxed().toList())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return pubSubClient.withBinaryConnection(connection ->
|
||||||
|
connection.async().publish(DISCONNECTION_REQUEST_CHANNEL, disconnectionRequest.toByteArray()))
|
||||||
|
.toCompletableFuture()
|
||||||
|
.thenRun(DISCONNECTION_REQUESTS_SENT_COUNTER::increment);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void message(final byte[] channel, final byte[] message) {
|
||||||
|
final UUID accountIdentifier;
|
||||||
|
final List<Byte> deviceIds;
|
||||||
|
|
||||||
|
try {
|
||||||
|
final DisconnectionRequest disconnectionRequest = DisconnectionRequest.parseFrom(message);
|
||||||
|
DISCONNECTION_REQUESTS_RECEIVED_COUNTER.increment();
|
||||||
|
|
||||||
|
accountIdentifier = UUIDUtil.fromByteString(disconnectionRequest.getAccountIdentifier());
|
||||||
|
deviceIds = disconnectionRequest.getDeviceIdsCount() > 0
|
||||||
|
? disconnectionRequest.getDeviceIdsList().stream()
|
||||||
|
.map(deviceIdInt -> {
|
||||||
|
if (deviceIdInt == null || deviceIdInt < Device.PRIMARY_ID || deviceIdInt > Byte.MAX_VALUE) {
|
||||||
|
throw new IllegalArgumentException("Invalid device ID: " + deviceIdInt);
|
||||||
|
}
|
||||||
|
|
||||||
|
return deviceIdInt.byteValue();
|
||||||
|
})
|
||||||
|
.toList()
|
||||||
|
: Device.ALL_POSSIBLE_DEVICE_IDS;
|
||||||
|
} catch (final InvalidProtocolBufferException e) {
|
||||||
|
logger.error("Could not parse disconnection request protobuf", e);
|
||||||
|
return;
|
||||||
|
} catch (final IllegalArgumentException e) {
|
||||||
|
logger.error("Could not parse part of disconnection request", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (final DisconnectionRequestListener listener : listeners) {
|
||||||
|
try {
|
||||||
|
listenerEventExecutor.execute(() -> listener.handleDisconnectionRequest(accountIdentifier, deviceIds));
|
||||||
|
} catch (final Exception e) {
|
||||||
|
logger.warn("Listener failed to handle disconnection request", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -259,6 +259,7 @@ record CommandDependencies(
|
||||||
Clock.systemUTC());
|
Clock.systemUTC());
|
||||||
|
|
||||||
environment.lifecycle().manage(apnSender);
|
environment.lifecycle().manage(apnSender);
|
||||||
|
environment.lifecycle().manage(disconnectionRequestManager);
|
||||||
environment.lifecycle().manage(webSocketConnectionEventManager);
|
environment.lifecycle().manage(webSocketConnectionEventManager);
|
||||||
environment.lifecycle().manage(new ManagedAwsCrt());
|
environment.lifecycle().manage(new ManagedAwsCrt());
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package org.signal.chat.auth;
|
||||||
|
|
||||||
|
option java_package = "org.whispersystems.textsecuregcm.auth";
|
||||||
|
option java_multiple_files = true;
|
||||||
|
|
||||||
|
message DisconnectionRequest {
|
||||||
|
bytes account_identifier = 1;
|
||||||
|
repeated uint32 device_ids = 2;
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.auth;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.Timeout;
|
||||||
|
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||||
|
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
|
|
||||||
|
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
|
||||||
|
class DisconnectionRequestManagerTest {
|
||||||
|
|
||||||
|
private DisconnectionRequestManager disconnectionRequestManager;
|
||||||
|
|
||||||
|
@RegisterExtension
|
||||||
|
static final RedisServerExtension REDIS_EXTENSION = RedisServerExtension.builder().build();
|
||||||
|
|
||||||
|
private static class DisconnectionRequestTestListener implements DisconnectionRequestListener {
|
||||||
|
|
||||||
|
private final CountDownLatch requestLatch = new CountDownLatch(1);
|
||||||
|
|
||||||
|
private UUID accountIdentifier;
|
||||||
|
private Collection<Byte> deviceIds;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handleDisconnectionRequest(final UUID accountIdentifier, final Collection<Byte> deviceIds) {
|
||||||
|
this.accountIdentifier = accountIdentifier;
|
||||||
|
this.deviceIds = deviceIds;
|
||||||
|
|
||||||
|
requestLatch.countDown();
|
||||||
|
}
|
||||||
|
|
||||||
|
public UUID getAccountIdentifier() {
|
||||||
|
return accountIdentifier;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Collection<Byte> getDeviceIds() {
|
||||||
|
return deviceIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void waitForRequest() throws InterruptedException {
|
||||||
|
requestLatch.await();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
disconnectionRequestManager = new DisconnectionRequestManager(REDIS_EXTENSION.getRedisClient(), Runnable::run);
|
||||||
|
disconnectionRequestManager.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() {
|
||||||
|
disconnectionRequestManager.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void requestDisconnection() throws InterruptedException {
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final List<Byte> deviceIds = List.of(Device.PRIMARY_ID, (byte) (Device.PRIMARY_ID + 1));
|
||||||
|
|
||||||
|
final DisconnectionRequestTestListener listener = new DisconnectionRequestTestListener();
|
||||||
|
|
||||||
|
disconnectionRequestManager.addListener(listener);
|
||||||
|
disconnectionRequestManager.requestDisconnection(accountIdentifier, deviceIds).toCompletableFuture().join();
|
||||||
|
|
||||||
|
listener.waitForRequest();
|
||||||
|
|
||||||
|
assertEquals(accountIdentifier, listener.getAccountIdentifier());
|
||||||
|
assertEquals(deviceIds, listener.getDeviceIds());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void requestDisconnectionAllDevices() throws InterruptedException {
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
|
||||||
|
final DisconnectionRequestTestListener listener = new DisconnectionRequestTestListener();
|
||||||
|
|
||||||
|
disconnectionRequestManager.addListener(listener);
|
||||||
|
disconnectionRequestManager.requestDisconnection(accountIdentifier).toCompletableFuture().join();
|
||||||
|
|
||||||
|
listener.waitForRequest();
|
||||||
|
|
||||||
|
assertEquals(accountIdentifier, listener.getAccountIdentifier());
|
||||||
|
assertEquals(Device.ALL_POSSIBLE_DEVICE_IDS, listener.getDeviceIds());
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue