Introduce and evaluate a client presence manager based on sharded pub/sub

This commit is contained in:
Jon Chambers 2024-11-05 15:51:29 -05:00 committed by GitHub
parent 60cdcf5f0c
commit 8c984cbf42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 1339 additions and 56 deletions

View File

@ -56,7 +56,7 @@
<jsr305.version>3.0.2</jsr305.version> <jsr305.version>3.0.2</jsr305.version>
<kotlin.version>1.9.24</kotlin.version> <kotlin.version>1.9.24</kotlin.version>
<kotlinx-serialization.version>1.5.1</kotlinx-serialization.version> <kotlinx-serialization.version>1.5.1</kotlinx-serialization.version>
<lettuce.version>6.3.2.RELEASE</lettuce.version> <lettuce.version>6.4.1.RELEASE</lettuce.version>
<libphonenumber.version>8.13.40</libphonenumber.version> <libphonenumber.version>8.13.40</libphonenumber.version>
<logstash.logback.version>7.3</logstash.logback.version> <logstash.logback.version>7.3</logstash.logback.version>
<log4j-bom.version>2.23.1</log4j-bom.version> <log4j-bom.version>2.23.1</log4j-bom.version>

View File

@ -196,6 +196,7 @@ import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender; import org.whispersystems.textsecuregcm.push.FcmSender;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
@ -569,6 +570,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.virtualExecutorService(name(getClass(), "googlePlayBilling-%d")); .virtualExecutorService(name(getClass(), "googlePlayBilling-%d"));
ExecutorService appleAppStoreExecutor = environment.lifecycle() ExecutorService appleAppStoreExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "appleAppStore-%d")); .virtualExecutorService(name(getClass(), "appleAppStore-%d"));
ExecutorService clientEventExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "clientEvent-%d"));
ScheduledExecutorService appleAppStoreRetryExecutor = environment.lifecycle() ScheduledExecutorService appleAppStoreRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "appleAppStoreRetry-%d")).threads(1).build(); .scheduledExecutorService(name(getClass(), "appleAppStoreRetry-%d")).threads(1).build();
@ -619,6 +622,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
storageServiceExecutor, storageServiceRetryExecutor, config.getSecureStorageServiceConfiguration()); storageServiceExecutor, storageServiceRetryExecutor, config.getSecureStorageServiceConfiguration());
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster, recurringJobExecutor, ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster, recurringJobExecutor,
keyspaceNotificationDispatchExecutor); keyspaceNotificationDispatchExecutor);
PubSubClientEventManager pubSubClientEventManager = new PubSubClientEventManager(messagesCluster, clientEventExecutor, experimentEnrollmentManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor, MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionAsyncExecutor, clock, dynamicConfigurationManager); messageDeliveryScheduler, messageDeletionAsyncExecutor, clock, dynamicConfigurationManager);
@ -637,7 +641,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
pubsubClient, accountLockManager, keysManager, messagesManager, profilesManager, pubsubClient, accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, secureStorageClient, secureValueRecovery2Client,
clientPresenceManager, clientPresenceManager, pubSubClientEventManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor,
clock, config.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager); clock, config.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs); RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
@ -667,7 +671,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new MessageDeliveryLoopMonitor(rateLimitersCluster); new MessageDeliveryLoopMonitor(rateLimitersCluster);
final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager( final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
accountsManager, clientPresenceManager, svr2CredentialsGenerator, svr3CredentialsGenerator, accountsManager, clientPresenceManager, pubSubClientEventManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters); registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters);
final ReportedMessageMetricsListener reportedMessageMetricsListener = new ReportedMessageMetricsListener( final ReportedMessageMetricsListener reportedMessageMetricsListener = new ReportedMessageMetricsListener(
@ -677,7 +681,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager); final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
final MessageSender messageSender = final MessageSender messageSender =
new MessageSender(clientPresenceManager, messagesManager, pushNotificationManager); new MessageSender(clientPresenceManager, pubSubClientEventManager, messagesManager, pushNotificationManager);
final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor); final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
final TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(dynamicConfigurationManager, final TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(dynamicConfigurationManager,
config.getTurnConfiguration().secret().value()); config.getTurnConfiguration().secret().value());
@ -745,6 +749,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(provisioningManager); environment.lifecycle().manage(provisioningManager);
environment.lifecycle().manage(messagesCache); environment.lifecycle().manage(messagesCache);
environment.lifecycle().manage(clientPresenceManager); environment.lifecycle().manage(clientPresenceManager);
environment.lifecycle().manage(pubSubClientEventManager);
environment.lifecycle().manage(currencyManager); environment.lifecycle().manage(currencyManager);
environment.lifecycle().manage(registrationServiceClient); environment.lifecycle().manage(registrationServiceClient);
environment.lifecycle().manage(keyTransparencyServiceClient); environment.lifecycle().manage(keyTransparencyServiceClient);
@ -996,7 +1001,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(MultiRecipientMessageProvider.class); environment.jersey().register(MultiRecipientMessageProvider.class);
environment.jersey().register(new AuthDynamicFeature(accountAuthFilter)); environment.jersey().register(new AuthDynamicFeature(accountAuthFilter));
environment.jersey().register(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class)); environment.jersey().register(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class));
environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager,
pubSubClientEventManager));
environment.jersey().register(new TimestampResponseFilter()); environment.jersey().register(new TimestampResponseFilter());
/// ///
@ -1006,10 +1012,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager))); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager)));
webSocketEnvironment.setConnectListener( webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager, new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager,
pushNotificationScheduler, clientPresenceManager, websocketScheduledExecutor, messageDeliveryScheduler, pushNotificationScheduler, clientPresenceManager, pubSubClientEventManager, websocketScheduledExecutor,
clientReleaseManager, messageDeliveryLoopMonitor)); messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor));
webSocketEnvironment.jersey() webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); .register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager,
pubSubClientEventManager));
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters)); webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));
webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET)); webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class); webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class);
@ -1151,7 +1158,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
WebSocketEnvironment<AuthenticatedDevice> provisioningEnvironment = new WebSocketEnvironment<>(environment, WebSocketEnvironment<AuthenticatedDevice> provisioningEnvironment = new WebSocketEnvironment<>(environment,
webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000)); webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000));
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager,
pubSubClientEventManager));
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager)); provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager));
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager)); provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager)); provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));

View File

@ -27,6 +27,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -55,6 +56,7 @@ public class RegistrationLockVerificationManager {
private final AccountsManager accounts; private final AccountsManager accounts;
private final ClientPresenceManager clientPresenceManager; private final ClientPresenceManager clientPresenceManager;
private final PubSubClientEventManager pubSubClientEventManager;
private final ExternalServiceCredentialsGenerator svr2CredentialGenerator; private final ExternalServiceCredentialsGenerator svr2CredentialGenerator;
private final ExternalServiceCredentialsGenerator svr3CredentialGenerator; private final ExternalServiceCredentialsGenerator svr3CredentialGenerator;
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
@ -62,7 +64,9 @@ public class RegistrationLockVerificationManager {
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
public RegistrationLockVerificationManager( public RegistrationLockVerificationManager(
final AccountsManager accounts, final ClientPresenceManager clientPresenceManager, final AccountsManager accounts,
final ClientPresenceManager clientPresenceManager,
final PubSubClientEventManager pubSubClientEventManager,
final ExternalServiceCredentialsGenerator svr2CredentialGenerator, final ExternalServiceCredentialsGenerator svr2CredentialGenerator,
final ExternalServiceCredentialsGenerator svr3CredentialGenerator, final ExternalServiceCredentialsGenerator svr3CredentialGenerator,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
@ -70,6 +74,7 @@ public class RegistrationLockVerificationManager {
final RateLimiters rateLimiters) { final RateLimiters rateLimiters) {
this.accounts = accounts; this.accounts = accounts;
this.clientPresenceManager = clientPresenceManager; this.clientPresenceManager = clientPresenceManager;
this.pubSubClientEventManager = pubSubClientEventManager;
this.svr2CredentialGenerator = svr2CredentialGenerator; this.svr2CredentialGenerator = svr2CredentialGenerator;
this.svr3CredentialGenerator = svr3CredentialGenerator; this.svr3CredentialGenerator = svr3CredentialGenerator;
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager; this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
@ -161,6 +166,7 @@ public class RegistrationLockVerificationManager {
final List<Byte> deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList(); final List<Byte> deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList();
clientPresenceManager.disconnectAllPresences(updatedAccount.getUuid(), deviceIds); clientPresenceManager.disconnectAllPresences(updatedAccount.getUuid(), deviceIds);
pubSubClientEventManager.requestDisconnection(updatedAccount.getUuid(), deviceIds);
try { try {
// Send a push notification that prompts the client to attempt login and fail due to locked credentials // Send a push notification that prompts the client to attempt login and fail due to locked credentials

View File

@ -10,6 +10,7 @@ import org.glassfish.jersey.server.monitoring.ApplicationEventListener;
import org.glassfish.jersey.server.monitoring.RequestEvent; import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEventListener; import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
/** /**
@ -20,9 +21,11 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven
private final WebsocketRefreshRequestEventListener websocketRefreshRequestEventListener; private final WebsocketRefreshRequestEventListener websocketRefreshRequestEventListener;
public WebsocketRefreshApplicationEventListener(final AccountsManager accountsManager, public WebsocketRefreshApplicationEventListener(final AccountsManager accountsManager,
final ClientPresenceManager clientPresenceManager) { final ClientPresenceManager clientPresenceManager,
final PubSubClientEventManager pubSubClientEventManager) {
this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager, this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager,
pubSubClientEventManager,
new LinkedDeviceRefreshRequirementProvider(accountsManager), new LinkedDeviceRefreshRequirementProvider(accountsManager),
new PhoneNumberChangeRefreshRequirementProvider(accountsManager)); new PhoneNumberChangeRefreshRequirementProvider(accountsManager));
} }

View File

@ -10,6 +10,7 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import javax.ws.rs.container.ResourceInfo; import javax.ws.rs.container.ResourceInfo;
import javax.ws.rs.core.Context; import javax.ws.rs.core.Context;
@ -19,10 +20,12 @@ import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
public class WebsocketRefreshRequestEventListener implements RequestEventListener { public class WebsocketRefreshRequestEventListener implements RequestEventListener {
private final ClientPresenceManager clientPresenceManager; private final ClientPresenceManager clientPresenceManager;
private final PubSubClientEventManager pubSubClientEventManager;
private final WebsocketRefreshRequirementProvider[] providers; private final WebsocketRefreshRequirementProvider[] providers;
private static final Counter DISPLACED_ACCOUNTS = Metrics.counter( private static final Counter DISPLACED_ACCOUNTS = Metrics.counter(
@ -35,9 +38,11 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene
public WebsocketRefreshRequestEventListener( public WebsocketRefreshRequestEventListener(
final ClientPresenceManager clientPresenceManager, final ClientPresenceManager clientPresenceManager,
final PubSubClientEventManager pubSubClientEventManager,
final WebsocketRefreshRequirementProvider... providers) { final WebsocketRefreshRequirementProvider... providers) {
this.clientPresenceManager = clientPresenceManager; this.clientPresenceManager = clientPresenceManager;
this.pubSubClientEventManager = pubSubClientEventManager;
this.providers = providers; this.providers = providers;
} }
@ -60,6 +65,7 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene
try { try {
displacedDevices.incrementAndGet(); displacedDevices.incrementAndGet();
clientPresenceManager.disconnectPresence(pair.first(), pair.second()); clientPresenceManager.disconnectPresence(pair.first(), pair.second());
pubSubClientEventManager.requestDisconnection(pair.first(), List.of(pair.second()));
} catch (final Exception e) { } catch (final Exception e) {
logger.error("Could not displace device presence", e); logger.error("Could not displace device presence", e);
} }

View File

@ -0,0 +1,27 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
/**
* A client event listener handles events related to a client's message-retrieval presence. Handler methods are run on
* dedicated threads and may safely perform blocking operations.
*/
public interface ClientEventListener {
/**
* Indicates that a new message is available in the connected client's message queue.
*/
void handleNewMessageAvailable();
/**
* Indicates that the client's presence has been displaced and the listener should close the client's underlying
* network connection.
*
* @param connectedElsewhere if {@code true}, indicates that the client's presence has been displaced by another
* connection from the same client
*/
void handleConnectionDisplaced(boolean connectedElsewhere);
}

View File

@ -14,6 +14,7 @@ import io.lettuce.core.RedisFuture;
import io.lettuce.core.ScriptOutputType; import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
@ -277,7 +278,7 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
.subscribe(getKeyspaceNotificationChannel(presenceKey))); .subscribe(getKeyspaceNotificationChannel(presenceKey)));
} }
private void resubscribeAll() { private void resubscribeAll(final ClusterTopologyChangedEvent event) {
for (final String presenceKey : displacementListenersByPresenceKey.keySet()) { for (final String presenceKey : displacementListenersByPresenceKey.keySet()) {
subscribeForRemotePresenceChanges(presenceKey); subscribeForRemotePresenceChanges(presenceKey);
} }

View File

@ -8,9 +8,11 @@ import static com.codahale.metrics.MetricRegistry.name;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import java.util.Objects;
/** /**
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages, * A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
@ -28,6 +30,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
public class MessageSender { public class MessageSender {
private final ClientPresenceManager clientPresenceManager; private final ClientPresenceManager clientPresenceManager;
private final PubSubClientEventManager pubSubClientEventManager;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
@ -35,15 +38,18 @@ public class MessageSender {
private static final String CHANNEL_TAG_NAME = "channel"; private static final String CHANNEL_TAG_NAME = "channel";
private static final String EPHEMERAL_TAG_NAME = "ephemeral"; private static final String EPHEMERAL_TAG_NAME = "ephemeral";
private static final String CLIENT_ONLINE_TAG_NAME = "clientOnline"; private static final String CLIENT_ONLINE_TAG_NAME = "clientOnline";
private static final String PUB_SUB_CLIENT_ONLINE_TAG_NAME = "pubSubClientOnline";
private static final String URGENT_TAG_NAME = "urgent"; private static final String URGENT_TAG_NAME = "urgent";
private static final String STORY_TAG_NAME = "story"; private static final String STORY_TAG_NAME = "story";
private static final String SEALED_SENDER_TAG_NAME = "sealedSender"; private static final String SEALED_SENDER_TAG_NAME = "sealedSender";
public MessageSender(final ClientPresenceManager clientPresenceManager, public MessageSender(final ClientPresenceManager clientPresenceManager,
final PubSubClientEventManager pubSubClientEventManager,
final MessagesManager messagesManager, final MessagesManager messagesManager,
final PushNotificationManager pushNotificationManager) { final PushNotificationManager pushNotificationManager) {
this.clientPresenceManager = clientPresenceManager; this.clientPresenceManager = clientPresenceManager;
this.pubSubClientEventManager = pubSubClientEventManager;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
} }
@ -88,13 +94,15 @@ public class MessageSender {
} }
} }
Metrics.counter(SEND_COUNTER_NAME, pubSubClientEventManager.handleNewMessageAvailable(account.getIdentifier(IdentityType.ACI), device.getId())
CHANNEL_TAG_NAME, channel, .whenComplete((present, throwable) -> Metrics.counter(SEND_COUNTER_NAME,
EPHEMERAL_TAG_NAME, String.valueOf(online), CHANNEL_TAG_NAME, channel,
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent), EPHEMERAL_TAG_NAME, String.valueOf(online),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()), CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
STORY_TAG_NAME, String.valueOf(message.getStory()), PUB_SUB_CLIENT_ONLINE_TAG_NAME, String.valueOf(Objects.requireNonNullElse(present, false)),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId())) URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
.increment(); STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
.increment());
} }
} }

View File

@ -0,0 +1,407 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
/**
* The pub/sub-based client presence manager uses the Redis 7 sharded pub/sub system to notify connected clients that
* new messages are available for retrieval and report to senders whether a client was present to receive a message when
* sent. This system makes a best effort to ensure that a given client has only a single open connection across the
* fleet of servers, but cannot guarantee at-most-one behavior.
*/
public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[], byte[]> implements Managed {
private final FaultTolerantRedisClusterClient clusterClient;
private final Executor listenerEventExecutor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
static final String EXPERIMENT_NAME = "pubSubPresenceManager";
@Nullable
private FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubConnection;
private final Map<AccountAndDeviceIdentifier, ConnectionIdAndListener> listenersByAccountAndDeviceIdentifier;
private static final byte[] NEW_MESSAGE_EVENT_BYTES = ClientEvent.newBuilder()
.setNewMessageAvailable(NewMessageAvailableEvent.getDefaultInstance())
.build()
.toByteArray();
private static final byte[] DISCONNECT_REQUESTED_EVENT_BYTES = ClientEvent.newBuilder()
.setDisconnectRequested(DisconnectRequested.getDefaultInstance())
.build()
.toByteArray();
private static final Counter PUBLISH_CLIENT_CONNECTION_EVENT_ERROR_COUNTER =
Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "publishClientConnectionEventError"));
private static final Counter UNSUBSCRIBE_ERROR_COUNTER =
Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "unsubscribeError"));
private static final Counter MESSAGE_WITHOUT_LISTENER_COUNTER =
Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "messageWithoutListener"));
private static final String LISTENER_GAUGE_NAME =
MetricsUtil.name(PubSubClientEventManager.class, "listeners");
private static final Logger logger = LoggerFactory.getLogger(PubSubClientEventManager.class);
private record AccountAndDeviceIdentifier(UUID accountIdentifier, byte deviceId) {
}
private record ConnectionIdAndListener(UUID connectionIdentifier, ClientEventListener listener) {
}
public PubSubClientEventManager(final FaultTolerantRedisClusterClient clusterClient,
final Executor listenerEventExecutor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.clusterClient = clusterClient;
this.listenerEventExecutor = listenerEventExecutor;
this.experimentEnrollmentManager = experimentEnrollmentManager;
this.listenersByAccountAndDeviceIdentifier =
Metrics.gaugeMapSize(LISTENER_GAUGE_NAME, Tags.empty(), new ConcurrentHashMap<>());
}
@Override
public synchronized void start() {
this.pubSubConnection = clusterClient.createBinaryPubSubConnection();
this.pubSubConnection.usePubSubConnection(connection -> connection.addListener(this));
pubSubConnection.subscribeToClusterTopologyChangedEvents(this::resubscribe);
}
@Override
public synchronized void stop() {
if (pubSubConnection != null) {
pubSubConnection.usePubSubConnection(connection -> {
connection.removeListener(this);
connection.close();
});
}
pubSubConnection = null;
}
/**
* Marks the given device as "present" and registers a listener for new messages and conflicting connections. If the
* given device already has a presence registered with this presence manager instance, that presence is displaced
* immediately and the listener's {@link ClientEventListener#handleConnectionDisplaced(boolean)} method is called.
*
* @param accountIdentifier the account identifier for the newly-connected device
* @param deviceId the ID of the newly-connected device within the given account
* @param listener the listener to notify when new messages or conflicting connections arrive for the newly-conencted
* device
*
* @return a future that yields a connection identifier when the new device's presence has been registered; the future
* may fail if a pub/sub subscription could not be established, in which case callers should close the client's
* connection to the server
*/
public CompletionStage<UUID> handleClientConnected(final UUID accountIdentifier, final byte deviceId, final ClientEventListener listener) {
if (pubSubConnection == null) {
throw new IllegalStateException("Presence manager not started");
}
if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
return CompletableFuture.completedFuture(UUID.randomUUID());
}
final UUID connectionId = UUID.randomUUID();
final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId);
final AtomicReference<ClientEventListener> displacedListener = new AtomicReference<>();
final AtomicReference<CompletionStage<Void>> subscribeFuture = new AtomicReference<>();
// Note that we're relying on some specific implementation details of `ConcurrentHashMap#compute(...)`. In
// particular, the behavioral contract for `ConcurrentHashMap#compute(...)` says:
//
// > The entire method invocation is performed atomically. The supplied function is invoked exactly once per
// > invocation of this method. Some attempted update operations on this map by other threads may be blocked while
// > computation is in progress, so the computation should be short and simple.
//
// This provides a mechanism to make sure that we enqueue subscription/unsubscription operations in the same order
// as adding/removing listeners from the map and helps us avoid races and conflicts. Note that the enqueued
// operation is asynchronous; we're not blocking on it in the scope of the `compute` operation.
listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId),
(key, existingIdAndListener) -> {
subscribeFuture.set(pubSubConnection.withPubSubConnection(connection ->
connection.async().ssubscribe(clientPresenceKey)));
if (existingIdAndListener != null) {
displacedListener.set(existingIdAndListener.listener());
}
return new ConnectionIdAndListener(connectionId, listener);
});
if (displacedListener.get() != null) {
listenerEventExecutor.execute(() -> displacedListener.get().handleConnectionDisplaced(true));
}
return subscribeFuture.get()
.thenCompose(ignored -> clusterClient.withBinaryCluster(connection -> connection.async()
.spublish(clientPresenceKey, buildClientConnectedMessage(connectionId))))
.handle((ignored, throwable) -> {
if (throwable != null) {
PUBLISH_CLIENT_CONNECTION_EVENT_ERROR_COUNTER.increment();
}
return connectionId;
});
}
/**
* Removes the "presence" for the given device. The presence is removed if and only if the given connection ID matches
* the connection ID for the currently-registered presence. Callers should call this method when they have closed or
* intend to close the client's underlying network connection.
*
* @param accountIdentifier the identifier of the account for the disconnected device
* @param deviceId the ID of the disconnected device within the given account
* @param connectionId the ID of the connection that has been closed (or will be closed)
*
* @return a future that completes when the presence has been removed
*/
public CompletionStage<Void> handleClientDisconnected(final UUID accountIdentifier, final byte deviceId, final UUID connectionId) {
if (pubSubConnection == null) {
throw new IllegalStateException("Presence manager not started");
}
if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
return CompletableFuture.completedFuture(null);
}
final AtomicReference<CompletionStage<Void>> unsubscribeFuture = new AtomicReference<>();
// Note that we're relying on some specific implementation details of `ConcurrentHashMap#compute(...)`. In
// particular, the behavioral contract for `ConcurrentHashMap#compute(...)` says:
//
// > The entire method invocation is performed atomically. The supplied function is invoked exactly once per
// > invocation of this method. Some attempted update operations on this map by other threads may be blocked while
// > computation is in progress, so the computation should be short and simple.
//
// This provides a mechanism to make sure that we enqueue subscription/unsubscription operations in the same order
// as adding/removing listeners from the map and helps us avoid races and conflicts. Note that the enqueued
// operation is asynchronous; we're not blocking on it in the scope of the `compute` operation.
listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId),
(ignored, existingIdAndListener) -> {
final ConnectionIdAndListener remainingIdAndListener;
if (existingIdAndListener == null) {
remainingIdAndListener = null;
} else if (existingIdAndListener.connectionIdentifier().equals(connectionId)) {
remainingIdAndListener = null;
} else {
remainingIdAndListener = existingIdAndListener;
}
if (remainingIdAndListener == null) {
// Only unsubscribe if there's no listener remaining
unsubscribeFuture.set(pubSubConnection.withPubSubConnection(connection ->
connection.async().sunsubscribe(getClientPresenceKey(accountIdentifier, deviceId)))
.thenRun(Util.NOOP));
} else {
unsubscribeFuture.set(CompletableFuture.completedFuture(null));
}
return remainingIdAndListener;
});
return unsubscribeFuture.get()
.whenComplete((ignored, throwable) -> {
if (throwable != null) {
UNSUBSCRIBE_ERROR_COUNTER.increment();
}
});
}
/**
* Publishes an event notifying a specific device that a new message is available for retrieval. This method indicates
* whether the target device is "present" (i.e. has an active listener). Callers may choose to take follow-up action
* (like sending a push notification) if the target device is not present.
*
* @param accountIdentifier the account identifier of the receiving device
* @param deviceId the ID of the receiving device within the target account
*
* @return a future that yields {@code true} if the target device had an active listener or {@code false} otherwise
*/
public CompletionStage<Boolean> handleNewMessageAvailable(final UUID accountIdentifier, final byte deviceId) {
if (pubSubConnection == null) {
throw new IllegalStateException("Presence manager not started");
}
if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
return CompletableFuture.completedFuture(false);
}
return pubSubConnection.withPubSubConnection(connection ->
connection.async().spublish(getClientPresenceKey(accountIdentifier, deviceId), NEW_MESSAGE_EVENT_BYTES))
.thenApply(listeners -> listeners > 0);
}
/**
* Tests whether a client with the given account/device is connected to this presence manager instance.
*
* @param accountUuid the account identifier for the client to check
* @param deviceId the ID of the device within the given account
*
* @return {@code true} if a client with the given account/device is connected to this presence manager instance or
* {@code false} if the client is not connected at all or is connected to a different presence manager instance
*/
public boolean isLocallyPresent(final UUID accountUuid, final byte deviceId) {
return listenersByAccountAndDeviceIdentifier.containsKey(new AccountAndDeviceIdentifier(accountUuid, deviceId));
}
/**
* Broadcasts a request that all devices associated with the identified account and connected to any client presence
* instance close their network connections.
*
* @param accountIdentifier the account identifier for which to request disconnection
*
* @return a future that completes when the request has been sent
*/
public CompletableFuture<Void> requestDisconnection(final UUID accountIdentifier) {
return requestDisconnection(accountIdentifier, Device.ALL_POSSIBLE_DEVICE_IDS);
}
/**
* Broadcasts a request that the specified devices associated with the identified account and connected to any client
* presence instance close their network connections.
*
* @param accountIdentifier the account identifier for which to request disconnection
* @param deviceIds the IDs of the devices for which to request disconnection
*
* @return a future that completes when the request has been sent
*/
public CompletableFuture<Void> requestDisconnection(final UUID accountIdentifier, final Collection<Byte> deviceIds) {
return CompletableFuture.allOf(deviceIds.stream()
.map(deviceId -> {
final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId);
return clusterClient.withBinaryCluster(connection -> connection.async()
.spublish(clientPresenceKey, DISCONNECT_REQUESTED_EVENT_BYTES))
.toCompletableFuture();
})
.toArray(CompletableFuture[]::new));
}
@VisibleForTesting
void resubscribe(final ClusterTopologyChangedEvent clusterTopologyChangedEvent) {
final boolean[] changedSlots = RedisClusterUtil.getChangedSlots(clusterTopologyChangedEvent);
final Map<Integer, List<byte[]>> clientPresenceKeysBySlot = new HashMap<>();
// Organize subscriptions by slot so we can issue a smaller number of larger resubscription commands
listenersByAccountAndDeviceIdentifier.keySet()
.stream()
.map(accountAndDeviceIdentifier -> getClientPresenceKey(accountAndDeviceIdentifier.accountIdentifier(), accountAndDeviceIdentifier.deviceId()))
.forEach(clientPresenceKey -> {
final int slot = SlotHash.getSlot(clientPresenceKey);
if (changedSlots[slot]) {
clientPresenceKeysBySlot.computeIfAbsent(slot, ignored -> new ArrayList<>()).add(clientPresenceKey);
}
});
// Issue one resubscription command per affected slot
clientPresenceKeysBySlot.forEach((slot, clientPresenceKeys) -> {
if (pubSubConnection != null) {
final byte[][] clientPresenceKeyArray = clientPresenceKeys.toArray(byte[][]::new);
pubSubConnection.usePubSubConnection(connection -> connection.sync().ssubscribe(clientPresenceKeyArray));
}
});
}
@Override
public void smessage(final RedisClusterNode node, final byte[] shardChannel, final byte[] message) {
final ClientEvent clientEvent;
try {
clientEvent = ClientEvent.parseFrom(message);
} catch (final InvalidProtocolBufferException e) {
logger.error("Failed to parse pub/sub event protobuf", e);
return;
}
final AccountAndDeviceIdentifier accountAndDeviceIdentifier = parseClientPresenceKey(shardChannel);
@Nullable final ConnectionIdAndListener connectionIdAndListener =
listenersByAccountAndDeviceIdentifier.get(accountAndDeviceIdentifier);
if (connectionIdAndListener != null) {
switch (clientEvent.getEventCase()) {
case NEW_MESSAGE_AVAILABLE -> connectionIdAndListener.listener().handleNewMessageAvailable();
case CLIENT_CONNECTED -> {
final UUID connectionId = UUIDUtil.fromByteString(clientEvent.getClientConnected().getConnectionId());
if (!connectionIdAndListener.connectionIdentifier().equals(connectionId)) {
listenerEventExecutor.execute(() ->
connectionIdAndListener.listener().handleConnectionDisplaced(true));
}
}
case DISCONNECT_REQUESTED -> listenerEventExecutor.execute(() ->
connectionIdAndListener.listener().handleConnectionDisplaced(false));
default -> logger.warn("Unexpected client event type: {}", clientEvent.getClass());
}
} else {
MESSAGE_WITHOUT_LISTENER_COUNTER.increment();
}
}
private static byte[] buildClientConnectedMessage(final UUID connectionId) {
return ClientEvent.newBuilder()
.setClientConnected(ClientConnectedEvent.newBuilder()
.setConnectionId(UUIDUtil.toByteString(connectionId))
.build())
.build()
.toByteArray();
}
@VisibleForTesting
static byte[] getClientPresenceKey(final UUID accountIdentifier, final byte deviceId) {
return ("client_presence::{" + accountIdentifier + ":" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static AccountAndDeviceIdentifier parseClientPresenceKey(final byte[] clientPresenceKeyBytes) {
final String clientPresenceKey = new String(clientPresenceKeyBytes, StandardCharsets.UTF_8);
final int uuidStart = "client_presence::{".length();
final UUID accountIdentifier = UUID.fromString(clientPresenceKey.substring(uuidStart, uuidStart + 36));
final byte deviceId = Byte.parseByte(clientPresenceKey.substring(uuidStart + 37, clientPresenceKey.length() - 1));
return new AccountAndDeviceIdentifier(accountIdentifier, deviceId);
}
}

View File

@ -11,6 +11,7 @@ import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import java.util.function.Consumer;
public class FaultTolerantPubSubClusterConnection<K, V> extends AbstractFaultTolerantPubSubConnection<K, V, StatefulRedisClusterPubSubConnection<K, V>> { public class FaultTolerantPubSubClusterConnection<K, V> extends AbstractFaultTolerantPubSubConnection<K, V, StatefulRedisClusterPubSubConnection<K, V>> {
@ -32,7 +33,7 @@ public class FaultTolerantPubSubClusterConnection<K, V> extends AbstractFaultTol
this.topologyChangedEventScheduler = topologyChangedEventScheduler; this.topologyChangedEventScheduler = topologyChangedEventScheduler;
} }
public void subscribeToClusterTopologyChangedEvents(final Runnable eventHandler) { public void subscribeToClusterTopologyChangedEvents(final Consumer<ClusterTopologyChangedEvent> eventHandler) {
usePubSubConnection(connection -> connection.getResources().eventBus().get() usePubSubConnection(connection -> connection.getResources().eventBus().get()
.filter(event -> { .filter(event -> {
@ -53,7 +54,7 @@ public class FaultTolerantPubSubClusterConnection<K, V> extends AbstractFaultTol
resubscribeRetry.executeRunnable(() -> { resubscribeRetry.executeRunnable(() -> {
try { try {
eventHandler.run(); eventHandler.accept((ClusterTopologyChangedEvent) event);
} catch (final RuntimeException e) { } catch (final RuntimeException e) {
logger.warn("Resubscribe for {} failed", getName(), e); logger.warn("Resubscribe for {} failed", getName(), e);
throw e; throw e;

View File

@ -202,4 +202,11 @@ public class FaultTolerantRedisClusterClient {
Schedulers.newSingle(name + "-redisPubSubEvents", true)); Schedulers.newSingle(name + "-redisPubSubEvents", true));
} }
public FaultTolerantPubSubClusterConnection<byte[], byte[]> createBinaryPubSubConnection() {
final StatefulRedisClusterPubSubConnection<byte[], byte[]> pubSubConnection = clusterClient.connectPubSub(ByteArrayCodec.INSTANCE);
pubSubConnections.add(pubSubConnection);
return new FaultTolerantPubSubClusterConnection<>(name, pubSubConnection, topologyChangedEventRetry,
Schedulers.newSingle(name + "-redisPubSubEvents", true));
}
} }

View File

@ -76,6 +76,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
@ -126,6 +127,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private final SecureStorageClient secureStorageClient; private final SecureStorageClient secureStorageClient;
private final SecureValueRecovery2Client secureValueRecovery2Client; private final SecureValueRecovery2Client secureValueRecovery2Client;
private final ClientPresenceManager clientPresenceManager; private final ClientPresenceManager clientPresenceManager;
private final PubSubClientEventManager pubSubClientEventManager;
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager; private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final ClientPublicKeysManager clientPublicKeysManager; private final ClientPublicKeysManager clientPublicKeysManager;
private final Executor accountLockExecutor; private final Executor accountLockExecutor;
@ -205,6 +207,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
final SecureStorageClient secureStorageClient, final SecureStorageClient secureStorageClient,
final SecureValueRecovery2Client secureValueRecovery2Client, final SecureValueRecovery2Client secureValueRecovery2Client,
final ClientPresenceManager clientPresenceManager, final ClientPresenceManager clientPresenceManager,
final PubSubClientEventManager pubSubClientEventManager,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final ClientPublicKeysManager clientPublicKeysManager, final ClientPublicKeysManager clientPublicKeysManager,
final Executor accountLockExecutor, final Executor accountLockExecutor,
@ -223,6 +226,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
this.secureStorageClient = secureStorageClient; this.secureStorageClient = secureStorageClient;
this.secureValueRecovery2Client = secureValueRecovery2Client; this.secureValueRecovery2Client = secureValueRecovery2Client;
this.clientPresenceManager = clientPresenceManager; this.clientPresenceManager = clientPresenceManager;
this.pubSubClientEventManager = pubSubClientEventManager;
this.registrationRecoveryPasswordsManager = requireNonNull(registrationRecoveryPasswordsManager); this.registrationRecoveryPasswordsManager = requireNonNull(registrationRecoveryPasswordsManager);
this.clientPublicKeysManager = clientPublicKeysManager; this.clientPublicKeysManager = clientPublicKeysManager;
this.accountLockExecutor = accountLockExecutor; this.accountLockExecutor = accountLockExecutor;
@ -329,7 +333,10 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
keysManager.deleteSingleUsePreKeys(pni), keysManager.deleteSingleUsePreKeys(pni),
messagesManager.clear(aci), messagesManager.clear(aci),
profilesManager.deleteAll(aci)) profilesManager.deleteAll(aci))
.thenRunAsync(() -> clientPresenceManager.disconnectAllPresencesForUuid(aci), clientPresenceExecutor) .thenRunAsync(() -> {
clientPresenceManager.disconnectAllPresencesForUuid(aci);
pubSubClientEventManager.requestDisconnection(aci);
}, clientPresenceExecutor)
.thenCompose(ignored -> accounts.reclaimAccount(e.getExistingAccount(), account, additionalWriteItems)) .thenCompose(ignored -> accounts.reclaimAccount(e.getExistingAccount(), account, additionalWriteItems))
.thenCompose(ignored -> { .thenCompose(ignored -> {
// We should have cleared all messages before overwriting the old account, but more may have arrived // We should have cleared all messages before overwriting the old account, but more may have arrived
@ -594,6 +601,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
.whenCompleteAsync((ignored, throwable) -> { .whenCompleteAsync((ignored, throwable) -> {
if (throwable == null) { if (throwable == null) {
RedisOperation.unchecked(() -> clientPresenceManager.disconnectPresence(accountIdentifier, deviceId)); RedisOperation.unchecked(() -> clientPresenceManager.disconnectPresence(accountIdentifier, deviceId));
pubSubClientEventManager.requestDisconnection(accountIdentifier, List.of(deviceId));
} }
}, clientPresenceExecutor); }, clientPresenceExecutor);
} }
@ -1240,9 +1248,11 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
registrationRecoveryPasswordsManager.removeForNumber(account.getNumber())) registrationRecoveryPasswordsManager.removeForNumber(account.getNumber()))
.thenCompose(ignored -> accounts.delete(account.getUuid(), additionalWriteItems)) .thenCompose(ignored -> accounts.delete(account.getUuid(), additionalWriteItems))
.thenCompose(ignored -> redisDeleteAsync(account)) .thenCompose(ignored -> redisDeleteAsync(account))
.thenRunAsync(() -> RedisOperation.unchecked(() -> .thenRunAsync(() -> {
account.getDevices().forEach(device -> RedisOperation.unchecked(() -> clientPresenceManager.disconnectAllPresencesForUuid(account.getUuid()));
clientPresenceManager.disconnectPresence(account.getUuid(), device.getId()))), clientPresenceExecutor);
pubSubClientEventManager.requestDisconnection(account.getUuid());
}, clientPresenceExecutor);
} }
private String getAccountMapKey(String key) { private String getAccountMapKey(String key) {

View File

@ -13,6 +13,7 @@ import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed; import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.ZAddArgs; import io.lettuce.core.ZAddArgs;
import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
@ -247,7 +248,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
pubSubConnection.usePubSubConnection(connection -> connection.sync().upstream().commands().unsubscribe()); pubSubConnection.usePubSubConnection(connection -> connection.sync().upstream().commands().unsubscribe());
} }
private void resubscribeAll() { private void resubscribeAll(final ClusterTopologyChangedEvent event) {
final Set<String> queueNames; final Set<String> queueNames;

View File

@ -6,6 +6,12 @@
package org.whispersystems.textsecuregcm.util; package org.whispersystems.textsecuregcm.util;
import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public class RedisClusterUtil { public class RedisClusterUtil {
@ -38,4 +44,51 @@ public class RedisClusterUtil {
public static String getMinimalHashTag(final int slot) { public static String getMinimalHashTag(final int slot) {
return HASHES_BY_SLOT[slot]; return HASHES_BY_SLOT[slot];
} }
/**
* Returns an array indicating which slots have moved as part of a {@link ClusterTopologyChangedEvent}. The elements
* of the array map to slots in the cluster; for example, if slot 1234 has changed, then element 1234 of the returned
* array will be {@code true}.
*
* @param clusterTopologyChangedEvent the event from which to derive an array of changed slots
*
* @return an array indicating which slots of changed
*/
public static boolean[] getChangedSlots(final ClusterTopologyChangedEvent clusterTopologyChangedEvent) {
final Map<String, RedisClusterNode> beforeNodesById = clusterTopologyChangedEvent.before().stream()
.collect(Collectors.toMap(RedisClusterNode::getNodeId, node -> node));
final Map<String, RedisClusterNode> afterNodesById = clusterTopologyChangedEvent.after().stream()
.collect(Collectors.toMap(RedisClusterNode::getNodeId, node -> node));
final Set<String> nodeIds = new HashSet<>(beforeNodesById.keySet());
nodeIds.addAll(afterNodesById.keySet());
final boolean[] changedSlots = new boolean[SlotHash.SLOT_COUNT];
for (final String nodeId : nodeIds) {
if (beforeNodesById.containsKey(nodeId) && afterNodesById.containsKey(nodeId)) {
// This node was present before and after the topology change, but its slots may have changed
final boolean[] beforeSlots = new boolean[SlotHash.SLOT_COUNT];
beforeNodesById.get(nodeId).getSlots().forEach(slot -> beforeSlots[slot] = true);
final boolean[] afterSlots = new boolean[SlotHash.SLOT_COUNT];
afterNodesById.get(nodeId).getSlots().forEach(slot -> afterSlots[slot] = true);
for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) {
changedSlots[slot] |= beforeSlots[slot] ^ afterSlots[slot];
}
} else if (beforeNodesById.containsKey(nodeId)) {
// The node was present before the topology change, but is gone now; all of its slots should be considered
// changed
beforeNodesById.get(nodeId).getSlots().forEach(slot -> changedSlots[slot] = true);
} else {
// The node was present after the change, but wasn't there before; all of its slots should be considered
// changed
afterNodesById.get(nodeId).getSlots().forEach(slot -> changedSlots[slot] = true);
}
}
return changedSlots;
}
} }

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
import java.util.UUID;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -19,6 +20,7 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
@ -47,6 +49,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
private final PushNotificationScheduler pushNotificationScheduler; private final PushNotificationScheduler pushNotificationScheduler;
private final ClientPresenceManager clientPresenceManager; private final ClientPresenceManager clientPresenceManager;
private final PubSubClientEventManager pubSubClientEventManager;
private final ScheduledExecutorService scheduledExecutorService; private final ScheduledExecutorService scheduledExecutorService;
private final Scheduler messageDeliveryScheduler; private final Scheduler messageDeliveryScheduler;
private final ClientReleaseManager clientReleaseManager; private final ClientReleaseManager clientReleaseManager;
@ -55,12 +58,15 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter; private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter; private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
private transient UUID connectionId;
public AuthenticatedConnectListener(ReceiptSender receiptSender, public AuthenticatedConnectListener(ReceiptSender receiptSender,
MessagesManager messagesManager, MessagesManager messagesManager,
MessageMetrics messageMetrics, MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager, PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler, PushNotificationScheduler pushNotificationScheduler,
ClientPresenceManager clientPresenceManager, ClientPresenceManager clientPresenceManager,
PubSubClientEventManager pubSubClientEventManager,
ScheduledExecutorService scheduledExecutorService, ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager, ClientReleaseManager clientReleaseManager,
@ -71,6 +77,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
this.pushNotificationScheduler = pushNotificationScheduler; this.pushNotificationScheduler = pushNotificationScheduler;
this.clientPresenceManager = clientPresenceManager; this.clientPresenceManager = clientPresenceManager;
this.pubSubClientEventManager = pubSubClientEventManager;
this.scheduledExecutorService = scheduledExecutorService; this.scheduledExecutorService = scheduledExecutorService;
this.messageDeliveryScheduler = messageDeliveryScheduler; this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager; this.clientReleaseManager = clientReleaseManager;
@ -121,6 +128,12 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// It's preferable to start sending push notifications as soon as possible. // It's preferable to start sending push notifications as soon as possible.
RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection)); RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection));
if (connectionId != null) {
pubSubClientEventManager.handleClientDisconnected(auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
connectionId);
}
// Next, we stop listening for inbound messages. If a message arrives after this call, the websocket connection // Next, we stop listening for inbound messages. If a message arrives after this call, the websocket connection
// will not be notified and will not change its state, but that's okay because it has already closed and // will not be notified and will not change its state, but that's okay because it has already closed and
// attempts to deliver mesages via this connection will not succeed. // attempts to deliver mesages via this connection will not succeed.
@ -147,6 +160,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// Finally, we register this client's presence, which suppresses push notifications. We do this last because // Finally, we register this client's presence, which suppresses push notifications. We do this last because
// receiving extra push notifications is generally preferable to missing out on a push notification. // receiving extra push notifications is generally preferable to missing out on a push notification.
clientPresenceManager.setPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection); clientPresenceManager.setPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection);
pubSubClientEventManager.handleClientConnected(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), null)
.thenAccept(connectionId -> this.connectionId = connectionId);
renewPresenceFutureReference.set(scheduledExecutorService.scheduleAtFixedRate(() -> RedisOperation.unchecked(() -> renewPresenceFutureReference.set(scheduledExecutorService.scheduleAtFixedRate(() -> RedisOperation.unchecked(() ->
clientPresenceManager.renewPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())), clientPresenceManager.renewPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())),

View File

@ -45,6 +45,7 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientEventListener;
import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
@ -63,15 +64,13 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener { public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener, ClientEventListener {
private static final DistributionSummary messageTime = Metrics.summary( private static final DistributionSummary messageTime = Metrics.summary(
name(MessageController.class, "messageDeliveryDuration")); name(MessageController.class, "messageDeliveryDuration"));
private static final DistributionSummary primaryDeviceMessageTime = Metrics.summary( private static final DistributionSummary primaryDeviceMessageTime = Metrics.summary(
name(MessageController.class, "primaryDeviceMessageDeliveryDuration")); name(MessageController.class, "primaryDeviceMessageDeliveryDuration"));
private static final Counter sendMessageCounter = Metrics.counter(name(WebSocketConnection.class, "sendMessage")); private static final Counter sendMessageCounter = Metrics.counter(name(WebSocketConnection.class, "sendMessage"));
private static final Counter messageAvailableCounter = Metrics.counter(
name(WebSocketConnection.class, "messagesAvailable"));
private static final Counter messagesPersistedCounter = Metrics.counter( private static final Counter messagesPersistedCounter = Metrics.counter(
name(WebSocketConnection.class, "messagesPersisted")); name(WebSocketConnection.class, "messagesPersisted"));
private static final Counter bytesSentCounter = Metrics.counter(name(WebSocketConnection.class, "bytesSent")); private static final Counter bytesSentCounter = Metrics.counter(name(WebSocketConnection.class, "bytesSent"));
@ -91,6 +90,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
"sendMessages"); "sendMessages");
private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class, private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class,
"sendMessageError"); "sendMessageError");
private static final String MESSAGE_AVAILABLE_COUNTER_NAME = name(WebSocketConnection.class, "messagesAvailable");
private static final String PRESENCE_MANAGER_TAG = "presenceManager";
private static final String STATUS_CODE_TAG = "status"; private static final String STATUS_CODE_TAG = "status";
private static final String STATUS_MESSAGE_TAG = "message"; private static final String STATUS_MESSAGE_TAG = "message";
private static final String ERROR_TYPE_TAG = "errorType"; private static final String ERROR_TYPE_TAG = "errorType";
@ -468,7 +470,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
return false; return false;
} }
messageAvailableCounter.increment(); Metrics.counter(MESSAGE_AVAILABLE_COUNTER_NAME,
PRESENCE_MANAGER_TAG, "legacy")
.increment();
storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE); storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE);
@ -477,6 +481,13 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
return true; return true;
} }
@Override
public void handleNewMessageAvailable() {
Metrics.counter(MESSAGE_AVAILABLE_COUNTER_NAME,
PRESENCE_MANAGER_TAG, "pubsub")
.increment();
}
@Override @Override
public boolean handleMessagesPersisted() { public boolean handleMessagesPersisted() {
if (!client.isOpen()) { if (!client.isOpen()) {
@ -497,7 +508,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
public void handleDisplacement(final boolean connectedElsewhere) { public void handleDisplacement(final boolean connectedElsewhere) {
final Tags tags = Tags.of( final Tags tags = Tags.of(
UserAgentTagUtil.getPlatformTag(client.getUserAgent()), UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)) Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)),
Tag.of(PRESENCE_MANAGER_TAG, "legacy")
); );
Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment(); Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment();
@ -522,6 +534,17 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
} }
} }
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
final Tags tags = Tags.of(
UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)),
Tag.of(PRESENCE_MANAGER_TAG, "pubsub")
);
Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment();
}
private record StoredMessageInfo(UUID guid, long serverTimestamp) { private record StoredMessageInfo(UUID guid, long serverTimestamp) {
} }

View File

@ -31,10 +31,12 @@ import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller; import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher; import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
import org.whispersystems.textsecuregcm.push.APNSender; import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender; import org.whispersystems.textsecuregcm.push.FcmSender;
@ -141,6 +143,8 @@ record CommandDependencies(
.maxThreads(1).minThreads(1).build(); .maxThreads(1).minThreads(1).build();
ExecutorService fcmSenderExecutor = environment.lifecycle().executorService(name(name, "fcmSender-%d")) ExecutorService fcmSenderExecutor = environment.lifecycle().executorService(name(name, "fcmSender-%d"))
.maxThreads(16).minThreads(16).build(); .maxThreads(16).minThreads(16).build();
ExecutorService clientEventExecutor = environment.lifecycle()
.virtualExecutorService(name(name, "clientEvent-%d"));
ScheduledExecutorService secureValueRecoveryServiceRetryExecutor = environment.lifecycle() ScheduledExecutorService secureValueRecoveryServiceRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "secureValueRecoveryServiceRetry-%d")).threads(1).build(); .scheduledExecutorService(name(name, "secureValueRecoveryServiceRetry-%d")).threads(1).build();
@ -214,6 +218,9 @@ record CommandDependencies(
storageServiceExecutor, storageServiceRetryExecutor, configuration.getSecureStorageServiceConfiguration()); storageServiceExecutor, storageServiceRetryExecutor, configuration.getSecureStorageServiceConfiguration());
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster, ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster,
recurringJobExecutor, keyspaceNotificationDispatchExecutor); recurringJobExecutor, keyspaceNotificationDispatchExecutor);
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(
dynamicConfigurationManager);
PubSubClientEventManager pubSubClientEventManager = new PubSubClientEventManager(messagesCluster, clientEventExecutor, experimentEnrollmentManager);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor, MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager); messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
@ -230,7 +237,7 @@ record CommandDependencies(
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
pubsubClient, accountLockManager, keys, messagesManager, profilesManager, pubsubClient, accountLockManager, keys, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, clientPresenceManager, secureStorageClient, secureValueRecovery2Client, clientPresenceManager, pubSubClientEventManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor,
clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager); clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(), RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(),
@ -269,6 +276,7 @@ record CommandDependencies(
environment.lifecycle().manage(apnSender); environment.lifecycle().manage(apnSender);
environment.lifecycle().manage(messagesCache); environment.lifecycle().manage(messagesCache);
environment.lifecycle().manage(clientPresenceManager); environment.lifecycle().manage(clientPresenceManager);
environment.lifecycle().manage(pubSubClientEventManager);
environment.lifecycle().manage(new ManagedAwsCrt()); environment.lifecycle().manage(new ManagedAwsCrt());
return new CommandDependencies( return new CommandDependencies(

View File

@ -0,0 +1,38 @@
/**
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
syntax = "proto3";
package org.signal.chat.presence;
option java_package = "org.whispersystems.textsecuregcm.push";
option java_multiple_files = true;
message ClientEvent {
oneof event {
NewMessageAvailableEvent new_message_available = 1;
ClientConnectedEvent client_connected = 2;
DisconnectRequested disconnect_requested = 3;
}
}
/**
* Indicates that a new message is available for the client to retrieve.
*/
message NewMessageAvailableEvent {
}
/**
* Indicates that a client has connected to the presence system.
*/
message ClientConnectedEvent {
bytes connection_id = 1;
}
/**
* Indicates that the server has requested that the client disconnect due to
* (for example) account lifecycle events.
*/
message DisconnectRequested {
}

View File

@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -61,6 +60,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -97,6 +97,7 @@ class LinkedDeviceRefreshRequirementProviderTest {
private AccountsManager accountsManager; private AccountsManager accountsManager;
private ClientPresenceManager clientPresenceManager; private ClientPresenceManager clientPresenceManager;
private PubSubClientEventManager pubSubClientEventManager;
private LinkedDeviceRefreshRequirementProvider provider; private LinkedDeviceRefreshRequirementProvider provider;
@ -104,11 +105,12 @@ class LinkedDeviceRefreshRequirementProviderTest {
void setup() { void setup() {
accountsManager = mock(AccountsManager.class); accountsManager = mock(AccountsManager.class);
clientPresenceManager = mock(ClientPresenceManager.class); clientPresenceManager = mock(ClientPresenceManager.class);
pubSubClientEventManager = mock(PubSubClientEventManager.class);
provider = new LinkedDeviceRefreshRequirementProvider(accountsManager); provider = new LinkedDeviceRefreshRequirementProvider(accountsManager);
final WebsocketRefreshRequestEventListener listener = final WebsocketRefreshRequestEventListener listener =
new WebsocketRefreshRequestEventListener(clientPresenceManager, provider); new WebsocketRefreshRequestEventListener(clientPresenceManager, pubSubClientEventManager, provider);
when(applicationEventListener.onRequest(any())).thenReturn(listener); when(applicationEventListener.onRequest(any())).thenReturn(listener);
@ -146,6 +148,10 @@ class LinkedDeviceRefreshRequirementProviderTest {
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 1); verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 1);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 2); verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 2);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 3); verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 3);
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 1));
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 2));
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 3));
} }
@ParameterizedTest @ParameterizedTest
@ -173,8 +179,10 @@ class LinkedDeviceRefreshRequirementProviderTest {
assertEquals(200, response.getStatus()); assertEquals(200, response.getStatus());
initialDeviceIds.forEach(deviceId -> initialDeviceIds.forEach(deviceId -> {
verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId)); verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId);
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(deviceId));
});
verifyNoMoreInteractions(clientPresenceManager); verifyNoMoreInteractions(clientPresenceManager);
} }

View File

@ -28,6 +28,7 @@ import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import javax.servlet.DispatcherType; import javax.servlet.DispatcherType;
@ -47,6 +48,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -74,6 +76,7 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class); private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class);
private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class); private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
private static final ClientPresenceManager CLIENT_PRESENCE = mock(ClientPresenceManager.class); private static final ClientPresenceManager CLIENT_PRESENCE = mock(ClientPresenceManager.class);
private static final PubSubClientEventManager PUBSUB_CLIENT_PRESENCE = mock(PubSubClientEventManager.class);
private WebSocketClient client; private WebSocketClient client;
private final Account account1 = new Account(); private final Account account1 = new Account();
@ -122,9 +125,9 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter()); webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.jersey() webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE)); .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE, PUBSUB_CLIENT_PRESENCE));
environment.jersey() environment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE)); .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE, PUBSUB_CLIENT_PRESENCE));
webSocketEnvironment.setConnectListener(webSocketSessionContext -> { webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
}); });
@ -215,6 +218,10 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
verify(CLIENT_PRESENCE, timeout(5000)) verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId())); .disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE); verifyNoMoreInteractions(CLIENT_PRESENCE);
verify(PUBSUB_CLIENT_PRESENCE, timeout(5000))
.requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
verifyNoMoreInteractions(PUBSUB_CLIENT_PRESENCE);
} }
@Test @Test
@ -231,6 +238,10 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
verify(CLIENT_PRESENCE, timeout(5000)) verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId())); .disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE); verifyNoMoreInteractions(CLIENT_PRESENCE);
verify(PUBSUB_CLIENT_PRESENCE, timeout(5000))
.requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
verifyNoMoreInteractions(PUBSUB_CLIENT_PRESENCE);
} }
@ParameterizedTest @ParameterizedTest

View File

@ -35,6 +35,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -47,6 +48,7 @@ class RegistrationLockVerificationManagerTest {
private final AccountsManager accountsManager = mock(AccountsManager.class); private final AccountsManager accountsManager = mock(AccountsManager.class);
private final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class); private final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class);
private final PubSubClientEventManager pubSubClientEventManager = mock(PubSubClientEventManager.class);
private final ExternalServiceCredentialsGenerator svr2CredentialsGenerator = mock( private final ExternalServiceCredentialsGenerator svr2CredentialsGenerator = mock(
ExternalServiceCredentialsGenerator.class); ExternalServiceCredentialsGenerator.class);
private final ExternalServiceCredentialsGenerator svr3CredentialsGenerator = mock( private final ExternalServiceCredentialsGenerator svr3CredentialsGenerator = mock(
@ -56,7 +58,7 @@ class RegistrationLockVerificationManagerTest {
private static PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); private static PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class); private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager( private final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
accountsManager, clientPresenceManager, svr2CredentialsGenerator, svr3CredentialsGenerator, accountsManager, clientPresenceManager, pubSubClientEventManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters); registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters);
private final RateLimiter pinLimiter = mock(RateLimiter.class); private final RateLimiter pinLimiter = mock(RateLimiter.class);
@ -108,6 +110,7 @@ class RegistrationLockVerificationManagerTest {
verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber()); verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber());
} }
verify(clientPresenceManager).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID)); verify(clientPresenceManager).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID));
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(Device.PRIMARY_ID));
try { try {
verify(pushNotificationManager).sendAttemptLoginNotification(any(), eq("failedRegistrationLock")); verify(pushNotificationManager).sendAttemptLoginNotification(any(), eq("failedRegistrationLock"));
} catch (NotPushRegisteredException npre) {} } catch (NotPushRegisteredException npre) {}
@ -131,6 +134,7 @@ class RegistrationLockVerificationManagerTest {
} catch (NotPushRegisteredException npre) {} } catch (NotPushRegisteredException npre) {}
verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber()); verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber());
verify(clientPresenceManager, never()).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID)); verify(clientPresenceManager, never()).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID));
verify(pubSubClientEventManager, never()).requestDisconnection(any(), any());
}); });
} }
}; };
@ -169,6 +173,7 @@ class RegistrationLockVerificationManagerTest {
verify(account, never()).lockAuthTokenHash(); verify(account, never()).lockAuthTokenHash();
verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber()); verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber());
verify(clientPresenceManager, never()).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID)); verify(clientPresenceManager, never()).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID));
verify(pubSubClientEventManager, never()).requestDisconnection(any(), any());
} }
static Stream<Arguments> testSuccess() { static Stream<Arguments> testSuccess() {

View File

@ -80,6 +80,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
@ -111,6 +112,7 @@ class DeviceControllerTest {
private static final Account maxedAccount = mock(Account.class); private static final Account maxedAccount = mock(Account.class);
private static final Device primaryDevice = mock(Device.class); private static final Device primaryDevice = mock(Device.class);
private static final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class); private static final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class);
private static final PubSubClientEventManager pubSubClientEventManager = mock(PubSubClientEventManager.class);
private static final Map<String, Integer> deviceConfiguration = new HashMap<>(); private static final Map<String, Integer> deviceConfiguration = new HashMap<>();
private static final TestClock testClock = TestClock.now(); private static final TestClock testClock = TestClock.now();
@ -131,7 +133,8 @@ class DeviceControllerTest {
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class)) .addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.addProvider(new RateLimitExceededExceptionMapper()) .addProvider(new RateLimitExceededExceptionMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)) .addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager,
pubSubClientEventManager))
.addProvider(new DeviceLimitExceededExceptionMapper()) .addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(deviceController) .addResource(deviceController)
.build(); .build();

View File

@ -21,6 +21,7 @@ import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -39,6 +40,7 @@ class MessageSenderTest {
private MessageProtos.Envelope message; private MessageProtos.Envelope message;
private ClientPresenceManager clientPresenceManager; private ClientPresenceManager clientPresenceManager;
private PubSubClientEventManager pubSubClientEventManager;
private MessagesManager messagesManager; private MessagesManager messagesManager;
private PushNotificationManager pushNotificationManager; private PushNotificationManager pushNotificationManager;
private MessageSender messageSender; private MessageSender messageSender;
@ -54,9 +56,14 @@ class MessageSenderTest {
message = generateRandomMessage(); message = generateRandomMessage();
clientPresenceManager = mock(ClientPresenceManager.class); clientPresenceManager = mock(ClientPresenceManager.class);
pubSubClientEventManager = mock(PubSubClientEventManager.class);
messagesManager = mock(MessagesManager.class); messagesManager = mock(MessagesManager.class);
pushNotificationManager = mock(PushNotificationManager.class); pushNotificationManager = mock(PushNotificationManager.class);
messageSender = new MessageSender(clientPresenceManager, messagesManager, pushNotificationManager);
when(pubSubClientEventManager.handleNewMessageAvailable(any(), anyByte()))
.thenReturn(CompletableFuture.completedFuture(true));
messageSender = new MessageSender(clientPresenceManager, pubSubClientEventManager, messagesManager, pushNotificationManager);
when(account.getUuid()).thenReturn(ACCOUNT_UUID); when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(device.getId()).thenReturn(DEVICE_ID); when(device.getId()).thenReturn(DEVICE_ID);

View File

@ -0,0 +1,337 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.api.async.RedisClusterPubSubAsyncCommands;
import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.IntStream;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class PubSubClientEventManagerTest {
private PubSubClientEventManager localPresenceManager;
private PubSubClientEventManager remotePresenceManager;
private static ExecutorService clientEventExecutor;
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private static class ClientEventAdapter implements ClientEventListener {
@Override
public void handleNewMessageAvailable() {
}
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
}
}
@BeforeAll
static void setUpBeforeAll() {
clientEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
}
@BeforeEach
void setUp() {
final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), any())).thenReturn(true);
localPresenceManager = new PubSubClientEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor, experimentEnrollmentManager);
remotePresenceManager = new PubSubClientEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor, experimentEnrollmentManager);
localPresenceManager.start();
remotePresenceManager.start();
}
@AfterEach
void tearDown() {
localPresenceManager.stop();
remotePresenceManager.stop();
}
@AfterAll
static void tearDownAfterAll() {
clientEventExecutor.shutdown();
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void handleClientConnected(final boolean displaceRemotely) throws InterruptedException {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false);
final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false);
final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false);
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
synchronized (firstListenerDisplaced) {
firstListenerDisplaced.set(true);
firstListenerConnectedElsewhere.set(connectedElsewhere);
firstListenerDisplaced.notifyAll();
}
}
}).toCompletableFuture().join();
assertFalse(firstListenerDisplaced.get());
assertFalse(secondListenerDisplaced.get());
final PubSubClientEventManager displacingManager =
displaceRemotely ? remotePresenceManager : localPresenceManager;
displacingManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
secondListenerDisplaced.set(true);
}
}).toCompletableFuture().join();
synchronized (firstListenerDisplaced) {
while (!firstListenerDisplaced.get()) {
firstListenerDisplaced.wait();
}
}
assertTrue(firstListenerDisplaced.get());
assertFalse(secondListenerDisplaced.get());
assertTrue(firstListenerConnectedElsewhere.get());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void handleNewMessageAvailable(final boolean messageAvailableRemotely) throws InterruptedException {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final AtomicBoolean messageReceived = new AtomicBoolean(false);
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
@Override
public void handleNewMessageAvailable() {
synchronized (messageReceived) {
messageReceived.set(true);
messageReceived.notifyAll();
}
}
}).toCompletableFuture().join();
final PubSubClientEventManager messagePresenceManager =
messageAvailableRemotely ? remotePresenceManager : localPresenceManager;
assertTrue(messagePresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
synchronized (messageReceived) {
while (!messageReceived.get()) {
messageReceived.wait();
}
}
assertTrue(messageReceived.get());
}
@Test
void handleClientDisconnected() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final UUID connectionId =
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter())
.toCompletableFuture().join();
assertTrue(localPresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId, connectionId).toCompletableFuture().join();
assertFalse(localPresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
}
@Test
void isLocallyPresent() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
assertFalse(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
final UUID connectionId =
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter())
.toCompletableFuture()
.join();
assertTrue(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId, connectionId)
.toCompletableFuture()
.join();
assertFalse(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void requestDisconnection(final boolean requestDisconnectionRemotely) throws InterruptedException {
final UUID accountIdentifier = UUID.randomUUID();
final byte firstDeviceId = Device.PRIMARY_ID;
final byte secondDeviceId = firstDeviceId + 1;
final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false);
final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false);
final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false);
localPresenceManager.handleClientConnected(accountIdentifier, firstDeviceId, new ClientEventAdapter() {
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
synchronized (firstListenerDisplaced) {
firstListenerDisplaced.set(true);
firstListenerConnectedElsewhere.set(connectedElsewhere);
firstListenerDisplaced.notifyAll();
}
}
}).toCompletableFuture().join();
localPresenceManager.handleClientConnected(accountIdentifier, secondDeviceId, new ClientEventAdapter() {
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
synchronized (secondListenerDisplaced) {
secondListenerDisplaced.set(true);
secondListenerDisplaced.notifyAll();
}
}
}).toCompletableFuture().join();
assertFalse(firstListenerDisplaced.get());
assertFalse(secondListenerDisplaced.get());
final PubSubClientEventManager displacingManager =
requestDisconnectionRemotely ? remotePresenceManager : localPresenceManager;
displacingManager.requestDisconnection(accountIdentifier, List.of(firstDeviceId)).toCompletableFuture().join();
synchronized (firstListenerDisplaced) {
while (!firstListenerDisplaced.get()) {
firstListenerDisplaced.wait();
}
}
assertTrue(firstListenerDisplaced.get());
assertFalse(secondListenerDisplaced.get());
assertFalse(firstListenerConnectedElsewhere.get());
}
@Test
void resubscribe() {
final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), any())).thenReturn(true);
@SuppressWarnings("unchecked") final RedisClusterPubSubCommands<byte[], byte[]> pubSubCommands =
mock(RedisClusterPubSubCommands.class);
@SuppressWarnings("unchecked") final RedisClusterPubSubAsyncCommands<byte[], byte[]> pubSubAsyncCommands =
mock(RedisClusterPubSubAsyncCommands.class);
when(pubSubAsyncCommands.ssubscribe(any())).thenReturn(MockRedisFuture.completedFuture(null));
final FaultTolerantRedisClusterClient clusterClient = RedisClusterHelper.builder()
.binaryPubSubCommands(pubSubCommands)
.binaryPubSubAsyncCommands(pubSubAsyncCommands)
.build();
final PubSubClientEventManager presenceManager =
new PubSubClientEventManager(clusterClient, Runnable::run, experimentEnrollmentManager);
presenceManager.start();
final UUID firstAccountIdentifier = UUID.randomUUID();
final byte firstDeviceId = Device.PRIMARY_ID;
final int firstSlot = SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(firstAccountIdentifier, firstDeviceId));
final UUID secondAccountIdentifier;
final byte secondDeviceId = firstDeviceId + 1;
// Make sure that the two subscriptions wind up in different slots
{
UUID candidateIdentifier;
do {
candidateIdentifier = UUID.randomUUID();
} while (SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(candidateIdentifier, secondDeviceId)) == firstSlot);
secondAccountIdentifier = candidateIdentifier;
}
presenceManager.handleClientConnected(firstAccountIdentifier, firstDeviceId, new ClientEventAdapter()).toCompletableFuture().join();
presenceManager.handleClientConnected(secondAccountIdentifier, secondDeviceId, new ClientEventAdapter()).toCompletableFuture().join();
final int secondSlot = SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(secondAccountIdentifier, secondDeviceId));
final String firstNodeId = UUID.randomUUID().toString();
final RedisClusterNode firstBeforeNode = mock(RedisClusterNode.class);
when(firstBeforeNode.getNodeId()).thenReturn(firstNodeId);
when(firstBeforeNode.getSlots()).thenReturn(IntStream.range(0, SlotHash.SLOT_COUNT).boxed().toList());
final RedisClusterNode firstAfterNode = mock(RedisClusterNode.class);
when(firstAfterNode.getNodeId()).thenReturn(firstNodeId);
when(firstAfterNode.getSlots()).thenReturn(IntStream.range(0, SlotHash.SLOT_COUNT)
.filter(slot -> slot != secondSlot)
.boxed()
.toList());
final RedisClusterNode secondAfterNode = mock(RedisClusterNode.class);
when(secondAfterNode.getNodeId()).thenReturn(UUID.randomUUID().toString());
when(secondAfterNode.getSlots()).thenReturn(List.of(secondSlot));
presenceManager.resubscribe(new ClusterTopologyChangedEvent(
List.of(firstBeforeNode),
List.of(firstAfterNode, secondAfterNode)));
verify(pubSubCommands).ssubscribe(PubSubClientEventManager.getClientPresenceKey(secondAccountIdentifier, secondDeviceId));
verify(pubSubCommands, never()).ssubscribe(PubSubClientEventManager.getClientPresenceKey(firstAccountIdentifier, firstDeviceId));
}
}

View File

@ -31,6 +31,7 @@ import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
@ -46,7 +47,7 @@ class FaultTolerantPubSubClusterConnectionTest {
private TestPublisher<Event> eventPublisher; private TestPublisher<Event> eventPublisher;
private Runnable resubscribe; private Consumer<ClusterTopologyChangedEvent> resubscribe;
private AtomicInteger resubscribeCounter; private AtomicInteger resubscribeCounter;
private CountDownLatch resubscribeFailure; private CountDownLatch resubscribeFailure;
@ -93,7 +94,7 @@ class FaultTolerantPubSubClusterConnectionTest {
resubscribeCounter = new AtomicInteger(); resubscribeCounter = new AtomicInteger();
resubscribe = () -> { resubscribe = event -> {
try { try {
resubscribeCounter.incrementAndGet(); resubscribeCounter.incrementAndGet();
pubSubConnection.sync().nodes((ignored) -> true); pubSubConnection.sync().nodes((ignored) -> true);
@ -137,7 +138,7 @@ class FaultTolerantPubSubClusterConnectionTest {
void testFilterClusterTopologyChangeEvents() throws InterruptedException { void testFilterClusterTopologyChangeEvents() throws InterruptedException {
final CountDownLatch topologyEventLatch = new CountDownLatch(1); final CountDownLatch topologyEventLatch = new CountDownLatch(1);
faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(topologyEventLatch::countDown); faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(event -> topologyEventLatch.countDown());
final RedisClusterNode nodeFromDifferentCluster = mock(RedisClusterNode.class); final RedisClusterNode nodeFromDifferentCluster = mock(RedisClusterNode.class);

View File

@ -44,6 +44,7 @@ import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@ -152,6 +153,7 @@ public class AccountCreationDeletionIntegrationTest {
secureStorageClient, secureStorageClient,
svr2Client, svr2Client,
mock(ClientPresenceManager.class), mock(ClientPresenceManager.class),
mock(PubSubClientEventManager.class),
registrationRecoveryPasswordsManager, registrationRecoveryPasswordsManager,
clientPublicKeysManager, clientPublicKeysManager,
accountLockExecutor, accountLockExecutor,

View File

@ -37,6 +37,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@ -67,6 +68,7 @@ class AccountsManagerChangeNumberIntegrationTest {
private KeysManager keysManager; private KeysManager keysManager;
private ClientPresenceManager clientPresenceManager; private ClientPresenceManager clientPresenceManager;
private PubSubClientEventManager pubSubClientEventManager;
private ExecutorService accountLockExecutor; private ExecutorService accountLockExecutor;
private ExecutorService clientPresenceExecutor; private ExecutorService clientPresenceExecutor;
@ -119,6 +121,7 @@ class AccountsManagerChangeNumberIntegrationTest {
when(svr2Client.deleteBackups(any())).thenReturn(CompletableFuture.completedFuture(null)); when(svr2Client.deleteBackups(any())).thenReturn(CompletableFuture.completedFuture(null));
clientPresenceManager = mock(ClientPresenceManager.class); clientPresenceManager = mock(ClientPresenceManager.class);
pubSubClientEventManager = mock(PubSubClientEventManager.class);
final PhoneNumberIdentifiers phoneNumberIdentifiers = final PhoneNumberIdentifiers phoneNumberIdentifiers =
new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.PNI.tableName()); new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.PNI.tableName());
@ -147,6 +150,7 @@ class AccountsManagerChangeNumberIntegrationTest {
secureStorageClient, secureStorageClient,
svr2Client, svr2Client,
clientPresenceManager, clientPresenceManager,
pubSubClientEventManager,
registrationRecoveryPasswordsManager, registrationRecoveryPasswordsManager,
clientPublicKeysManager, clientPublicKeysManager,
accountLockExecutor, accountLockExecutor,
@ -281,7 +285,8 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(secondNumber, accountsManager.getByAccountIdentifier(originalUuid).map(Account::getNumber).orElseThrow()); assertEquals(secondNumber, accountsManager.getByAccountIdentifier(originalUuid).map(Account::getNumber).orElseThrow());
verify(clientPresenceManager).disconnectPresence(existingAccountUuid, Device.PRIMARY_ID); verify(clientPresenceManager).disconnectAllPresencesForUuid(existingAccountUuid);
verify(pubSubClientEventManager).requestDisconnection(existingAccountUuid);
assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalNumber)); assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalNumber));
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondNumber)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondNumber));

View File

@ -49,6 +49,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@ -134,6 +135,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(SecureStorageClient.class), mock(SecureStorageClient.class),
mock(SecureValueRecovery2Client.class), mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class), mock(ClientPresenceManager.class),
mock(PubSubClientEventManager.class),
mock(RegistrationRecoveryPasswordsManager.class), mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class), mock(ClientPublicKeysManager.class),
mock(Executor.class), mock(Executor.class),

View File

@ -15,6 +15,7 @@ import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment; import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension; import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@ -63,6 +64,7 @@ public class AccountsManagerDeviceTransferIntegrationTest {
mock(SecureStorageClient.class), mock(SecureStorageClient.class),
mock(SecureValueRecovery2Client.class), mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class), mock(ClientPresenceManager.class),
mock(PubSubClientEventManager.class),
mock(RegistrationRecoveryPasswordsManager.class), mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class), mock(ClientPublicKeysManager.class),
mock(ExecutorService.class), mock(ExecutorService.class),

View File

@ -80,6 +80,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@ -118,6 +119,7 @@ class AccountsManagerTest {
private MessagesManager messagesManager; private MessagesManager messagesManager;
private ProfilesManager profilesManager; private ProfilesManager profilesManager;
private ClientPresenceManager clientPresenceManager; private ClientPresenceManager clientPresenceManager;
private PubSubClientEventManager pubSubClientEventManager;
private ClientPublicKeysManager clientPublicKeysManager; private ClientPublicKeysManager clientPublicKeysManager;
private Map<String, UUID> phoneNumberIdentifiersByE164; private Map<String, UUID> phoneNumberIdentifiersByE164;
@ -153,6 +155,7 @@ class AccountsManagerTest {
messagesManager = mock(MessagesManager.class); messagesManager = mock(MessagesManager.class);
profilesManager = mock(ProfilesManager.class); profilesManager = mock(ProfilesManager.class);
clientPresenceManager = mock(ClientPresenceManager.class); clientPresenceManager = mock(ClientPresenceManager.class);
pubSubClientEventManager = mock(PubSubClientEventManager.class);
clientPublicKeysManager = mock(ClientPublicKeysManager.class); clientPublicKeysManager = mock(ClientPublicKeysManager.class);
dynamicConfiguration = mock(DynamicConfiguration.class); dynamicConfiguration = mock(DynamicConfiguration.class);
@ -259,6 +262,7 @@ class AccountsManagerTest {
storageClient, storageClient,
svr2Client, svr2Client,
clientPresenceManager, clientPresenceManager,
pubSubClientEventManager,
registrationRecoveryPasswordsManager, registrationRecoveryPasswordsManager,
clientPublicKeysManager, clientPublicKeysManager,
mock(Executor.class), mock(Executor.class),
@ -799,6 +803,7 @@ class AccountsManagerTest {
verify(keysManager).buildWriteItemsForRemovedDevice(account.getUuid(), account.getPhoneNumberIdentifier(), linkedDevice.getId()); verify(keysManager).buildWriteItemsForRemovedDevice(account.getUuid(), account.getPhoneNumberIdentifier(), linkedDevice.getId());
verify(clientPublicKeysManager).buildTransactWriteItemForDeletion(account.getUuid(), linkedDevice.getId()); verify(clientPublicKeysManager).buildTransactWriteItemForDeletion(account.getUuid(), linkedDevice.getId());
verify(clientPresenceManager).disconnectPresence(account.getUuid(), linkedDevice.getId()); verify(clientPresenceManager).disconnectPresence(account.getUuid(), linkedDevice.getId());
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(linkedDevice.getId()));
} }
@Test @Test
@ -817,6 +822,7 @@ class AccountsManagerTest {
verify(messagesManager, never()).clear(any(), anyByte()); verify(messagesManager, never()).clear(any(), anyByte());
verify(keysManager, never()).deleteSingleUsePreKeys(any(), anyByte()); verify(keysManager, never()).deleteSingleUsePreKeys(any(), anyByte());
verify(clientPresenceManager, never()).disconnectPresence(any(), anyByte()); verify(clientPresenceManager, never()).disconnectPresence(any(), anyByte());
verify(pubSubClientEventManager, never()).requestDisconnection(any(), any());
} }
@Test @Test
@ -886,6 +892,7 @@ class AccountsManagerTest {
verify(messagesManager, times(2)).clear(existingUuid); verify(messagesManager, times(2)).clear(existingUuid);
verify(profilesManager, times(2)).deleteAll(existingUuid); verify(profilesManager, times(2)).deleteAll(existingUuid);
verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid); verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid);
verify(pubSubClientEventManager).requestDisconnection(existingUuid);
} }
@Test @Test

View File

@ -36,6 +36,7 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@ -146,6 +147,7 @@ class AccountsManagerUsernameIntegrationTest {
mock(SecureStorageClient.class), mock(SecureStorageClient.class),
mock(SecureValueRecovery2Client.class), mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class), mock(ClientPresenceManager.class),
mock(PubSubClientEventManager.class),
mock(RegistrationRecoveryPasswordsManager.class), mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class), mock(ClientPublicKeysManager.class),
Executors.newSingleThreadExecutor(), Executors.newSingleThreadExecutor(),

View File

@ -34,6 +34,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension; import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@ -152,6 +153,7 @@ public class AddRemoveDeviceIntegrationTest {
secureStorageClient, secureStorageClient,
svr2Client, svr2Client,
mock(ClientPresenceManager.class), mock(ClientPresenceManager.class),
mock(PubSubClientEventManager.class),
registrationRecoveryPasswordsManager, registrationRecoveryPasswordsManager,
clientPublicKeysManager, clientPublicKeysManager,
accountLockExecutor, accountLockExecutor,

View File

@ -14,8 +14,12 @@ import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands; import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection;
import io.lettuce.core.cluster.pubsub.api.async.RedisClusterPubSubAsyncCommands;
import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
public class RedisClusterHelper { public class RedisClusterHelper {
@ -30,7 +34,12 @@ public class RedisClusterHelper {
final RedisAdvancedClusterAsyncCommands<String, String> stringAsyncCommands, final RedisAdvancedClusterAsyncCommands<String, String> stringAsyncCommands,
final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands, final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands,
final RedisAdvancedClusterAsyncCommands<byte[], byte[]> binaryAsyncCommands, final RedisAdvancedClusterAsyncCommands<byte[], byte[]> binaryAsyncCommands,
final RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands) { final RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands,
final RedisClusterPubSubCommands<String, String> stringPubSubCommands,
final RedisClusterPubSubAsyncCommands<String, String> stringAsyncPubSubCommands,
final RedisClusterPubSubCommands<byte[], byte[]> binaryPubSubCommands,
final RedisClusterPubSubAsyncCommands<byte[], byte[]> binaryAsyncPubSubCommands) {
final FaultTolerantRedisClusterClient cluster = mock(FaultTolerantRedisClusterClient.class); final FaultTolerantRedisClusterClient cluster = mock(FaultTolerantRedisClusterClient.class);
final StatefulRedisClusterConnection<String, String> stringConnection = mock(StatefulRedisClusterConnection.class); final StatefulRedisClusterConnection<String, String> stringConnection = mock(StatefulRedisClusterConnection.class);
final StatefulRedisClusterConnection<byte[], byte[]> binaryConnection = mock(StatefulRedisClusterConnection.class); final StatefulRedisClusterConnection<byte[], byte[]> binaryConnection = mock(StatefulRedisClusterConnection.class);
@ -59,6 +68,45 @@ public class RedisClusterHelper {
return null; return null;
}).when(cluster).useBinaryCluster(any(Consumer.class)); }).when(cluster).useBinaryCluster(any(Consumer.class));
final StatefulRedisClusterPubSubConnection<String, String> stringPubSubConnection =
mock(StatefulRedisClusterPubSubConnection.class);
final StatefulRedisClusterPubSubConnection<byte[], byte[]> binaryPubSubConnection =
mock(StatefulRedisClusterPubSubConnection.class);
final FaultTolerantPubSubClusterConnection<String, String> faultTolerantPubSubClusterConnection =
mock(FaultTolerantPubSubClusterConnection.class);
final FaultTolerantPubSubClusterConnection<byte[], byte[]> faultTolerantBinaryPubSubClusterConnection =
mock(FaultTolerantPubSubClusterConnection.class);
when(stringPubSubConnection.sync()).thenReturn(stringPubSubCommands);
when(stringPubSubConnection.async()).thenReturn(stringAsyncPubSubCommands);
when(binaryPubSubConnection.sync()).thenReturn(binaryPubSubCommands);
when(binaryPubSubConnection.async()).thenReturn(binaryAsyncPubSubCommands);
when(cluster.createPubSubConnection()).thenReturn(faultTolerantPubSubClusterConnection);
when(cluster.createBinaryPubSubConnection()).thenReturn(faultTolerantBinaryPubSubClusterConnection);
when(faultTolerantPubSubClusterConnection.withPubSubConnection(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(stringPubSubConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(stringPubSubConnection);
return null;
}).when(faultTolerantPubSubClusterConnection).usePubSubConnection(any(Consumer.class));
when(faultTolerantBinaryPubSubClusterConnection.withPubSubConnection(any(Function.class))).thenAnswer(
invocation -> {
return invocation.getArgument(0, Function.class).apply(binaryPubSubConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(binaryPubSubConnection);
return null;
}).when(faultTolerantBinaryPubSubClusterConnection).usePubSubConnection(any(Consumer.class));
return cluster; return cluster;
} }
@ -77,6 +125,18 @@ public class RedisClusterHelper {
private RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands = private RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands =
mock(RedisAdvancedClusterReactiveCommands.class); mock(RedisAdvancedClusterReactiveCommands.class);
private RedisClusterPubSubCommands<String, String> stringPubSubCommands =
mock(RedisClusterPubSubCommands.class);
private RedisClusterPubSubCommands<byte[], byte[]> binaryPubSubCommands =
mock(RedisClusterPubSubCommands.class);
private RedisClusterPubSubAsyncCommands<String, String> stringPubSubAsyncCommands =
mock(RedisClusterPubSubAsyncCommands.class);
private RedisClusterPubSubAsyncCommands<byte[], byte[]> binaryPubSubAsyncCommands =
mock(RedisClusterPubSubAsyncCommands.class);
private Builder() { private Builder() {
} }
@ -107,9 +167,33 @@ public class RedisClusterHelper {
return this; return this;
} }
public Builder stringPubSubCommands(final RedisClusterPubSubCommands<String, String> stringPubSubCommands) {
this.stringPubSubCommands = stringPubSubCommands;
return this;
}
public Builder binaryPubSubCommands(final RedisClusterPubSubCommands<byte[], byte[]> binaryPubSubCommands) {
this.binaryPubSubCommands = binaryPubSubCommands;
return this;
}
public Builder stringPubSubAsyncCommands(
final RedisClusterPubSubAsyncCommands<String, String> stringPubSubAsyncCommands) {
this.stringPubSubAsyncCommands = stringPubSubAsyncCommands;
return this;
}
public Builder binaryPubSubAsyncCommands(
final RedisClusterPubSubAsyncCommands<byte[], byte[]> binaryPubSubAsyncCommands) {
this.binaryPubSubAsyncCommands = binaryPubSubAsyncCommands;
return this;
}
public FaultTolerantRedisClusterClient build() { public FaultTolerantRedisClusterClient build() {
return RedisClusterHelper.buildMockRedisCluster(stringCommands, stringAsyncCommands, binaryCommands, binaryAsyncCommands, return RedisClusterHelper.buildMockRedisCluster(stringCommands, stringAsyncCommands, binaryCommands,
binaryReactiveCommands); binaryAsyncCommands,
binaryReactiveCommands, stringPubSubCommands, stringPubSubAsyncCommands, binaryPubSubCommands,
binaryPubSubAsyncCommands);
} }
} }

View File

@ -5,17 +5,199 @@
package org.whispersystems.textsecuregcm.util; package org.whispersystems.textsecuregcm.util;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
class RedisClusterUtilTest { class RedisClusterUtilTest {
@Test @Test
void testGetMinimalHashTag() { void testGetMinimalHashTag() {
for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) { for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) {
assertEquals(slot, SlotHash.getSlot(RedisClusterUtil.getMinimalHashTag(slot))); assertEquals(slot, SlotHash.getSlot(RedisClusterUtil.getMinimalHashTag(slot)));
}
} }
}
@ParameterizedTest
@MethodSource
void getChangedSlots(final ClusterTopologyChangedEvent event, final boolean[] expectedSlotsChanged) {
assertArrayEquals(expectedSlotsChanged, RedisClusterUtil.getChangedSlots(event));
}
private static List<Arguments> getChangedSlots() {
final List<Arguments> arguments = new ArrayList<>();
// Slot moved from one node to another
{
final String firstNodeId = UUID.randomUUID().toString();
final String secondNodeId = UUID.randomUUID().toString();
final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId);
when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192));
final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384));
final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class);
when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId);
when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8191));
final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8191, 16384));
final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
List.of(firstNodeBefore, secondNodeBefore),
List.of(firstNodeAfter, secondNodeAfter));
final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
slotsChanged[8191] = true;
arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
}
// New node added to cluster
{
final String firstNodeId = UUID.randomUUID().toString();
final String secondNodeId = UUID.randomUUID().toString();
final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId);
when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192));
final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384));
final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class);
when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId);
when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8192));
final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8192, 12288));
final RedisClusterNode thirdNodeAfter = mock(RedisClusterNode.class);
when(thirdNodeAfter.getNodeId()).thenReturn(UUID.randomUUID().toString());
when(thirdNodeAfter.getSlots()).thenReturn(getSlotRange(12288, 16384));
final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
List.of(firstNodeBefore, secondNodeBefore),
List.of(firstNodeAfter, secondNodeAfter, thirdNodeAfter));
final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
for (int slot = 12288; slot < 16384; slot++) {
slotsChanged[slot] = true;
}
arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
}
// Node removed from cluster
{
final String firstNodeId = UUID.randomUUID().toString();
final String secondNodeId = UUID.randomUUID().toString();
final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId);
when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192));
final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 12288));
final RedisClusterNode thirdNodeBefore = mock(RedisClusterNode.class);
when(thirdNodeBefore.getNodeId()).thenReturn(UUID.randomUUID().toString());
when(thirdNodeBefore.getSlots()).thenReturn(getSlotRange(12288, 16384));
final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class);
when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId);
when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8192));
final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8192, 16384));
final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
List.of(firstNodeBefore, secondNodeBefore, thirdNodeBefore),
List.of(firstNodeAfter, secondNodeAfter));
final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
for (int slot = 12288; slot < 16384; slot++) {
slotsChanged[slot] = true;
}
arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
}
// Node added, node removed, and slot moved
// Node removed from cluster
{
final String secondNodeId = UUID.randomUUID().toString();
final String thirdNodeId = UUID.randomUUID().toString();
final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
when(firstNodeBefore.getNodeId()).thenReturn(UUID.randomUUID().toString());
when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 1));
final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(1, 8192));
final RedisClusterNode thirdNodeBefore = mock(RedisClusterNode.class);
when(thirdNodeBefore.getNodeId()).thenReturn(thirdNodeId);
when(thirdNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384));
final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8191));
final RedisClusterNode thirdNodeAfter = mock(RedisClusterNode.class);
when(thirdNodeAfter.getNodeId()).thenReturn(thirdNodeId);
when(thirdNodeAfter.getSlots()).thenReturn(getSlotRange(8191, 16383));
final RedisClusterNode fourthNodeAfter = mock(RedisClusterNode.class);
when(fourthNodeAfter.getNodeId()).thenReturn(UUID.randomUUID().toString());
when(fourthNodeAfter.getSlots()).thenReturn(getSlotRange(16383, 16384));
final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
List.of(firstNodeBefore, secondNodeBefore, thirdNodeBefore),
List.of(secondNodeAfter, thirdNodeAfter, fourthNodeAfter));
final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
slotsChanged[0] = true;
slotsChanged[8191] = true;
slotsChanged[16383] = true;
arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
}
return arguments;
}
private static List<Integer> getSlotRange(final int startInclusive, final int endExclusive) {
final List<Integer> slots = new ArrayList<>(endExclusive - startInclusive);
for (int i = startInclusive; i < endExclusive; i++) {
slots.add(i);
}
return slots;
}
} }

View File

@ -58,6 +58,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
@ -124,8 +125,8 @@ class WebSocketConnectionTest {
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class)); new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager, AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class),
mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager, mock(ClientPresenceManager.class), mock(PubSubClientEventManager.class), retrySchedulingExecutor,
mock(MessageDeliveryLoopMonitor.class)); messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class));
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))