From fc71ced660c2c5ae949f0ee9813c27c1ff519902 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 23 Sep 2020 12:23:34 -0400 Subject: [PATCH] Persist messages in batches. --- .../storage/MessagePersister.java | 6 +- .../textsecuregcm/storage/Messages.java | 51 ++++++----- .../storage/MessagesManager.java | 9 +- .../storage/MessagePersisterTest.java | 35 +++++--- .../tests/storage/MessagesTest.java | 89 +++++++------------ .../WebSocketConnectionIntegrationTest.java | 34 ++++--- 6 files changed, 118 insertions(+), 106 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index 369967cdb..f969f40b4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -123,10 +123,8 @@ public class MessagePersister implements Managed { do { messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT); - for (final MessageProtos.Envelope message : messages) { - messagesManager.persistMessage(accountNumber, accountUuid, message, UUID.fromString(message.getServerGuid()), deviceId); - messageCount++; - } + messagesManager.persistMessages(accountNumber, accountUuid, deviceId, messages); + messageCount += messages.size(); } while (!messages.isEmpty()); queueSizeHistogram.update(messageCount); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java index 04224db84..07785a948 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java @@ -1,9 +1,11 @@ package org.whispersystems.textsecuregcm.storage; +import com.codahale.metrics.Histogram; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; +import org.jdbi.v3.core.statement.PreparedBatch; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.storage.mappers.OutgoingMessageEntityRowMapper; @@ -43,6 +45,7 @@ public class Messages { private final Timer clearTimer = metricRegistry.timer(name(Messages.class, "clear" )); private final Timer vacuumTimer = metricRegistry.timer(name(Messages.class, "vacuum")); private final Meter insertNullGuidMeter = metricRegistry.meter(name(Messages.class, "insertNullGuid")); + private final Histogram storeSizeHistogram = metricRegistry.histogram(name(Messages.class, "storeBatchSize")); private final FaultTolerantDatabase database; @@ -51,28 +54,34 @@ public class Messages { this.database.getDatabase().registerRowMapper(new OutgoingMessageEntityRowMapper()); } - public void store(UUID guid, Envelope message, String destination, long destinationDevice) { - if (guid == null) { - insertNullGuidMeter.mark(); - } + public void store(final List messages, final String destination, final long destinationDevice) { + database.use(jdbi -> jdbi.useTransaction(handle -> { + try (final Timer.Context ignored = storeTimer.time()) { + final PreparedBatch batch = handle.prepareBatch("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_UUID + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " + + "VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_uuid, :source_device, :destination, :destination_device, :message, :content)"); - database.use(jdbi ->jdbi.useHandle(handle -> { - try (Timer.Context ignored = storeTimer.time()) { - handle.createUpdate("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_UUID + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " + - "VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_uuid, :source_device, :destination, :destination_device, :message, :content)") - .bind("guid", guid) - .bind("destination", destination) - .bind("destination_device", destinationDevice) - .bind("type", message.getType().getNumber()) - .bind("relay", message.getRelay()) - .bind("timestamp", message.getTimestamp()) - .bind("server_timestamp", message.getServerTimestamp()) - .bind("source", message.hasSource() ? message.getSource() : null) - .bind("source_uuid", message.hasSourceUuid() ? UUID.fromString(message.getSourceUuid()) : null) - .bind("source_device", message.hasSourceDevice() ? message.getSourceDevice() : null) - .bind("message", message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null) - .bind("content", message.hasContent() ? message.getContent().toByteArray() : null) - .execute(); + for (final Envelope message : messages) { + if (message.getServerGuid() == null) { + insertNullGuidMeter.mark(); + } + + batch.bind("guid", UUID.fromString(message.getServerGuid())) + .bind("destination", destination) + .bind("destination_device", destinationDevice) + .bind("type", message.getType().getNumber()) + .bind("relay", message.getRelay()) + .bind("timestamp", message.getTimestamp()) + .bind("server_timestamp", message.getServerTimestamp()) + .bind("source", message.hasSource() ? message.getSource() : null) + .bind("source_uuid", message.hasSourceUuid() ? UUID.fromString(message.getSourceUuid()) : null) + .bind("source_device", message.hasSourceDevice() ? message.getSourceDevice() : null) + .bind("message", message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null) + .bind("content", message.hasContent() ? message.getContent().toByteArray() : null) + .add(); + } + + batch.execute(); + storeSizeHistogram.update(messages.size()); } })); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index f808ffc86..dbbf45af4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -117,9 +117,12 @@ public class MessagesManager { } } - public void persistMessage(String destination, UUID destinationUuid, Envelope envelope, UUID messageGuid, long deviceId) { - messages.store(messageGuid, envelope, destination, deviceId); - messagesCache.remove(destinationUuid, deviceId, messageGuid); + public void persistMessages(final String destination, final UUID destinationUuid, final long destinationDeviceId, final List messages) { + this.messages.store(messages, destination, destinationDeviceId); + + for (final Envelope message : messages) { + messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())); + } } public void addMessageAvailabilityListener(final UUID destinationUuid, final long deviceId, final MessageAvailabilityListener listener) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index cf727e198..9ba8a7384 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -5,12 +5,14 @@ import io.lettuce.core.cluster.SlotHash; import org.apache.commons.lang3.RandomStringUtils; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; +import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.ExecutorService; @@ -18,14 +20,15 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -67,17 +70,19 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, scheduledExecutorService, PERSIST_DELAY); doAnswer(invocation -> { - final String destination = invocation.getArgument(0, String.class); - final UUID destinationUuid = invocation.getArgument(1, UUID.class); - final MessageProtos.Envelope message = invocation.getArgument(2, MessageProtos.Envelope.class); - final UUID messageGuid = invocation.getArgument(3, UUID.class); - final long deviceId = invocation.getArgument(4, Long.class); + final String destination = invocation.getArgument(0, String.class); + final UUID destinationUuid = invocation.getArgument(1, UUID.class); + final long deviceId = invocation.getArgument(2, Long.class); + final List messages = invocation.getArgument(3, List.class); - messagesDatabase.store(messageGuid, message, destination, deviceId); - messagesCache.remove(destinationUuid, deviceId, messageGuid); + messagesDatabase.store(messages, destination, deviceId); + + for (final MessageProtos.Envelope message : messages) { + messagesCache.remove(destinationUuid, deviceId, UUID.fromString(message.getServerGuid())); + } return null; - }).when(messagesManager).persistMessage(anyString(), any(UUID.class), any(MessageProtos.Envelope.class), any(UUID.class), anyLong()); + }).when(messagesManager).persistMessages(anyString(), any(UUID.class), anyLong(), any()); } @Override @@ -109,7 +114,10 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); - verify(messagesDatabase, times(messageCount)).store(any(UUID.class), any(MessageProtos.Envelope.class), eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_DEVICE_ID)); + final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); + + verify(messagesDatabase, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_DEVICE_ID)); + assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); } @Test @@ -123,7 +131,7 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { messagePersister.persistNextQueues(now); - verify(messagesDatabase, never()).store(any(UUID.class), any(MessageProtos.Envelope.class), anyString(), anyLong()); + verify(messagesDatabase, never()).store(any(), anyString(), anyLong()); } @Test @@ -151,7 +159,10 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); - verify(messagesDatabase, times(queueCount * messagesPerQueue)).store(any(UUID.class), any(MessageProtos.Envelope.class), anyString(), anyLong()); + final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); + + verify(messagesDatabase, atLeastOnce()).store(messagesCaptor.capture(), anyString(), anyLong()); + assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); } @SuppressWarnings("SameParameterValue") diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java index 477d876dd..11501f567 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java @@ -44,9 +44,8 @@ public class MessagesTest { @Test public void testStore() throws SQLException { Envelope envelope = generateEnvelope(); - UUID guid = UUID.randomUUID(); - messages.store(guid, envelope, "+14151112222", 1); + messages.store(List.of(envelope), "+14151112222", 1); PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM messages WHERE destination = ?"); statement.setString(1, "+14151112222"); @@ -54,7 +53,7 @@ public class MessagesTest { ResultSet resultSet = statement.executeQuery(); assertThat(resultSet.next()).isTrue(); - assertThat(resultSet.getString("guid")).isEqualTo(guid.toString()); + assertThat(resultSet.getString("guid")).isEqualTo(envelope.getServerGuid()); assertThat(resultSet.getInt("type")).isEqualTo(envelope.getType().getNumber()); assertThat(resultSet.getString("relay")).isNullOrEmpty(); assertThat(resultSet.getLong("timestamp")).isEqualTo(envelope.getTimestamp()); @@ -71,36 +70,29 @@ public class MessagesTest { @Test public void testLoad() { - List inserted = new ArrayList<>(50); + List inserted = insertRandom("+14151112222", 1); - for (int i=0;i<50;i++) { - MessageToStore message = generateMessageToStore(); - inserted.add(message); - - messages.store(message.guid, message.envelope, "+14151112222", 1); - } - - inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp())); + inserted.sort(Comparator.comparingLong(Envelope::getTimestamp)); List retrieved = messages.load("+14151112222", 1); assertThat(retrieved.size()).isEqualTo(inserted.size()); for (int i=0;i inserted = insertRandom("+14151112222", 1); - List unrelated = insertRandom("+14151114444", 3); - MessageToStore toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); - Optional removed = messages.remove("+14151112222", 1, toRemove.envelope.getSource(), toRemove.envelope.getTimestamp()); + List inserted = insertRandom("+14151112222", 1); + List unrelated = insertRandom("+14151114444", 3); + Envelope toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); + Optional removed = messages.remove("+14151112222", 1, toRemove.getSource(), toRemove.getTimestamp()); assertThat(removed.isPresent()).isTrue(); - verifyExpected(removed.get(), toRemove.envelope, toRemove.guid); + verifyExpected(removed.get(), toRemove, UUID.fromString(toRemove.getServerGuid())); verifyInTact(inserted, "+14151112222", 1); verifyInTact(unrelated, "+14151114444", 3); @@ -108,13 +100,13 @@ public class MessagesTest { @Test public void removeByDestinationGuid() { - List unrelated = insertRandom("+14151113333", 2); - List inserted = insertRandom("+14151112222", 1); - MessageToStore toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); - Optional removed = messages.remove("+14151112222", toRemove.guid); + List unrelated = insertRandom("+14151113333", 2); + List inserted = insertRandom("+14151112222", 1); + Envelope toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); + Optional removed = messages.remove("+14151112222", UUID.fromString(toRemove.getServerGuid())); assertThat(removed.isPresent()).isTrue(); - verifyExpected(removed.get(), toRemove.envelope, toRemove.guid); + verifyExpected(removed.get(), toRemove, UUID.fromString(toRemove.getServerGuid())); verifyInTact(inserted, "+14151112222", 1); verifyInTact(unrelated, "+14151113333", 2); @@ -122,10 +114,10 @@ public class MessagesTest { @Test public void removeByDestinationRowId() { - List unrelatedInserted = insertRandom("+14151111111", 1); - List inserted = insertRandom("+14151112222", 1); + List unrelatedInserted = insertRandom("+14151111111", 1); + List inserted = insertRandom("+14151112222", 1); - inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp())); + inserted.sort(Comparator.comparingLong(Envelope::getTimestamp)); List retrieved = messages.load("+14151112222", 1); @@ -141,9 +133,8 @@ public class MessagesTest { @Test public void testLoadEmpty() { - List inserted = insertRandom("+14151112222", 1); - List loaded = messages.load("+14159999999", 1); - assertThat(loaded.isEmpty()).isTrue(); + insertRandom("+14151112222", 1); + assertThat(messages.load("+14159999999", 1).isEmpty()).isTrue(); } @Test @@ -151,7 +142,7 @@ public class MessagesTest { insertRandom("+14151112222", 1); insertRandom("+14151112222", 2); - List unrelated = insertRandom("+14151111111", 1); + List unrelated = insertRandom("+14151111111", 1); messages.clear("+14151112222"); @@ -163,9 +154,9 @@ public class MessagesTest { @Test public void testClearDestinationDevice() { insertRandom("+14151112222", 1); - List inserted = insertRandom("+14151112222", 2); + List inserted = insertRandom("+14151112222", 2); - List unrelated = insertRandom("+14151111111", 1); + List unrelated = insertRandom("+14151111111", 1); messages.clear("+14151112222", 1); @@ -177,33 +168,32 @@ public class MessagesTest { @Test public void testVacuum() { - List inserted = insertRandom("+14151112222", 2); + List inserted = insertRandom("+14151112222", 2); messages.vacuum(); verifyInTact(inserted, "+14151112222", 2); } - private List insertRandom(String destination, int destinationDevice) { - List inserted = new ArrayList<>(50); + private List insertRandom(String destination, int destinationDevice) { + List inserted = new ArrayList<>(50); for (int i=0;i<50;i++) { - MessageToStore message = generateMessageToStore(); - inserted.add(message); - - messages.store(message.guid, message.envelope, destination, destinationDevice); + inserted.add(generateEnvelope()); } + messages.store(inserted, destination, destinationDevice); + return inserted; } - private void verifyInTact(List inserted, String destination, int destinationDevice) { - inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp())); + private void verifyInTact(List inserted, String destination, int destinationDevice) { + inserted.sort(Comparator.comparingLong(Envelope::getTimestamp)); List retrieved = messages.load(destination, destinationDevice); assertThat(retrieved.size()).isEqualTo(inserted.size()); for (int i=0;i expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); - for (int i = 0; i < persistedMessageCount; i++) { - final UUID messageGuid = UUID.randomUUID(); - final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); + { + final List persistedMessages = new ArrayList<>(persistedMessageCount); - messages.store(messageGuid, envelope, account.getNumber(), device.getId()); - expectedMessages.add(envelope.toBuilder().clearServerGuid().build()); + for (int i = 0; i < persistedMessageCount; i++) { + final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); + + persistedMessages.add(envelope); + expectedMessages.add(envelope.toBuilder().clearServerGuid().build()); + } + + messages.store(persistedMessages, account.getNumber(), device.getId()); } for (int i = 0; i < cachedMessageCount; i++) { @@ -172,9 +179,14 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest final int persistedMessageCount = 207; final int cachedMessageCount = 173; - for (int i = 0; i < persistedMessageCount; i++) { - final UUID messageGuid = UUID.randomUUID(); - messages.store(messageGuid, generateRandomMessage(messageGuid), account.getNumber(), device.getId()); + { + final List persistedMessages = new ArrayList<>(persistedMessageCount); + + for (int i = 0; i < persistedMessageCount; i++) { + persistedMessages.add(generateRandomMessage(UUID.randomUUID())); + } + + messages.store(persistedMessages, account.getNumber(), device.getId()); } for (int i = 0; i < cachedMessageCount; i++) { @@ -191,9 +203,11 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest } private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) { + final long timestamp = serialTimestamp++; + return MessageProtos.Envelope.newBuilder() - .setTimestamp(System.currentTimeMillis()) - .setServerTimestamp(System.currentTimeMillis()) + .setTimestamp(timestamp) + .setServerTimestamp(timestamp) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setServerGuid(messageGuid.toString())