Avoid querying the database if we think all new messages are in the cache.

This commit is contained in:
Jon Chambers 2020-09-09 18:04:30 -04:00 committed by Jon Chambers
parent f766c57743
commit 6f9ff3be37
5 changed files with 103 additions and 26 deletions

View File

@ -193,7 +193,7 @@ public class MessageController {
final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice(account.getNumber(), final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice(account.getNumber(),
account.getUuid(), account.getUuid(),
account.getAuthenticatedDevice().get().getId(), account.getAuthenticatedDevice().get().getId(),
userAgent); userAgent, false);
outgoingMessageListSizeHistogram.update(outgoingMessages.getMessages().size()); outgoingMessageListSizeHistogram.update(outgoingMessages.getMessages().size());

View File

@ -11,6 +11,7 @@ import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -49,10 +50,10 @@ public class MessagesManager {
return messagesCache.takeEphemeralMessage(destinationUuid, destinationDevice); return messagesCache.takeEphemeralMessage(destinationUuid, destinationDevice);
} }
public OutgoingMessageEntityList getMessagesForDevice(String destination, UUID destinationUuid, long destinationDevice, final String userAgent) { public OutgoingMessageEntityList getMessagesForDevice(String destination, UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) {
RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent)); RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent));
List<OutgoingMessageEntity> messages = this.messages.load(destination, destinationDevice); List<OutgoingMessageEntity> messages = cachedMessagesOnly ? new ArrayList<>() : this.messages.load(destination, destinationDevice);
if (messages.size() <= Messages.RESULT_SET_CHUNK_SIZE) { if (messages.size() <= Messages.RESULT_SET_CHUNK_SIZE) {
messages.addAll(messagesCache.get(destinationUuid, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messages.size())); messages.addAll(messagesCache.get(destinationUuid, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messages.size()));

View File

@ -65,7 +65,9 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
private final WebSocketClient client; private final WebSocketClient client;
private final String connectionId; private final String connectionId;
private int storedMessageState = 0; private int storedMessageState = 1;
private int lastPersistedState = 1;
private int lastDatabaseClearedState = 0;
private boolean processingStoredMessages = false; private boolean processingStoredMessages = false;
private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false);
@ -191,6 +193,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
@VisibleForTesting @VisibleForTesting
void processStoredMessages() { void processStoredMessages() {
final int processedState; final int processedState;
final boolean cachedMessagesOnly;
synchronized (this) { synchronized (this) {
if (processingStoredMessages) { if (processingStoredMessages) {
@ -199,9 +202,10 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
processingStoredMessages = true; processingStoredMessages = true;
processedState = storedMessageState; processedState = storedMessageState;
cachedMessagesOnly = lastPersistedState <= lastDatabaseClearedState;
} }
OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent()); OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly);
CompletableFuture<?>[] sendFutures = new CompletableFuture[messages.getMessages().size()]; CompletableFuture<?>[] sendFutures = new CompletableFuture[messages.getMessages().size()];
for (int i = 0; i < messages.getMessages().size(); i++) { for (int i = 0; i < messages.getMessages().size(); i++) {
@ -232,14 +236,19 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
} }
CompletableFuture.allOf(sendFutures).whenComplete((v, cause) -> { CompletableFuture.allOf(sendFutures).whenComplete((v, cause) -> {
final boolean mayHaveMoreMessages;
synchronized (this) { synchronized (this) {
processingStoredMessages = false; processingStoredMessages = false;
mayHaveMoreMessages = messages.hasMore() || storedMessageState > processedState;
} }
if (messages.hasMore() || storedMessageState > processedState) { if (mayHaveMoreMessages) {
processStoredMessages(); processStoredMessages();
} else { } else {
final boolean shouldSendEmptyQueueMessage; synchronized (this) {
lastDatabaseClearedState = processedState;
}
if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) {
client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty());
@ -267,6 +276,13 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability
@Override @Override
public void handleMessagesPersisted() { public void handleMessagesPersisted() {
messagesPersistedMeter.mark(); messagesPersistedMeter.mark();
synchronized (this) {
storedMessageState++;
lastPersistedState = storedMessageState;
}
processStoredMessages();
} }
@Override @Override

View File

@ -257,7 +257,7 @@ public class MessageControllerTest {
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString())).thenReturn(messagesList); when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList);
OutgoingMessageEntityList response = OutgoingMessageEntityList response =
resources.getJerseyTest().target("/v1/messages/") resources.getJerseyTest().target("/v1/messages/")
@ -294,7 +294,7 @@ public class MessageControllerTest {
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString())).thenReturn(messagesList); when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList);
Response response = Response response =
resources.getJerseyTest().target("/v1/messages/") resources.getJerseyTest().target("/v1/messages/")

View File

@ -39,16 +39,17 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.any; import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -149,7 +150,7 @@ public class WebSocketConnectionTest {
String userAgent = "user-agent"; String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false))
.thenReturn(outgoingMessagesList); .thenReturn(outgoingMessagesList);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>(); final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@ -236,7 +237,7 @@ public class WebSocketConnectionTest {
String userAgent = "user-agent"; String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false))
.thenReturn(pendingMessagesList); .thenReturn(pendingMessagesList);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>(); final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@ -347,7 +348,7 @@ public class WebSocketConnectionTest {
String userAgent = "user-agent"; String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false))
.thenReturn(pendingMessagesList); .thenReturn(pendingMessagesList);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>(); final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@ -401,7 +402,7 @@ public class WebSocketConnectionTest {
final AtomicBoolean threadWaiting = new AtomicBoolean(false); final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean returnMessageList = new AtomicBoolean(false); final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())).thenAnswer((Answer<OutgoingMessageEntityList>)invocation -> { when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer<OutgoingMessageEntityList>)invocation -> {
synchronized (threadWaiting) { synchronized (threadWaiting) {
threadWaiting.set(true); threadWaiting.set(true);
threadWaiting.notifyAll(); threadWaiting.notifyAll();
@ -445,7 +446,7 @@ public class WebSocketConnectionTest {
thread.join(); thread.join();
} }
verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString()); verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString(), eq(false));
} }
@Test @Test
@ -469,7 +470,7 @@ public class WebSocketConnectionTest {
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true); final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true);
final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false);
when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())) when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false))
.thenReturn(firstPage) .thenReturn(firstPage)
.thenReturn(secondPage); .thenReturn(secondPage);
@ -497,17 +498,15 @@ public class WebSocketConnectionTest {
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234"); when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID()); when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L); when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())).thenAnswer(new Answer<OutgoingMessageEntityList>() { when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
@Override .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
public OutgoingMessageEntityList answer(final InvocationOnMock invocation) throws Throwable {
return new OutgoingMessageEntityList(Collections.emptyList(), false);
}
});
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -523,7 +522,7 @@ public class WebSocketConnectionTest {
} }
@Test @Test
public void testRequeryAfterOnStateMismatch() throws InterruptedException { public void testRequeryOnStateMismatch() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class); final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
@ -543,7 +542,7 @@ public class WebSocketConnectionTest {
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false); final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false);
final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false);
when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())) when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false))
.thenReturn(firstPage) .thenReturn(firstPage)
.thenReturn(secondPage) .thenReturn(secondPage)
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
@ -573,6 +572,67 @@ public class WebSocketConnectionTest {
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
} }
@Test
public void testProcessCachedMessagesOnly() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
// This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to
// CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the
// whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for
// anything.
connection.processStoredMessages();
verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false);
connection.processStoredMessages();
verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), true);
}
@Test
public void testProcessDatabaseMessagesAfterPersist() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
// This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to
// CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the
// whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for
// anything.
connection.processStoredMessages();
connection.handleMessagesPersisted();
verify(messagesManager, times(2)).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false);
}
private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) { private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) {
return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,
null, timestamp, sender, senderUuid, 1, content.getBytes(), null, 0); null, timestamp, sender, senderUuid, 1, content.getBytes(), null, 0);