diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 383673ec5..f46eef3e4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -720,6 +720,7 @@ public class WhisperServerService extends Application deviceIds); +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java new file mode 100644 index 000000000..e97a92cb2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java @@ -0,0 +1,165 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import com.google.protobuf.InvalidProtocolBufferException; +import io.dropwizard.lifecycle.Managed; +import io.lettuce.core.pubsub.RedisPubSubAdapter; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.Executor; +import javax.annotation.Nullable; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +/** + * A disconnection request manager broadcasts and dispatches requests for servers to close authenticated connections + * from specific clients. + * + * @see DisconnectionRequestListener + */ +public class DisconnectionRequestManager extends RedisPubSubAdapter implements Managed { + + private final FaultTolerantRedisClient pubSubClient; + private final Executor listenerEventExecutor; + + // We expect just a couple listeners to get added at startup time and not at all at steady-state. There are several + // reasonable ways to model this, but a copy-on-write list gives us good flexibility with minimal performance cost. + private final List listeners = new CopyOnWriteArrayList<>(); + + @Nullable + private FaultTolerantPubSubConnection pubSubConnection; + + private static final byte[] DISCONNECTION_REQUEST_CHANNEL = "disconnection_requests".getBytes(StandardCharsets.UTF_8); + + private static final Counter DISCONNECTION_REQUESTS_SENT_COUNTER = + Metrics.counter(MetricsUtil.name(DisconnectionRequestManager.class, "requestsSent")); + + private static final Counter DISCONNECTION_REQUESTS_RECEIVED_COUNTER = + Metrics.counter(MetricsUtil.name(DisconnectionRequestManager.class, "requestsReceived")); + + private static final Logger logger = LoggerFactory.getLogger(DisconnectionRequestManager.class); + + public DisconnectionRequestManager(final FaultTolerantRedisClient pubSubClient, + final Executor listenerEventExecutor) { + + this.pubSubClient = pubSubClient; + this.listenerEventExecutor = listenerEventExecutor; + } + + @Override + public synchronized void start() { + this.pubSubConnection = pubSubClient.createBinaryPubSubConnection(); + this.pubSubConnection.usePubSubConnection(connection -> { + connection.addListener(this); + connection.sync().subscribe(DISCONNECTION_REQUEST_CHANNEL); + }); + } + + @Override + public synchronized void stop() { + if (pubSubConnection != null) { + pubSubConnection.usePubSubConnection(connection -> { + connection.removeListener(this); + connection.close(); + }); + } + + pubSubConnection = null; + } + + /** + * Adds a listener for disconnection requests. Listeners will receive all broadcast disconnection requests regardless + * of whether the device in connection is connected to this server. + * + * @param listener the listener to register + */ + public void addListener(final DisconnectionRequestListener listener) { + listeners.add(listener); + } + + /** + * Broadcasts a request to close all connections associated with the given account identifier to all servers. + * + * @param accountIdentifier the account for which to close connections + * + * @return a future that completes when the request has been broadcast + */ + public CompletionStage requestDisconnection(final UUID accountIdentifier) { + return requestDisconnection(accountIdentifier, Collections.emptyList()); + } + + /** + * Broadcasts a request to close connections associated with the given account identifier and device IDs to all + * servers. + * + * @param accountIdentifier the account for which to close connections + * @param deviceIds the device IDs for which to close connections + * + * @return a future that completes when the request has been broadcast + */ + public CompletionStage requestDisconnection(final UUID accountIdentifier, final Collection deviceIds) { + final DisconnectionRequest disconnectionRequest = DisconnectionRequest.newBuilder() + .setAccountIdentifier(UUIDUtil.toByteString(accountIdentifier)) + .addAllDeviceIds(deviceIds.stream().mapToInt(Byte::intValue).boxed().toList()) + .build(); + + return pubSubClient.withBinaryConnection(connection -> + connection.async().publish(DISCONNECTION_REQUEST_CHANNEL, disconnectionRequest.toByteArray())) + .toCompletableFuture() + .thenRun(DISCONNECTION_REQUESTS_SENT_COUNTER::increment); + } + + @Override + public void message(final byte[] channel, final byte[] message) { + final UUID accountIdentifier; + final List deviceIds; + + try { + final DisconnectionRequest disconnectionRequest = DisconnectionRequest.parseFrom(message); + DISCONNECTION_REQUESTS_RECEIVED_COUNTER.increment(); + + accountIdentifier = UUIDUtil.fromByteString(disconnectionRequest.getAccountIdentifier()); + deviceIds = disconnectionRequest.getDeviceIdsCount() > 0 + ? disconnectionRequest.getDeviceIdsList().stream() + .map(deviceIdInt -> { + if (deviceIdInt == null || deviceIdInt < Device.PRIMARY_ID || deviceIdInt > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Invalid device ID: " + deviceIdInt); + } + + return deviceIdInt.byteValue(); + }) + .toList() + : Device.ALL_POSSIBLE_DEVICE_IDS; + } catch (final InvalidProtocolBufferException e) { + logger.error("Could not parse disconnection request protobuf", e); + return; + } catch (final IllegalArgumentException e) { + logger.error("Could not parse part of disconnection request", e); + return; + } + + for (final DisconnectionRequestListener listener : listeners) { + try { + listenerEventExecutor.execute(() -> listener.handleDisconnectionRequest(accountIdentifier, deviceIds)); + } catch (final Exception e) { + logger.warn("Listener failed to handle disconnection request", e); + } + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index 215af6e77..d20b9c5f3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -259,6 +259,7 @@ record CommandDependencies( Clock.systemUTC()); environment.lifecycle().manage(apnSender); + environment.lifecycle().manage(disconnectionRequestManager); environment.lifecycle().manage(webSocketConnectionEventManager); environment.lifecycle().manage(new ManagedAwsCrt()); diff --git a/service/src/main/proto/DisconnectionRequests.proto b/service/src/main/proto/DisconnectionRequests.proto new file mode 100644 index 000000000..78ac8c076 --- /dev/null +++ b/service/src/main/proto/DisconnectionRequests.proto @@ -0,0 +1,15 @@ +/** + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +syntax = "proto3"; + +package org.signal.chat.auth; + +option java_package = "org.whispersystems.textsecuregcm.auth"; +option java_multiple_files = true; + +message DisconnectionRequest { + bytes account_identifier = 1; + repeated uint32 device_ids = 2; +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java new file mode 100644 index 000000000..963018524 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java @@ -0,0 +1,99 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Collection; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.redis.RedisServerExtension; +import org.whispersystems.textsecuregcm.storage.Device; + +@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) +class DisconnectionRequestManagerTest { + + private DisconnectionRequestManager disconnectionRequestManager; + + @RegisterExtension + static final RedisServerExtension REDIS_EXTENSION = RedisServerExtension.builder().build(); + + private static class DisconnectionRequestTestListener implements DisconnectionRequestListener { + + private final CountDownLatch requestLatch = new CountDownLatch(1); + + private UUID accountIdentifier; + private Collection deviceIds; + + @Override + public void handleDisconnectionRequest(final UUID accountIdentifier, final Collection deviceIds) { + this.accountIdentifier = accountIdentifier; + this.deviceIds = deviceIds; + + requestLatch.countDown(); + } + + public UUID getAccountIdentifier() { + return accountIdentifier; + } + + public Collection getDeviceIds() { + return deviceIds; + } + + public void waitForRequest() throws InterruptedException { + requestLatch.await(); + } + } + + @BeforeEach + void setUp() { + disconnectionRequestManager = new DisconnectionRequestManager(REDIS_EXTENSION.getRedisClient(), Runnable::run); + disconnectionRequestManager.start(); + } + + @AfterEach + void tearDown() { + disconnectionRequestManager.stop(); + } + + @Test + void requestDisconnection() throws InterruptedException { + final UUID accountIdentifier = UUID.randomUUID(); + final List deviceIds = List.of(Device.PRIMARY_ID, (byte) (Device.PRIMARY_ID + 1)); + + final DisconnectionRequestTestListener listener = new DisconnectionRequestTestListener(); + + disconnectionRequestManager.addListener(listener); + disconnectionRequestManager.requestDisconnection(accountIdentifier, deviceIds).toCompletableFuture().join(); + + listener.waitForRequest(); + + assertEquals(accountIdentifier, listener.getAccountIdentifier()); + assertEquals(deviceIds, listener.getDeviceIds()); + } + + @Test + void requestDisconnectionAllDevices() throws InterruptedException { + final UUID accountIdentifier = UUID.randomUUID(); + + final DisconnectionRequestTestListener listener = new DisconnectionRequestTestListener(); + + disconnectionRequestManager.addListener(listener); + disconnectionRequestManager.requestDisconnection(accountIdentifier).toCompletableFuture().join(); + + listener.waitForRequest(); + + assertEquals(accountIdentifier, listener.getAccountIdentifier()); + assertEquals(Device.ALL_POSSIBLE_DEVICE_IDS, listener.getDeviceIds()); + } +}