diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 2735699cf..30720a9db 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -61,6 +61,7 @@ import reactor.core.observability.micrometer.Micrometer; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; +import javax.annotation.Nullable; public class WebSocketConnection implements ClientEventListener { @@ -314,70 +315,78 @@ public class WebSocketConnection implements ClientEventListener { void processStoredMessages() { if (processStoredMessagesSemaphore.tryAcquire()) { final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY); - final CompletableFuture queueCleared = new CompletableFuture<>(); + final boolean cachedMessagesOnly = state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE; + sendMessages(cachedMessagesOnly) + // Update our state with the outcome, send the empty queue message if we need to, and release the semaphore + .whenComplete((ignored, cause) -> { + try { + if (cause != null) { + // We failed, if the state is currently EMPTY, set it to what it was before we tried + storedMessageState.compareAndSet(StoredMessageState.EMPTY, state); + return; + } - sendMessages(state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE, queueCleared); + // Cleared the queue! Send a queue empty message if we need to + consecutiveRetries.set(0); + if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { + final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); + final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get(); - setQueueClearedHandler(state, queueCleared); + Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum()); + Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDuration, TimeUnit.MILLISECONDS); + + if (drainDuration > SLOW_DRAIN_THRESHOLD) { + Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment(); + } + + client.sendRequest("PUT", "/api/v1/queue/empty", + Collections.singletonList(HeaderUtils.getTimestampHeader()), Optional.empty()); + } + } finally { + processStoredMessagesSemaphore.release(); + } + }) + // Potentially kick off more work, must happen after we release the semaphore + .whenComplete((ignored, cause) -> processMoreIfRequested(cause)); } } - private void setQueueClearedHandler(final StoredMessageState state, final CompletableFuture queueCleared) { - - queueCleared.whenComplete((v, cause) -> { - if (cause == null) { - consecutiveRetries.set(0); - - if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { - final List tags = List.of( - UserAgentTagUtil.getPlatformTag(client.getUserAgent()) - ); - final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get(); - - Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum()); - Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDuration, TimeUnit.MILLISECONDS); - - if (drainDuration > SLOW_DRAIN_THRESHOLD) { - Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment(); - } - - client.sendRequest("PUT", "/api/v1/queue/empty", - Collections.singletonList(HeaderUtils.getTimestampHeader()), Optional.empty()); - } - } else { - storedMessageState.compareAndSet(StoredMessageState.EMPTY, state); + /** + * After processing messages, kick off another processing job if more messages came in or if there was an error + * + * @param cause An error that was encountered when processing the message queue, if there was one + */ + private void processMoreIfRequested(final @Nullable Throwable cause) { + if (cause == null) { + // Success, but check if more messages came in while we were processing + if (storedMessageState.get() != StoredMessageState.EMPTY) { + processStoredMessages(); } + return; + } - processStoredMessagesSemaphore.release(); + if (!client.isOpen()) { + logger.debug("Client disconnected before queue cleared"); + return; + } - if (cause == null) { - if (storedMessageState.get() != StoredMessageState.EMPTY) { - processStoredMessages(); - } - } else { - if (client.isOpen()) { + if (consecutiveRetries.incrementAndGet() > MAX_CONSECUTIVE_RETRIES) { + logger.warn("Max consecutive retries exceeded", cause); + client.close(1011, "Failed to retrieve messages"); + return; + } - if (consecutiveRetries.incrementAndGet() > MAX_CONSECUTIVE_RETRIES) { - logger.warn("Max consecutive retries exceeded", cause); - client.close(1011, "Failed to retrieve messages"); - } else { - logger.debug("Failed to clear queue", cause); - final List tags = List.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); + logger.debug("Failed to clear queue", cause); + final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); - Metrics.counter(QUEUE_DRAIN_RETRY_COUNTER_NAME, tags).increment(); + Metrics.counter(QUEUE_DRAIN_RETRY_COUNTER_NAME, tags).increment(); - final long delay = RETRY_DELAY_MILLIS + random.nextInt(RETRY_DELAY_JITTER_MILLIS); - retryFuture - .set(scheduledExecutorService.schedule(this::processStoredMessages, delay, TimeUnit.MILLISECONDS)); - } - } else { - logger.debug("Client disconnected before queue cleared"); - } - } - }); + final long delay = RETRY_DELAY_MILLIS + random.nextInt(RETRY_DELAY_JITTER_MILLIS); + retryFuture.set(scheduledExecutorService.schedule(this::processStoredMessages, delay, TimeUnit.MILLISECONDS)); } - private void sendMessages(final boolean cachedMessagesOnly, final CompletableFuture queueCleared) { + private CompletableFuture sendMessages(final boolean cachedMessagesOnly) { + final CompletableFuture queueCleared = new CompletableFuture<>(); final Publisher messages = messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), cachedMessagesOnly); @@ -423,6 +432,7 @@ public class WebSocketConnection implements ClientEventListener { ); messageSubscription.set(subscription); + return queueCleared; } private void measureSendMessageErrors(Throwable e, final boolean terminal) {