Only allow one thread to process stored messages at a time.

This commit is contained in:
Jon Chambers 2020-09-09 11:28:59 -04:00 committed by Jon Chambers
parent 1a0c70acc2
commit 68256d2343
2 changed files with 81 additions and 2 deletions

View File

@ -4,6 +4,7 @@ import com.codahale.metrics.Histogram;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
@ -63,6 +64,8 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
private final WebSocketClient client;
private final String connectionId;
private boolean processingStoredMessages = false;
public WebSocketConnection(PushSender pushSender,
ReceiptSender receiptSender,
MessagesManager messagesManager,
@ -180,7 +183,16 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
return response != null && response.getStatus() >= 200 && response.getStatus() < 300;
}
private void processStoredMessages() {
@VisibleForTesting
void processStoredMessages() {
synchronized (this) {
if (processingStoredMessages) {
return;
}
processingStoredMessages = true;
}
OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent());
Iterator<OutgoingMessageEntity> iterator = messages.getMessages().iterator();
@ -214,6 +226,10 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
if (!messages.hasMore()) {
client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty());
}
synchronized (this) {
processingStoredMessages = false;
}
}
@Override

View File

@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.tests.websocket;
package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.ByteString;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
@ -31,6 +31,7 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
@ -39,6 +40,8 @@ import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import io.dropwizard.auth.basic.BasicCredentials;
import static org.junit.Assert.*;
@ -376,6 +379,66 @@ public class WebSocketConnectionTest {
verify(client).close(anyInt(), anyString());
}
@Test
public void testProcessStoredMessageConcurrency() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())).thenAnswer((Answer<OutgoingMessageEntityList>)invocation -> {
synchronized (threadWaiting) {
threadWaiting.set(true);
threadWaiting.notifyAll();
}
synchronized (returnMessageList) {
while (!returnMessageList.get()) {
returnMessageList.wait();
}
}
return new OutgoingMessageEntityList(Collections.emptyList(), false);
});
final Thread[] threads = new Thread[10];
final CountDownLatch unblockedThreadsLatch = new CountDownLatch(threads.length - 1);
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
connection.processStoredMessages();
unblockedThreadsLatch.countDown();
});
threads[i].start();
}
unblockedThreadsLatch.await();
synchronized (threadWaiting) {
while (!threadWaiting.get()) {
threadWaiting.wait();
}
}
synchronized (returnMessageList) {
returnMessageList.set(true);
returnMessageList.notifyAll();
}
for (final Thread thread : threads) {
thread.join();
}
verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString());
}
private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) {
return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,