From 6f9ff3be3755032b4a0030b3eb39164f28cb4f83 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 9 Sep 2020 18:04:30 -0400 Subject: [PATCH] Avoid querying the database if we think all new messages are in the cache. --- .../controllers/MessageController.java | 2 +- .../storage/MessagesManager.java | 5 +- .../websocket/WebSocketConnection.java | 26 +++++- .../controllers/MessageControllerTest.java | 4 +- .../websocket/WebSocketConnectionTest.java | 92 +++++++++++++++---- 5 files changed, 103 insertions(+), 26 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index c66d4a7e5..547bf3934 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -193,7 +193,7 @@ public class MessageController { final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), account.getAuthenticatedDevice().get().getId(), - userAgent); + userAgent, false); outgoingMessageListSizeHistogram.update(outgoingMessages.getMessages().size()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 1e8065046..aa64ead60 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -11,6 +11,7 @@ import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.util.Constants; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -49,10 +50,10 @@ public class MessagesManager { return messagesCache.takeEphemeralMessage(destinationUuid, destinationDevice); } - public OutgoingMessageEntityList getMessagesForDevice(String destination, UUID destinationUuid, long destinationDevice, final String userAgent) { + public OutgoingMessageEntityList getMessagesForDevice(String destination, UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) { RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent)); - List messages = this.messages.load(destination, destinationDevice); + List messages = cachedMessagesOnly ? new ArrayList<>() : this.messages.load(destination, destinationDevice); if (messages.size() <= Messages.RESULT_SET_CHUNK_SIZE) { messages.addAll(messagesCache.get(destinationUuid, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messages.size())); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index b67a24f31..ceee674e1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -65,7 +65,9 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability private final WebSocketClient client; private final String connectionId; - private int storedMessageState = 0; + private int storedMessageState = 1; + private int lastPersistedState = 1; + private int lastDatabaseClearedState = 0; private boolean processingStoredMessages = false; private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); @@ -190,7 +192,8 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability @VisibleForTesting void processStoredMessages() { - final int processedState; + final int processedState; + final boolean cachedMessagesOnly; synchronized (this) { if (processingStoredMessages) { @@ -199,9 +202,10 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability processingStoredMessages = true; processedState = storedMessageState; + cachedMessagesOnly = lastPersistedState <= lastDatabaseClearedState; } - OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent()); + OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); CompletableFuture[] sendFutures = new CompletableFuture[messages.getMessages().size()]; for (int i = 0; i < messages.getMessages().size(); i++) { @@ -232,14 +236,19 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability } CompletableFuture.allOf(sendFutures).whenComplete((v, cause) -> { + final boolean mayHaveMoreMessages; + synchronized (this) { processingStoredMessages = false; + mayHaveMoreMessages = messages.hasMore() || storedMessageState > processedState; } - if (messages.hasMore() || storedMessageState > processedState) { + if (mayHaveMoreMessages) { processStoredMessages(); } else { - final boolean shouldSendEmptyQueueMessage; + synchronized (this) { + lastDatabaseClearedState = processedState; + } if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); @@ -267,6 +276,13 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability @Override public void handleMessagesPersisted() { messagesPersistedMeter.mark(); + + synchronized (this) { + storedMessageState++; + lastPersistedState = storedMessageState; + } + + processStoredMessages(); } @Override diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index 757c6a104..dc8738068 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -257,7 +257,7 @@ public class MessageControllerTest { OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString())).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); OutgoingMessageEntityList response = resources.getJerseyTest().target("/v1/messages/") @@ -294,7 +294,7 @@ public class MessageControllerTest { OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString())).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); Response response = resources.getJerseyTest().target("/v1/messages/") diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index f06f3dd62..ad308ea48 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -39,16 +39,17 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; @@ -149,7 +150,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) .thenReturn(outgoingMessagesList); final List> futures = new LinkedList<>(); @@ -236,7 +237,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) .thenReturn(pendingMessagesList); final List> futures = new LinkedList<>(); @@ -347,7 +348,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) .thenReturn(pendingMessagesList); final List> futures = new LinkedList<>(); @@ -401,7 +402,7 @@ public class WebSocketConnectionTest { final AtomicBoolean threadWaiting = new AtomicBoolean(false); final AtomicBoolean returnMessageList = new AtomicBoolean(false); - when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())).thenAnswer((Answer)invocation -> { + when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer)invocation -> { synchronized (threadWaiting) { threadWaiting.set(true); threadWaiting.notifyAll(); @@ -445,7 +446,7 @@ public class WebSocketConnectionTest { thread.join(); } - verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString()); + verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString(), eq(false)); } @Test @@ -469,7 +470,7 @@ public class WebSocketConnectionTest { final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true); final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); - when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())) + when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)) .thenReturn(firstPage) .thenReturn(secondPage); @@ -497,17 +498,15 @@ public class WebSocketConnectionTest { final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + final UUID accountUuid = UUID.randomUUID(); + when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(account.getUuid()).thenReturn(accountUuid); when(device.getId()).thenReturn(1L); when(client.getUserAgent()).thenReturn("Test-UA"); - when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())).thenAnswer(new Answer() { - @Override - public OutgoingMessageEntityList answer(final InvocationOnMock invocation) throws Throwable { - return new OutgoingMessageEntityList(Collections.emptyList(), false); - } - }); + when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -523,7 +522,7 @@ public class WebSocketConnectionTest { } @Test - public void testRequeryAfterOnStateMismatch() throws InterruptedException { + public void testRequeryOnStateMismatch() throws InterruptedException { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); @@ -543,7 +542,7 @@ public class WebSocketConnectionTest { final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false); final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); - when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())) + when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)) .thenReturn(firstPage) .thenReturn(secondPage) .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); @@ -573,6 +572,67 @@ public class WebSocketConnectionTest { verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } + @Test + public void testProcessCachedMessagesOnly() { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + final UUID accountUuid = UUID.randomUUID(); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(accountUuid); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to + // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the + // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for + // anything. + connection.processStoredMessages(); + + verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false); + + connection.processStoredMessages(); + + verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), true); + } + + @Test + public void testProcessDatabaseMessagesAfterPersist() { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + final UUID accountUuid = UUID.randomUUID(); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(accountUuid); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to + // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the + // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for + // anything. + connection.processStoredMessages(); + connection.handleMessagesPersisted(); + + verify(messagesManager, times(2)).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false); + } + private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) { return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, null, timestamp, sender, senderUuid, 1, content.getBytes(), null, 0);