diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 333fbe971..7a6553fcc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -165,6 +165,7 @@ import org.whispersystems.textsecuregcm.grpc.net.NoiseWebSocketTunnelServer; import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer; import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; +import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.PushChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -675,6 +676,8 @@ public class WhisperServerService extends Application { + if (deliveryAttemptCount == DELIVERY_LOOP_THRESHOLD) { + logger.warn("Detected loop delivering message {} via {} to {}:{} ({})", + messageGuid, accountIdentifier, deviceId, context, userAgent); + } + }); + } + + @VisibleForTesting + CompletableFuture incrementDeliveryAttemptCount(final UUID accountIdentifier, final byte deviceId, final UUID messageGuid) { + final String firstMessageGuidKey = "firstMessageGuid::{" + accountIdentifier + ":" + deviceId + "}"; + final String deliveryAttemptsKey = "firstMessageDeliveryAttempts::{" + accountIdentifier + ":" + deviceId + "}"; + + return getDeliveryAttemptsScript.executeAsync( + List.of(firstMessageGuidKey, deliveryAttemptsKey), + List.of(messageGuid.toString(), String.valueOf(DELIVERY_ATTEMPTS_COUNTER_TTL.toSeconds()))) + .thenApply(result -> (long) result); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 877142bff..9b7356d98 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -20,6 +20,7 @@ import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; @@ -58,6 +59,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final ScheduledExecutorService scheduledExecutorService; private final Scheduler messageDeliveryScheduler; private final ClientReleaseManager clientReleaseManager; + private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final Map openAuthenticatedWebsocketsByClientPlatform; private final Map openUnauthenticatedWebsocketsByClientPlatform; @@ -77,7 +79,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { ClientPresenceManager clientPresenceManager, ScheduledExecutorService scheduledExecutorService, Scheduler messageDeliveryScheduler, - ClientReleaseManager clientReleaseManager) { + ClientReleaseManager clientReleaseManager, + MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) { this.receiptSender = receiptSender; this.messagesManager = messagesManager; this.messageMetrics = messageMetrics; @@ -87,6 +90,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { this.scheduledExecutorService = scheduledExecutorService; this.messageDeliveryScheduler = messageDeliveryScheduler; this.clientReleaseManager = clientReleaseManager; + this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; openAuthenticatedWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class); openUnauthenticatedWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class); @@ -151,7 +155,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { context.getClient(), scheduledExecutorService, messageDeliveryScheduler, - clientReleaseManager); + clientReleaseManager, + messageDeliveryLoopMonitor); openWebsocketAtomicInteger.incrementAndGet(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 83ed82e6d..2334f24d9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -39,7 +39,9 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; @@ -117,6 +119,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private final MessageMetrics messageMetrics; private final PushNotificationManager pushNotificationManager; private final PushNotificationScheduler pushNotificationScheduler; + private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final AuthenticatedDevice auth; private final WebSocketClient client; @@ -155,7 +158,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac WebSocketClient client, ScheduledExecutorService scheduledExecutorService, Scheduler messageDeliveryScheduler, - ClientReleaseManager clientReleaseManager) { + ClientReleaseManager clientReleaseManager, + MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) { this(receiptSender, messagesManager, @@ -167,7 +171,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS, scheduledExecutorService, messageDeliveryScheduler, - clientReleaseManager); + clientReleaseManager, + messageDeliveryLoopMonitor); } @VisibleForTesting @@ -181,7 +186,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac int sendFuturesTimeoutMillis, ScheduledExecutorService scheduledExecutorService, Scheduler messageDeliveryScheduler, - ClientReleaseManager clientReleaseManager) { + ClientReleaseManager clientReleaseManager, + MessageDeliveryLoopMonitor messageDeliveryLoopMonitor) { this.receiptSender = receiptSender; this.messagesManager = messagesManager; @@ -194,6 +200,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac this.scheduledExecutorService = scheduledExecutorService; this.messageDeliveryScheduler = messageDeliveryScheduler; this.clientReleaseManager = clientReleaseManager; + this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; } public void start() { @@ -378,12 +385,22 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac final Publisher messages = messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), cachedMessagesOnly); + final AtomicBoolean hasSentFirstMessage = new AtomicBoolean(); final AtomicBoolean hasErrored = new AtomicBoolean(); final Disposable subscription = Flux.from(messages) .name(SEND_MESSAGES_FLUX_NAME) .tap(Micrometer.metrics(Metrics.globalRegistry)) .limitRate(MESSAGE_PUBLISHER_LIMIT_RATE) + .doOnNext(envelope -> { + if (hasSentFirstMessage.compareAndSet(false, true)) { + messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccount().getIdentifier(IdentityType.ACI), + auth.getAuthenticatedDevice().getId(), + UUID.fromString(envelope.getServerGuid()), + client.getUserAgent(), + "websocket"); + } + }) .flatMapSequential(envelope -> Mono.fromFuture(() -> sendMessage(envelope) .orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS)) diff --git a/service/src/main/resources/lua/get_delivery_attempt_count.lua b/service/src/main/resources/lua/get_delivery_attempt_count.lua new file mode 100644 index 000000000..beba67e4f --- /dev/null +++ b/service/src/main/resources/lua/get_delivery_attempt_count.lua @@ -0,0 +1,13 @@ +local firstMessageGuidKey = KEYS[1] +local firstMessageAttemptsKey = KEYS[2] + +local firstMessageGuid = ARGV[1] +local ttlSeconds = ARGV[2] + +if firstMessageGuid ~= redis.call("GET", firstMessageGuidKey) then + -- This is the first time we've attempted to deliver this message as the first message in a "page" + redis.call("SET", firstMessageGuidKey, firstMessageGuid, "EX", ttlSeconds) + redis.call("SET", firstMessageAttemptsKey, 0, "EX", ttlSeconds) +end + +return redis.call("INCR", firstMessageAttemptsKey) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 3339a7fd3..e4f0ce911 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -103,6 +103,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; +import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; @@ -205,7 +206,8 @@ class MessageControllerTest { new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager, messagesManager, pushNotificationManager, pushNotificationScheduler, reportMessageManager, multiRecipientMessageExecutor, messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager, - serverSecretParams, SpamChecker.noop(), new MessageMetrics(), clock)) + serverSecretParams, SpamChecker.noop(), new MessageMetrics(), mock(MessageDeliveryLoopMonitor.class), + clock)) .build(); @BeforeEach diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/MessageDeliveryLoopMonitorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/MessageDeliveryLoopMonitorTest.java new file mode 100644 index 000000000..a4c09fb01 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/MessageDeliveryLoopMonitorTest.java @@ -0,0 +1,38 @@ +package org.whispersystems.textsecuregcm.limits; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.storage.Device; + +class MessageDeliveryLoopMonitorTest { + + private MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + @BeforeEach + void setUp() { + messageDeliveryLoopMonitor = new MessageDeliveryLoopMonitor(REDIS_CLUSTER_EXTENSION.getRedisCluster()); + } + + @Test + void incrementDeliveryAttemptCount() { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + assertEquals(1, messageDeliveryLoopMonitor.incrementDeliveryAttemptCount(accountIdentifier, deviceId, UUID.randomUUID()).join()); + assertEquals(1, messageDeliveryLoopMonitor.incrementDeliveryAttemptCount(accountIdentifier, deviceId, UUID.randomUUID()).join()); + + final UUID repeatedDeliveryGuid = UUID.randomUUID(); + + for (int i = 1; i < 10; i++) { + assertEquals(i, messageDeliveryLoopMonitor.incrementDeliveryAttemptCount(accountIdentifier, deviceId, repeatedDeliveryGuid).join()); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 53622c21e..e9d0863d6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -46,6 +46,7 @@ import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; @@ -133,7 +134,8 @@ class WebSocketConnectionIntegrationTest { webSocketClient, scheduledExecutorService, messageDeliveryScheduler, - clientReleaseManager); + clientReleaseManager, + mock(MessageDeliveryLoopMonitor.class)); final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); @@ -220,7 +222,8 @@ class WebSocketConnectionIntegrationTest { webSocketClient, scheduledExecutorService, messageDeliveryScheduler, - clientReleaseManager); + clientReleaseManager, + mock(MessageDeliveryLoopMonitor.class)); final int persistedMessageCount = 207; final int cachedMessageCount = 173; @@ -289,7 +292,8 @@ class WebSocketConnectionIntegrationTest { 100, // use a very short timeout, so that this test completes quickly scheduledExecutorService, messageDeliveryScheduler, - clientReleaseManager); + clientReleaseManager, + mock(MessageDeliveryLoopMonitor.class)); final int persistedMessageCount = 207; final int cachedMessageCount = 173; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 0a114c051..9c4346686 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -56,6 +56,7 @@ import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; @@ -124,7 +125,8 @@ class WebSocketConnectionTest { new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class)); AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager, new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), - mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager); + mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager, + mock(MessageDeliveryLoopMonitor.class)); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) @@ -628,7 +630,7 @@ class WebSocketConnectionTest { private @NotNull WebSocketConnection webSocketConnection(final WebSocketClient client) { return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client, - retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager); + retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class)); } @Test