Send disconnection requests after non-API device unlinks

This commit is contained in:
Jonathan Klabunde Tomer 2025-05-06 13:36:41 -07:00 committed by GitHub
parent 7a91c4d5b7
commit cc7b030a41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 44 additions and 7 deletions

View File

@ -580,6 +580,15 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
return Optional.of(aci); return Optional.of(aci);
} }
/**
* Unlink a device from the given account. The device will be immediately disconnected if it is
* connected to any chat frontend, but it is the caller's responsibility to make sure that the
* account's *other* devices are disconnected, either by use of
* {@link org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider} or
* directly by calling {@link DeviceDisconnectionManager#requestDisconnection}.
*
* @returns the updated Account
*/
public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) { public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) { if (deviceId == Device.PRIMARY_ID) {
throw new IllegalArgumentException("Cannot remove primary device"); throw new IllegalArgumentException("Cannot remove primary device");

View File

@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
@ -44,6 +45,7 @@ public class MessagePersister implements Managed {
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager; private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final ExperimentEnrollmentManager experimentEnrollmentManager; private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final DisconnectionRequestManager disconnectionRequestManager;
private final Duration persistDelay; private final Duration persistDelay;
@ -82,6 +84,7 @@ public class MessagePersister implements Managed {
final AccountsManager accountsManager, final AccountsManager accountsManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ExperimentEnrollmentManager experimentEnrollmentManager, final ExperimentEnrollmentManager experimentEnrollmentManager,
final DisconnectionRequestManager disconnectionRequestManager,
final Duration persistDelay, final Duration persistDelay,
final int dedicatedProcessWorkerThreadCount) { final int dedicatedProcessWorkerThreadCount) {
@ -90,6 +93,7 @@ public class MessagePersister implements Managed {
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.dynamicConfigurationManager = dynamicConfigurationManager; this.dynamicConfigurationManager = dynamicConfigurationManager;
this.experimentEnrollmentManager = experimentEnrollmentManager; this.experimentEnrollmentManager = experimentEnrollmentManager;
this.disconnectionRequestManager = disconnectionRequestManager;
this.persistDelay = persistDelay; this.persistDelay = persistDelay;
this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount]; this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount];
@ -257,9 +261,10 @@ public class MessagePersister implements Managed {
trimQueue(account, deviceId); trimQueue(account, deviceId);
throw new MessagePersistenceException("Could not persist due to an overfull queue. Trimmed primary queue, a subsequent retry may succeed"); throw new MessagePersistenceException("Could not persist due to an overfull queue. Trimmed primary queue, a subsequent retry may succeed");
} else { } else {
logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", account.getUuid(), logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", accountUuid, deviceId);
deviceId); accountsManager.removeDevice(account, deviceId)
accountsManager.removeDevice(account, deviceId).join(); .thenRun(() -> disconnectionRequestManager.requestDisconnection(accountUuid))
.join();
} }
} finally { } finally {
messagesCache.unlockQueueForPersistence(accountUuid, deviceId); messagesCache.unlockQueueForPersistence(accountUuid, deviceId);

View File

@ -77,6 +77,7 @@ record CommandDependencies(
AccountsManager accountsManager, AccountsManager accountsManager,
ProfilesManager profilesManager, ProfilesManager profilesManager,
ReportMessageManager reportMessageManager, ReportMessageManager reportMessageManager,
DisconnectionRequestManager disconnectionRequestManager,
MessagesCache messagesCache, MessagesCache messagesCache,
MessagesManager messagesManager, MessagesManager messagesManager,
KeysManager keysManager, KeysManager keysManager,
@ -289,6 +290,7 @@ record CommandDependencies(
accountsManager, accountsManager,
profilesManager, profilesManager,
reportMessageManager, reportMessageManager,
disconnectionRequestManager,
messagesCache, messagesCache,
messagesManager, messagesManager,
keys, keys,

View File

@ -66,6 +66,7 @@ public class MessagePersisterServiceCommand extends ServerCommand<WhisperServerC
deps.accountsManager(), deps.accountsManager(),
deps.dynamicConfigurationManager(), deps.dynamicConfigurationManager(),
new ExperimentEnrollmentManager(deps.dynamicConfigurationManager()), new ExperimentEnrollmentManager(deps.dynamicConfigurationManager()),
deps.disconnectionRequestManager(),
Duration.ofMinutes(configuration.getMessageCacheConfiguration().getPersistDelayMinutes()), Duration.ofMinutes(configuration.getMessageCacheConfiguration().getPersistDelayMinutes()),
namespace.getInt(WORKER_COUNT)); namespace.getInt(WORKER_COUNT));

View File

@ -110,7 +110,10 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc
final Mono<Long> accountUpdate = dryRun final Mono<Long> accountUpdate = dryRun
? Mono.just((long) expiredDevices.size()) ? Mono.just((long) expiredDevices.size())
: deleteDevices(account, expiredDevices, maxRetries); : deleteDevices(account, expiredDevices, maxRetries)
.flatMap(count ->
Mono.fromCompletionStage(getCommandDependencies().disconnectionRequestManager().requestDisconnection(account.getUuid()))
.then(Mono.just(count)));
return accountUpdate return accountUpdate
.doOnNext(successCounter::increment) .doOnNext(successCounter::increment)

View File

@ -30,6 +30,7 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
@ -98,7 +99,8 @@ class MessagePersisterIntegrationTest {
webSocketConnectionEventManager.start(); webSocketConnectionEventManager.start();
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, mock(ExperimentEnrollmentManager.class), PERSIST_DELAY, 1); dynamicConfigurationManager, mock(ExperimentEnrollmentManager.class), mock(DisconnectionRequestManager.class),
PERSIST_DELAY, 1);
account = mock(Account.class); account = mock(Account.class);

View File

@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
@ -21,6 +20,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.util.MockUtils.exactly; import static org.whispersystems.textsecuregcm.util.MockUtils.exactly;
@ -51,6 +51,7 @@ import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagePersisterConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagePersisterConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
@ -77,6 +78,7 @@ class MessagePersisterTest {
private MessagePersister messagePersister; private MessagePersister messagePersister;
private AccountsManager accountsManager; private AccountsManager accountsManager;
private MessagesManager messagesManager; private MessagesManager messagesManager;
private DisconnectionRequestManager disconnectionRequestManager;
private Account destinationAccount; private Account destinationAccount;
private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID(); private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
@ -97,6 +99,7 @@ class MessagePersisterTest {
messagesDynamoDb = mock(MessagesDynamoDb.class); messagesDynamoDb = mock(MessagesDynamoDb.class);
accountsManager = mock(AccountsManager.class); accountsManager = mock(AccountsManager.class);
disconnectionRequestManager = mock(DisconnectionRequestManager.class);
destinationAccount = mock(Account.class); destinationAccount = mock(Account.class);
when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(destinationAccount)); when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(destinationAccount));
@ -119,7 +122,8 @@ class MessagePersisterTest {
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, mock(ExperimentEnrollmentManager.class), PERSIST_DELAY, 1); dynamicConfigurationManager, mock(ExperimentEnrollmentManager.class), disconnectionRequestManager,
PERSIST_DELAY, 1);
when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
@ -301,6 +305,7 @@ class MessagePersisterTest {
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test")); messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test"));
verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID); verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID);
verify(disconnectionRequestManager, exactly()).requestDisconnection(DESTINATION_ACCOUNT_UUID);
} }
@Test @Test
@ -402,6 +407,7 @@ class MessagePersisterTest {
when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenReturn(CompletableFuture.failedFuture(new TimeoutException())); when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenReturn(CompletableFuture.failedFuture(new TimeoutException()));
assertThrows(CompletionException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test")); assertThrows(CompletionException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, "test"));
verifyNoMoreInteractions(disconnectionRequestManager);
} }
@SuppressWarnings("SameParameterValue") @SuppressWarnings("SameParameterValue")

View File

@ -74,6 +74,7 @@ class FinishPushNotificationExperimentCommandTest {
null, null,
null, null,
null, null,
null,
pushNotificationExperimentSamples, pushNotificationExperimentSamples,
null, null,
null, null,

View File

@ -59,6 +59,7 @@ class LockAccountsWithoutPniIdentityKeysCommandTest {
null, null,
null, null,
null, null,
null,
null); null);
namespace = new Namespace(Map.of( namespace = new Namespace(Map.of(

View File

@ -50,6 +50,7 @@ class LockAccountsWithoutPqKeysCommandTest {
null, null,
null, null,
null, null,
null,
keysManager, keysManager,
null, null,
null, null,

View File

@ -50,6 +50,7 @@ class NotifyIdleDevicesCommandTest {
null, null,
null, null,
null, null,
null,
messagesManager, messagesManager,
null, null,
null, null,

View File

@ -54,6 +54,7 @@ class RemoveAccountsWithoutPniIdentityKeysCommandTest {
null, null,
null, null,
null, null,
null,
null); null);
namespace = new Namespace(Map.of( namespace = new Namespace(Map.of(

View File

@ -50,6 +50,7 @@ class RemoveAccountsWithoutPqKeysCommandTest {
null, null,
null, null,
null, null,
null,
keysManager, keysManager,
null, null,
null, null,

View File

@ -48,6 +48,7 @@ class RemoveLinkedDevicesWithoutPniKeysCommandTest {
null, null,
null, null,
null, null,
null,
keysManager, keysManager,
null, null,
null, null,

View File

@ -47,6 +47,7 @@ class RemoveLinkedDevicesWithoutPqKeysCommandTest {
null, null,
null, null,
null, null,
null,
keysManager, keysManager,
null, null,
null, null,

View File

@ -63,6 +63,7 @@ class StartPushNotificationExperimentCommandTest {
null, null,
null, null,
null, null,
null,
pushNotificationExperimentSamples, pushNotificationExperimentSamples,
null, null,
null, null,