diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index b54abbe01..c9c73a1ac 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -35,6 +35,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import static com.codahale.metrics.MetricRegistry.name; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; @@ -95,7 +96,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability break; case PubSubMessage.Type.DELIVER_VALUE: pubSubNewMessageMeter.mark(); - sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.empty(), false); + sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.empty()); break; default: logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber()); @@ -114,10 +115,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability processStoredMessages(); } - private void sendMessage(final Envelope message, - final Optional storedMessageInfo, - final boolean requery) - { + private CompletableFuture sendMessage(final Envelope message, final Optional storedMessageInfo) { try { String header; Optional body; @@ -132,7 +130,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability sendMessageMeter.mark(); - client.sendRequest("PUT", "/api/v1/message", List.of(header, TimestampHeaderUtil.getTimestampHeader()), body) + return client.sendRequest("PUT", "/api/v1/message", List.of(header, TimestampHeaderUtil.getTimestampHeader()), body) .thenAccept(response -> { boolean isReceipt = message.getType() == Envelope.Type.RECEIPT; @@ -143,7 +141,6 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability if (isSuccessResponse(response)) { if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), account.getUuid(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached); if (!isReceipt) sendDeliveryReceiptFor(message); - if (requery) processStoredMessages(); } else if (!isSuccessResponse(response) && !storedMessageInfo.isPresent()) { requeueMessage(message); } @@ -154,6 +151,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability }); } catch (CryptoEncodingException e) { logger.warn("Bad signaling key", e); + return CompletableFuture.failedFuture(e); } } @@ -193,11 +191,11 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability processingStoredMessages = true; } - OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent()); - Iterator iterator = messages.getMessages().iterator(); + OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent()); + CompletableFuture[] sendFutures = new CompletableFuture[messages.getMessages().size()]; - while (iterator.hasNext()) { - OutgoingMessageEntity message = iterator.next(); + for (int i = 0; i < messages.getMessages().size(); i++) { + OutgoingMessageEntity message = messages.getMessages().get(i); Envelope.Builder builder = Envelope.newBuilder() .setType(Envelope.Type.valueOf(message.getType())) .setTimestamp(message.getTimestamp()) @@ -220,16 +218,20 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability builder.setRelay(message.getRelay()); } - sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached())), !iterator.hasNext() && messages.hasMore()); + sendFutures[i] = sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached()))); } - if (!messages.hasMore()) { - client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); - } + CompletableFuture.allOf(sendFutures).whenComplete((v, cause) -> { + synchronized (this) { + processingStoredMessages = false; + } - synchronized (this) { - processingStoredMessages = false; - } + if (messages.hasMore()) { + processStoredMessages(); + } else { + client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); + } + }); } @Override @@ -244,7 +246,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability final Optional maybeMessage = messagesManager.takeEphemeralMessage(account.getUuid(), device.getId()); if (maybeMessage.isPresent()) { - sendMessage(maybeMessage.get(), Optional.empty(), false); + sendMessage(maybeMessage.get(), Optional.empty()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 0cbb3ad14..6b02922d8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -1,6 +1,7 @@ package org.whispersystems.textsecuregcm.websocket; import com.google.protobuf.ByteString; +import io.dropwizard.auth.basic.BasicCredentials; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.Test; import org.mockito.ArgumentMatchers; @@ -21,10 +22,6 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.Base64; -import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener; -import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator; -import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; -import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.messages.WebSocketResponseMessage; @@ -43,10 +40,20 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; -import io.dropwizard.auth.basic.BasicCredentials; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; public class WebSocketConnectionTest { @@ -440,6 +447,49 @@ public class WebSocketConnectionTest { verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString()); } + @Test + public void testProcessStoredMessagesMultiplePages() throws InterruptedException { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + final List firstPageMessages = + List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), + createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); + + final List secondPageMessages = + List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third")); + + final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true); + final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); + + when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())) + .thenReturn(firstPage) + .thenReturn(secondPage); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size()); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { + sendLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); + + connection.processStoredMessages(); + + sendLatch.await(); + + verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + } + private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) { return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, null, timestamp, sender, senderUuid, 1, content.getBytes(), null, 0);