diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 30d56d9f4..2b7d60a59 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -210,7 +210,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac client.close(1000, "OK"); } - private CompletableFuture sendMessage(final Envelope message, StoredMessageInfo storedMessageInfo) { + private CompletableFuture sendMessage(final Envelope message, StoredMessageInfo storedMessageInfo) { // clear ephemeral field from the envelope final Optional body = Optional.ofNullable(message.toBuilder().clearEphemeral().build().toByteArray()); @@ -227,11 +227,12 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac sendFailuresMeter.mark(); } }).thenCompose(response -> { - final CompletableFuture result; + final CompletableFuture result; if (isSuccessResponse(response)) { result = messagesManager.delete(auth.getAccount().getUuid(), device.getId(), - storedMessageInfo.guid(), storedMessageInfo.serverTimestamp()); + storedMessageInfo.guid(), storedMessageInfo.serverTimestamp()) + .thenApply(ignored -> null); if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) { recordMessageDeliveryDuration(message.getTimestamp(), device); @@ -364,31 +365,37 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac .limitRate(MESSAGE_PUBLISHER_LIMIT_RATE) .flatMapSequential(envelope -> Mono.fromFuture(sendMessage(envelope) - .orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS)) - .doOnError(e -> { - final String errorType; - if (e instanceof TimeoutException) { - errorType = "timeout"; - } else if (e instanceof java.nio.channels.ClosedChannelException) { - errorType = "closedChannel"; - } else { - logger.warn("Send message failed", e); - errorType = "other"; - } - final Tags tags = Tags.of( - UserAgentTagUtil.getPlatformTag(client.getUserAgent()), - Tag.of(ERROR_TYPE_TAG, errorType)); - Metrics.counter(SEND_MESSAGE_ERROR_COUNTER, tags).increment(); - })) - .doOnError(queueCleared::completeExceptionally) - .doOnComplete(() -> queueCleared.complete(null)) + .orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS))) .subscribeOn(reactiveScheduler) - .subscribe(); + .subscribe( + // no additional consumer of values - it is Flux by now + null, + // the first error will terminate the stream, but we may get multiple errors from in-flight messages + e -> { + queueCleared.completeExceptionally(e); + + final String errorType; + if (e instanceof TimeoutException) { + errorType = "timeout"; + } else if (e instanceof java.nio.channels.ClosedChannelException) { + errorType = "closedChannel"; + } else { + logger.warn("Send message failed", e); + errorType = "other"; + } + final Tags tags = Tags.of( + UserAgentTagUtil.getPlatformTag(client.getUserAgent()), + Tag.of(ERROR_TYPE_TAG, errorType)); + Metrics.counter(SEND_MESSAGE_ERROR_COUNTER, tags).increment(); + }, + // completion + () -> queueCleared.complete(null) + ); messageSubscription.set(subscription); } - private CompletableFuture sendMessage(Envelope envelope) { + private CompletableFuture sendMessage(Envelope envelope) { final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); if (envelope.getStory() && !client.shouldDeliverStories()) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index d7a9b793f..982d98aa8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -87,6 +87,7 @@ class WebSocketConnectionTest { private Device device; private AuthenticatedAccount auth; private UpgradeRequest upgradeRequest; + private MessagesManager messagesManager; private ReceiptSender receiptSender; private ScheduledExecutorService retrySchedulingExecutor; @@ -98,6 +99,7 @@ class WebSocketConnectionTest { device = mock(Device.class); auth = new AuthenticatedAccount(() -> new Pair<>(account, device)); upgradeRequest = mock(UpgradeRequest.class); + messagesManager = mock(MessagesManager.class); receiptSender = mock(ReceiptSender.class); retrySchedulingExecutor = mock(ScheduledExecutorService.class); } @@ -109,9 +111,8 @@ class WebSocketConnectionTest { @Test void testCredentials() { - MessagesManager storedMessages = mock(MessagesManager.class); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); - AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, + AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager, mock(PushNotificationManager.class), mock(ClientPresenceManager.class), retrySchedulingExecutor); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); @@ -145,7 +146,6 @@ class WebSocketConnectionTest { @Test void testOpen() { - MessagesManager storedMessages = mock(MessagesManager.class); UUID accountUuid = UUID.randomUUID(); UUID senderOneUuid = UUID.randomUUID(); @@ -171,9 +171,12 @@ class WebSocketConnectionTest { when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); + when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn( + CompletableFuture.completedFuture(Optional.empty())); + String userAgent = HttpHeaders.USER_AGENT; - when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) .thenReturn(Flux.fromIterable(outgoingMessages)); final List> futures = new LinkedList<>(); @@ -187,7 +190,7 @@ class WebSocketConnectionTest { return future; }); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, + WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); connection.start(); @@ -203,7 +206,7 @@ class WebSocketConnectionTest { futures.get(0).completeExceptionally(new IOException()); futures.get(2).completeExceptionally(new IOException()); - verify(storedMessages, times(1)).delete(eq(accountUuid), eq(deviceId), + verify(messagesManager, times(1)).delete(eq(accountUuid), eq(deviceId), eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp())); verify(receiptSender, times(1)).sendReceipt(eq(accountUuid), eq(deviceId), eq(senderOneUuid), eq(2222L)); @@ -214,7 +217,6 @@ class WebSocketConnectionTest { @Test public void testOnlineSend() { - final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); @@ -277,8 +279,6 @@ class WebSocketConnectionTest { @Test void testPendingSend() { - MessagesManager storedMessages = mock(MessagesManager.class); - final UUID accountUuid = UUID.randomUUID(); final UUID senderTwoUuid = UUID.randomUUID(); @@ -319,9 +319,12 @@ class WebSocketConnectionTest { when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); + when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn( + CompletableFuture.completedFuture(Optional.empty())); + String userAgent = HttpHeaders.USER_AGENT; - when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) .thenReturn(Flux.fromIterable(pendingMessages)); final List> futures = new LinkedList<>(); @@ -335,7 +338,7 @@ class WebSocketConnectionTest { return future; }); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, + WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); connection.start(); @@ -358,7 +361,6 @@ class WebSocketConnectionTest { @Test void testProcessStoredMessageConcurrency() { - final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); @@ -424,7 +426,6 @@ class WebSocketConnectionTest { @Test void testProcessStoredMessagesMultiplePages() { - final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); @@ -477,7 +478,6 @@ class WebSocketConnectionTest { @Test void testProcessStoredMessagesContainsSenderUuid() { - final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); @@ -540,7 +540,6 @@ class WebSocketConnectionTest { @Test void testProcessStoredMessagesSingleEmptyCall() { - final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); @@ -570,7 +569,6 @@ class WebSocketConnectionTest { @Test public void testRequeryOnStateMismatch() { - final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); @@ -627,7 +625,6 @@ class WebSocketConnectionTest { @Test void testProcessCachedMessagesOnly() { - final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); @@ -660,7 +657,6 @@ class WebSocketConnectionTest { @Test void testProcessDatabaseMessagesAfterPersist() { - final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); @@ -690,8 +686,6 @@ class WebSocketConnectionTest { @Test void testRetrieveMessageException() { - MessagesManager storedMessages = mock(MessagesManager.class); - UUID accountUuid = UUID.randomUUID(); when(device.getId()).thenReturn(2L); @@ -699,7 +693,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); - when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) .thenReturn(Flux.error(new RedisException("OH NO"))); when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer( @@ -711,7 +705,7 @@ class WebSocketConnectionTest { final WebSocketClient client = mock(WebSocketClient.class); when(client.isOpen()).thenReturn(true); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, + WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); connection.start(); @@ -722,8 +716,6 @@ class WebSocketConnectionTest { @Test void testRetrieveMessageExceptionClientDisconnected() { - MessagesManager storedMessages = mock(MessagesManager.class); - UUID accountUuid = UUID.randomUUID(); when(device.getId()).thenReturn(2L); @@ -731,13 +723,13 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); - when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) .thenReturn(Flux.error(new RedisException("OH NO"))); final WebSocketClient client = mock(WebSocketClient.class); when(client.isOpen()).thenReturn(false); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, + WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); connection.start(); @@ -748,8 +740,6 @@ class WebSocketConnectionTest { @Test @Disabled("This test is flaky") void testReactivePublisherLimitRate() { - MessagesManager storedMessages = mock(MessagesManager.class); - final UUID accountUuid = UUID.randomUUID(); final long deviceId = 2L; @@ -771,7 +761,7 @@ class WebSocketConnectionTest { }); }); - when(storedMessages.getMessagesForDeviceReactive(eq(accountUuid), eq(deviceId), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(deviceId), anyBoolean())) .thenReturn(flux); final WebSocketClient client = mock(WebSocketClient.class); @@ -779,10 +769,10 @@ class WebSocketConnectionTest { final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); - when(storedMessages.delete(any(), anyLong(), any(), any())).thenReturn( + when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn( CompletableFuture.completedFuture(Optional.empty())); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, + WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor); connection.start(); @@ -808,8 +798,6 @@ class WebSocketConnectionTest { @Test void testReactivePublisherDisposedWhenConnectionStopped() { - MessagesManager storedMessages = mock(MessagesManager.class); - final UUID accountUuid = UUID.randomUUID(); final long deviceId = 2L; @@ -830,7 +818,7 @@ class WebSocketConnectionTest { s.onCancel(() -> canceled.set(true)); }); - when(storedMessages.getMessagesForDeviceReactive(eq(accountUuid), eq(deviceId), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(deviceId), anyBoolean())) .thenReturn(flux); final WebSocketClient client = mock(WebSocketClient.class); @@ -838,10 +826,10 @@ class WebSocketConnectionTest { final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); - when(storedMessages.delete(any(), anyLong(), any(), any())).thenReturn( + when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn( CompletableFuture.completedFuture(Optional.empty())); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, + WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, retrySchedulingExecutor, Schedulers.immediate()); connection.start();