Only allow one thread to process stored messages at a time.
This commit is contained in:
parent
1a0c70acc2
commit
68256d2343
|
@ -4,6 +4,7 @@ import com.codahale.metrics.Histogram;
|
|||
import com.codahale.metrics.Meter;
|
||||
import com.codahale.metrics.MetricRegistry;
|
||||
import com.codahale.metrics.SharedMetricRegistries;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -63,6 +64,8 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
|
|||
private final WebSocketClient client;
|
||||
private final String connectionId;
|
||||
|
||||
private boolean processingStoredMessages = false;
|
||||
|
||||
public WebSocketConnection(PushSender pushSender,
|
||||
ReceiptSender receiptSender,
|
||||
MessagesManager messagesManager,
|
||||
|
@ -180,7 +183,16 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
|
|||
return response != null && response.getStatus() >= 200 && response.getStatus() < 300;
|
||||
}
|
||||
|
||||
private void processStoredMessages() {
|
||||
@VisibleForTesting
|
||||
void processStoredMessages() {
|
||||
synchronized (this) {
|
||||
if (processingStoredMessages) {
|
||||
return;
|
||||
}
|
||||
|
||||
processingStoredMessages = true;
|
||||
}
|
||||
|
||||
OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent());
|
||||
Iterator<OutgoingMessageEntity> iterator = messages.getMessages().iterator();
|
||||
|
||||
|
@ -214,6 +226,10 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
|
|||
if (!messages.hasMore()) {
|
||||
client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty());
|
||||
}
|
||||
|
||||
synchronized (this) {
|
||||
processingStoredMessages = false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package org.whispersystems.textsecuregcm.tests.websocket;
|
||||
package org.whispersystems.textsecuregcm.websocket;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import org.eclipse.jetty.websocket.api.UpgradeRequest;
|
||||
|
@ -31,6 +31,7 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage;
|
|||
import org.whispersystems.websocket.session.WebSocketSessionContext;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedList;
|
||||
|
@ -39,6 +40,8 @@ import java.util.Optional;
|
|||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import io.dropwizard.auth.basic.BasicCredentials;
|
||||
import static org.junit.Assert.*;
|
||||
|
@ -376,6 +379,66 @@ public class WebSocketConnectionTest {
|
|||
verify(client).close(anyInt(), anyString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testProcessStoredMessageConcurrency() throws InterruptedException {
|
||||
final MessagesManager messagesManager = mock(MessagesManager.class);
|
||||
final WebSocketClient client = mock(WebSocketClient.class);
|
||||
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
|
||||
|
||||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
when(account.getUuid()).thenReturn(UUID.randomUUID());
|
||||
when(device.getId()).thenReturn(1L);
|
||||
when(client.getUserAgent()).thenReturn("Test-UA");
|
||||
|
||||
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
|
||||
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
|
||||
|
||||
when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())).thenAnswer((Answer<OutgoingMessageEntityList>)invocation -> {
|
||||
synchronized (threadWaiting) {
|
||||
threadWaiting.set(true);
|
||||
threadWaiting.notifyAll();
|
||||
}
|
||||
|
||||
synchronized (returnMessageList) {
|
||||
while (!returnMessageList.get()) {
|
||||
returnMessageList.wait();
|
||||
}
|
||||
}
|
||||
|
||||
return new OutgoingMessageEntityList(Collections.emptyList(), false);
|
||||
});
|
||||
|
||||
final Thread[] threads = new Thread[10];
|
||||
final CountDownLatch unblockedThreadsLatch = new CountDownLatch(threads.length - 1);
|
||||
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i] = new Thread(() -> {
|
||||
connection.processStoredMessages();
|
||||
unblockedThreadsLatch.countDown();
|
||||
});
|
||||
|
||||
threads[i].start();
|
||||
}
|
||||
|
||||
unblockedThreadsLatch.await();
|
||||
|
||||
synchronized (threadWaiting) {
|
||||
while (!threadWaiting.get()) {
|
||||
threadWaiting.wait();
|
||||
}
|
||||
}
|
||||
|
||||
synchronized (returnMessageList) {
|
||||
returnMessageList.set(true);
|
||||
returnMessageList.notifyAll();
|
||||
}
|
||||
|
||||
for (final Thread thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString());
|
||||
}
|
||||
|
||||
private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) {
|
||||
return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,
|
Loading…
Reference in New Issue