diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 571b6ab7e..b54abbe01 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -4,6 +4,7 @@ import com.codahale.metrics.Histogram; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; +import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import org.slf4j.Logger; @@ -63,6 +64,8 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability private final WebSocketClient client; private final String connectionId; + private boolean processingStoredMessages = false; + public WebSocketConnection(PushSender pushSender, ReceiptSender receiptSender, MessagesManager messagesManager, @@ -180,7 +183,16 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability return response != null && response.getStatus() >= 200 && response.getStatus() < 300; } - private void processStoredMessages() { + @VisibleForTesting + void processStoredMessages() { + synchronized (this) { + if (processingStoredMessages) { + return; + } + + processingStoredMessages = true; + } + OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent()); Iterator iterator = messages.getMessages().iterator(); @@ -214,6 +226,10 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability if (!messages.hasMore()) { client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); } + + synchronized (this) { + processingStoredMessages = false; + } } @Override diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java similarity index 89% rename from service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 4e115b7a7..0cbb3ad14 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -1,4 +1,4 @@ -package org.whispersystems.textsecuregcm.tests.websocket; +package org.whispersystems.textsecuregcm.websocket; import com.google.protobuf.ByteString; import org.eclipse.jetty.websocket.api.UpgradeRequest; @@ -31,6 +31,7 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.session.WebSocketSessionContext; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; @@ -39,6 +40,8 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; import io.dropwizard.auth.basic.BasicCredentials; import static org.junit.Assert.*; @@ -376,6 +379,66 @@ public class WebSocketConnectionTest { verify(client).close(anyInt(), anyString()); } + @Test + public void testProcessStoredMessageConcurrency() throws InterruptedException { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + final AtomicBoolean threadWaiting = new AtomicBoolean(false); + final AtomicBoolean returnMessageList = new AtomicBoolean(false); + + when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())).thenAnswer((Answer)invocation -> { + synchronized (threadWaiting) { + threadWaiting.set(true); + threadWaiting.notifyAll(); + } + + synchronized (returnMessageList) { + while (!returnMessageList.get()) { + returnMessageList.wait(); + } + } + + return new OutgoingMessageEntityList(Collections.emptyList(), false); + }); + + final Thread[] threads = new Thread[10]; + final CountDownLatch unblockedThreadsLatch = new CountDownLatch(threads.length - 1); + + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(() -> { + connection.processStoredMessages(); + unblockedThreadsLatch.countDown(); + }); + + threads[i].start(); + } + + unblockedThreadsLatch.await(); + + synchronized (threadWaiting) { + while (!threadWaiting.get()) { + threadWaiting.wait(); + } + } + + synchronized (returnMessageList) { + returnMessageList.set(true); + returnMessageList.notifyAll(); + } + + for (final Thread thread : threads) { + thread.join(); + } + + verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString()); + } private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) { return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,