Detect message delivery loops

This commit is contained in:
Jon Chambers 2024-08-30 16:27:21 -04:00 committed by GitHub
parent 4c628b1cd9
commit f09cc03164
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 182 additions and 13 deletions

View File

@ -165,6 +165,7 @@ import org.whispersystems.textsecuregcm.grpc.net.NoiseWebSocketTunnelServer;
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.limits.PushChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -675,6 +676,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getRedeemedReceipts().getExpiration());
Subscriptions subscriptions = new Subscriptions(
config.getDynamoDbTables().getSubscriptions().getTableName(), dynamoDbAsyncClient);
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor =
new MessageDeliveryLoopMonitor(rateLimitersCluster);
final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
accountsManager, clientPresenceManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
@ -1015,7 +1018,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager,
pushNotificationScheduler, clientPresenceManager, websocketScheduledExecutor, messageDeliveryScheduler,
clientReleaseManager));
clientReleaseManager, messageDeliveryLoopMonitor));
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET));
@ -1118,7 +1121,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender,
accountsManager, messagesManager, pushNotificationManager, pushNotificationScheduler, reportMessageManager,
multiRecipientMessageExecutor, messageDeliveryScheduler, reportSpamTokenProvider, clientReleaseManager,
dynamicConfigurationManager, zkSecretParams, spamChecker, messageMetrics, Clock.systemUTC()),
dynamicConfigurationManager, zkSecretParams, spamChecker, messageMetrics, messageDeliveryLoopMonitor,
Clock.systemUTC()),
new PaymentsController(currencyManager, paymentsCredentialsGenerator),
new ProfileController(clock, rateLimiters, accountsManager, profilesManager, dynamicConfigurationManager,
profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner,

View File

@ -106,6 +106,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
@ -168,6 +169,7 @@ public class MessageController {
private final ServerSecretParams serverSecretParams;
private final SpamChecker spamChecker;
private final MessageMetrics messageMetrics;
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final Clock clock;
private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
@ -230,6 +232,7 @@ public class MessageController {
final ServerSecretParams serverSecretParams,
final SpamChecker spamChecker,
final MessageMetrics messageMetrics,
final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
final Clock clock) {
this.rateLimiters = rateLimiters;
this.messageByteLimitEstimator = messageByteLimitEstimator;
@ -248,6 +251,7 @@ public class MessageController {
this.serverSecretParams = serverSecretParams;
this.spamChecker = spamChecker;
this.messageMetrics = messageMetrics;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
this.clock = clock;
}
@ -785,6 +789,14 @@ public class MessageController {
Metrics.summary(OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.record(estimateMessageListSizeBytes(messages));
if (!messages.messages().isEmpty()) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccount().getIdentifier(IdentityType.ACI),
auth.getAuthenticatedDevice().getId(),
messages.messages().getFirst().guid(),
userAgent,
"rest");
}
if (messagesAndHasMore.second()) {
pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(), auth.getAuthenticatedDevice(), NOTIFY_FOR_REMAINING_MESSAGES_DELAY);
}

View File

@ -0,0 +1,72 @@
package org.whispersystems.textsecuregcm.limits;
import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
public class MessageDeliveryLoopMonitor {
private final ClusterLuaScript getDeliveryAttemptsScript;
private static final Duration DELIVERY_ATTEMPTS_COUNTER_TTL = Duration.ofHours(1);
private static final int DELIVERY_LOOP_THRESHOLD = 5;
private static final Logger logger = LoggerFactory.getLogger(MessageDeliveryLoopMonitor.class);
public MessageDeliveryLoopMonitor(final FaultTolerantRedisCluster rateLimitCluster) {
try {
getDeliveryAttemptsScript =
ClusterLuaScript.fromResource(rateLimitCluster, "lua/get_delivery_attempt_count.lua", ScriptOutputType.INTEGER);
} catch (final IOException e) {
throw new UncheckedIOException("Failed to load 'get delivery attempt count' script", e);
}
}
/**
* Records an attempt to deliver a message with the given GUID to the given account/device pair and returns the number
* of consecutive attempts to deliver the same message and logs a warning if the message appears to be in a delivery
* loop. This method is intended to detect cases where a message remains at the head of a device's queue after
* repeated attempts to deliver the message, and so the given message GUID should be the first message of a "page"
* sent to clients.
*
* @param accountIdentifier the identifier of the destination account
* @param deviceId the destination device's ID within the given account
* @param messageGuid the GUID of the message
* @param userAgent the User-Agent header supplied by the caller
* @param context a human-readable string identifying the mechanism of message delivery (e.g. "rest" or "websocket")
*/
public void recordDeliveryAttempt(final UUID accountIdentifier,
final byte deviceId,
final UUID messageGuid,
final String userAgent,
final String context) {
incrementDeliveryAttemptCount(accountIdentifier, deviceId, messageGuid)
.thenAccept(deliveryAttemptCount -> {
if (deliveryAttemptCount == DELIVERY_LOOP_THRESHOLD) {
logger.warn("Detected loop delivering message {} via {} to {}:{} ({})",
messageGuid, accountIdentifier, deviceId, context, userAgent);
}
});
}
@VisibleForTesting
CompletableFuture<Long> incrementDeliveryAttemptCount(final UUID accountIdentifier, final byte deviceId, final UUID messageGuid) {
final String firstMessageGuidKey = "firstMessageGuid::{" + accountIdentifier + ":" + deviceId + "}";
final String deliveryAttemptsKey = "firstMessageDeliveryAttempts::{" + accountIdentifier + ":" + deviceId + "}";
return getDeliveryAttemptsScript.executeAsync(
List.of(firstMessageGuidKey, deliveryAttemptsKey),
List.of(messageGuid.toString(), String.valueOf(DELIVERY_ATTEMPTS_COUNTER_TTL.toSeconds())))
.thenApply(result -> (long) result);
}
}

View File

@ -20,6 +20,7 @@ import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
@ -58,6 +59,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final ScheduledExecutorService scheduledExecutorService;
private final Scheduler messageDeliveryScheduler;
private final ClientReleaseManager clientReleaseManager;
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final Map<ClientPlatform, AtomicInteger> openAuthenticatedWebsocketsByClientPlatform;
private final Map<ClientPlatform, AtomicInteger> openUnauthenticatedWebsocketsByClientPlatform;
@ -77,7 +79,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
ClientPresenceManager clientPresenceManager,
ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager) {
ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) {
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.messageMetrics = messageMetrics;
@ -87,6 +90,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
this.scheduledExecutorService = scheduledExecutorService;
this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
openAuthenticatedWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class);
openUnauthenticatedWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class);
@ -151,7 +155,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
context.getClient(),
scheduledExecutorService,
messageDeliveryScheduler,
clientReleaseManager);
clientReleaseManager,
messageDeliveryLoopMonitor);
openWebsocketAtomicInteger.incrementAndGet();

View File

@ -39,7 +39,9 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
@ -117,6 +119,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private final MessageMetrics messageMetrics;
private final PushNotificationManager pushNotificationManager;
private final PushNotificationScheduler pushNotificationScheduler;
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final AuthenticatedDevice auth;
private final WebSocketClient client;
@ -155,7 +158,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
WebSocketClient client,
ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager) {
ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) {
this(receiptSender,
messagesManager,
@ -167,7 +171,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS,
scheduledExecutorService,
messageDeliveryScheduler,
clientReleaseManager);
clientReleaseManager,
messageDeliveryLoopMonitor);
}
@VisibleForTesting
@ -181,7 +186,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
int sendFuturesTimeoutMillis,
ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager) {
ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) {
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
@ -194,6 +200,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
this.scheduledExecutorService = scheduledExecutorService;
this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
}
public void start() {
@ -378,12 +385,22 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
final Publisher<Envelope> messages =
messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), cachedMessagesOnly);
final AtomicBoolean hasSentFirstMessage = new AtomicBoolean();
final AtomicBoolean hasErrored = new AtomicBoolean();
final Disposable subscription = Flux.from(messages)
.name(SEND_MESSAGES_FLUX_NAME)
.tap(Micrometer.metrics(Metrics.globalRegistry))
.limitRate(MESSAGE_PUBLISHER_LIMIT_RATE)
.doOnNext(envelope -> {
if (hasSentFirstMessage.compareAndSet(false, true)) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccount().getIdentifier(IdentityType.ACI),
auth.getAuthenticatedDevice().getId(),
UUID.fromString(envelope.getServerGuid()),
client.getUserAgent(),
"websocket");
}
})
.flatMapSequential(envelope ->
Mono.fromFuture(() -> sendMessage(envelope)
.orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS))

View File

@ -0,0 +1,13 @@
local firstMessageGuidKey = KEYS[1]
local firstMessageAttemptsKey = KEYS[2]
local firstMessageGuid = ARGV[1]
local ttlSeconds = ARGV[2]
if firstMessageGuid ~= redis.call("GET", firstMessageGuidKey) then
-- This is the first time we've attempted to deliver this message as the first message in a "page"
redis.call("SET", firstMessageGuidKey, firstMessageGuid, "EX", ttlSeconds)
redis.call("SET", firstMessageAttemptsKey, 0, "EX", ttlSeconds)
end
return redis.call("INCR", firstMessageAttemptsKey)

View File

@ -103,6 +103,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
@ -205,7 +206,8 @@ class MessageControllerTest {
new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager,
messagesManager, pushNotificationManager, pushNotificationScheduler, reportMessageManager, multiRecipientMessageExecutor,
messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager,
serverSecretParams, SpamChecker.noop(), new MessageMetrics(), clock))
serverSecretParams, SpamChecker.noop(), new MessageMetrics(), mock(MessageDeliveryLoopMonitor.class),
clock))
.build();
@BeforeEach

View File

@ -0,0 +1,38 @@
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Device;
class MessageDeliveryLoopMonitorTest {
private MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@BeforeEach
void setUp() {
messageDeliveryLoopMonitor = new MessageDeliveryLoopMonitor(REDIS_CLUSTER_EXTENSION.getRedisCluster());
}
@Test
void incrementDeliveryAttemptCount() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
assertEquals(1, messageDeliveryLoopMonitor.incrementDeliveryAttemptCount(accountIdentifier, deviceId, UUID.randomUUID()).join());
assertEquals(1, messageDeliveryLoopMonitor.incrementDeliveryAttemptCount(accountIdentifier, deviceId, UUID.randomUUID()).join());
final UUID repeatedDeliveryGuid = UUID.randomUUID();
for (int i = 1; i < 10; i++) {
assertEquals(i, messageDeliveryLoopMonitor.incrementDeliveryAttemptCount(accountIdentifier, deviceId, repeatedDeliveryGuid).join());
}
}
}

View File

@ -46,6 +46,7 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
@ -133,7 +134,8 @@ class WebSocketConnectionIntegrationTest {
webSocketClient,
scheduledExecutorService,
messageDeliveryScheduler,
clientReleaseManager);
clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class));
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
@ -220,7 +222,8 @@ class WebSocketConnectionIntegrationTest {
webSocketClient,
scheduledExecutorService,
messageDeliveryScheduler,
clientReleaseManager);
clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class));
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
@ -289,7 +292,8 @@ class WebSocketConnectionIntegrationTest {
100, // use a very short timeout, so that this test completes quickly
scheduledExecutorService,
messageDeliveryScheduler,
clientReleaseManager);
clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class));
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;

View File

@ -56,6 +56,7 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@ -124,7 +125,8 @@ class WebSocketConnectionTest {
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class),
mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager);
mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class));
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
@ -628,7 +630,7 @@ class WebSocketConnectionTest {
private @NotNull WebSocketConnection webSocketConnection(final WebSocketClient client) {
return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client,
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager);
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class));
}
@Test