diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 2b0906267..d78fdcdba 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -35,6 +35,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; import static com.codahale.metrics.MetricRegistry.name; @@ -65,11 +66,10 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability private final WebSocketClient client; private final String connectionId; - private int storedMessageState = 1; - private int lastPersistedState = 1; - private int lastDatabaseClearedState = 0; - private boolean processingStoredMessages = false; - private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); + private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); + private final AtomicBoolean newMessagesAvailable = new AtomicBoolean(true); + private final AtomicBoolean persistedMessagesAvailable = new AtomicBoolean(true); + private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); public WebSocketConnection(PushSender pushSender, ReceiptSender receiptSender, @@ -96,11 +96,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability switch (pubSubMessage.getType().getNumber()) { case PubSubMessage.Type.QUERY_DB_VALUE: pubSubPersistedMeter.mark(); - - synchronized (this) { - storedMessageState++; - } - + newMessagesAvailable.set(true); processStoredMessages(); break; case PubSubMessage.Type.DELIVER_VALUE: @@ -192,42 +188,16 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability @VisibleForTesting void processStoredMessages() { - final int processedState; - final boolean cachedMessagesOnly; - - synchronized (this) { - if (processingStoredMessages) { - return; - } - - processingStoredMessages = true; - processedState = storedMessageState; - cachedMessagesOnly = lastPersistedState <= lastDatabaseClearedState; - } - - sendNextMessagePage(cachedMessagesOnly).thenAccept(hasMoreStoredMessages -> { - final boolean mayHaveMoreMessages; - - synchronized (this) { - processingStoredMessages = false; - mayHaveMoreMessages = hasMoreStoredMessages || storedMessageState > processedState; - } - - if (mayHaveMoreMessages) { - processStoredMessages(); + if (processStoredMessagesSemaphore.tryAcquire()) { + if (newMessagesAvailable.getAndSet(false)) { + sendNextMessagePage(!persistedMessagesAvailable.getAndSet(false)); } else { - synchronized (this) { - lastDatabaseClearedState = processedState; - } - - if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { - client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); - } + processStoredMessagesSemaphore.release(); } - }); + } } - private CompletableFuture sendNextMessagePage(final boolean cachedMessagesOnly) { + private void sendNextMessagePage(final boolean cachedMessagesOnly) { final OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); final CompletableFuture[] sendFutures = new CompletableFuture[messages.getMessages().size()]; @@ -258,34 +228,42 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability sendFutures[i] = sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached()))); } - return CompletableFuture.allOf(sendFutures).handle((v, cause) -> messages.hasMore()); + CompletableFuture.allOf(sendFutures).whenComplete((v, cause) -> { + if (messages.hasMore()) { + sendNextMessagePage(cachedMessagesOnly); + } else { + if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { + client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); + } + + processStoredMessagesSemaphore.release(); + processStoredMessages(); + } + }); } @Override public void handleNewMessagesAvailable() { messageAvailableMeter.mark(); + + newMessagesAvailable.set(true); + processStoredMessages(); } @Override public void handleNewEphemeralMessageAvailable() { ephemeralMessageAvailableMeter.mark(); - final Optional maybeMessage = messagesManager.takeEphemeralMessage(account.getUuid(), device.getId()); - - if (maybeMessage.isPresent()) { - sendMessage(maybeMessage.get(), Optional.empty()); - } + messagesManager.takeEphemeralMessage(account.getUuid(), device.getId()) + .ifPresent(message -> sendMessage(message, Optional.empty())); } @Override public void handleMessagesPersisted() { messagesPersistedMeter.mark(); - synchronized (this) { - storedMessageState++; - lastPersistedState = storedMessageState; - } - + persistedMessagesAvailable.set(true); + newMessagesAvailable.set(true); processStoredMessages(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index ad308ea48..57827bd71 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -49,7 +49,6 @@ import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; @@ -526,9 +525,10 @@ public class WebSocketConnectionTest { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + final UUID accountUuid = UUID.randomUUID(); when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(account.getUuid()).thenReturn(accountUuid); when(device.getId()).thenReturn(1L); when(client.getUserAgent()).thenReturn("Test-UA"); @@ -542,7 +542,7 @@ public class WebSocketConnectionTest { final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false); final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); - when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)) + when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) .thenReturn(firstPage) .thenReturn(secondPage) .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); @@ -599,7 +599,7 @@ public class WebSocketConnectionTest { verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false); - connection.processStoredMessages(); + connection.handleNewMessagesAvailable(); verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), true); }