From eecc71c77fa130e6225b36c36b7c14746631b9ff Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Mon, 20 Jul 2020 16:28:32 -0400 Subject: [PATCH] Revert batch message storage. (#95) --- .../textsecuregcm/storage/Messages.java | 44 ++++----- .../textsecuregcm/storage/MessagesCache.java | 41 ++++---- .../tests/storage/MessagesTest.java | 93 ++++++++++++------- 3 files changed, 97 insertions(+), 81 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java index 2f9e02136..a1acbd2c8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java @@ -1,11 +1,8 @@ package org.whispersystems.textsecuregcm.storage; -import com.codahale.metrics.Histogram; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; -import com.google.common.annotations.VisibleForTesting; -import org.jdbi.v3.core.statement.PreparedBatch; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.storage.mappers.OutgoingMessageEntityRowMapper; @@ -44,7 +41,6 @@ public class Messages { private final Timer clearDeviceTimer = metricRegistry.timer(name(Messages.class, "clearDevice" )); private final Timer clearTimer = metricRegistry.timer(name(Messages.class, "clear" )); private final Timer vacuumTimer = metricRegistry.timer(name(Messages.class, "vacuum")); - private final Histogram storeSizeHistogram = metricRegistry.histogram(name(Messages.class, "storeBatchSize")); private final FaultTolerantDatabase database; @@ -53,30 +49,24 @@ public class Messages { this.database.getDatabase().registerRowMapper(new OutgoingMessageEntityRowMapper()); } - public void store(List messages, String destination, long destinationDevice) { - database.use(jdbi -> jdbi.useTransaction(handle -> { + public void store(UUID guid, Envelope message, String destination, long destinationDevice) { + database.use(jdbi ->jdbi.useHandle(handle -> { try (Timer.Context ignored = storeTimer.time()) { - final PreparedBatch batch = handle.prepareBatch("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_UUID + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " + - "VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_uuid, :source_device, :destination, :destination_device, :message, :content)"); - - for (final Envelope message : messages) { - batch.bind("guid", UUID.fromString(message.getServerGuid())) - .bind("type", message.getType().getNumber()) - .bind("relay", message.getRelay()) - .bind("timestamp", message.getTimestamp()) - .bind("server_timestamp", message.getServerTimestamp()) - .bind("source", message.hasSource() ? message.getSource() : null) - .bind("source_uuid", message.hasSourceUuid() ? UUID.fromString(message.getSourceUuid()) : null) - .bind("source_device", message.hasSourceDevice() ? message.getSourceDevice() : null) - .bind("destination", destination) - .bind("destination_device", destinationDevice) - .bind("message", message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null) - .bind("content", message.hasContent() ? message.getContent().toByteArray() : null) - .add(); - } - - batch.execute(); - storeSizeHistogram.update(messages.size()); + handle.createUpdate("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_UUID + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " + + "VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_uuid, :source_device, :destination, :destination_device, :message, :content)") + .bind("guid", guid) + .bind("destination", destination) + .bind("destination_device", destinationDevice) + .bind("type", message.getType().getNumber()) + .bind("relay", message.getRelay()) + .bind("timestamp", message.getTimestamp()) + .bind("server_timestamp", message.getServerTimestamp()) + .bind("source", message.hasSource() ? message.getSource() : null) + .bind("source_uuid", message.hasSourceUuid() ? UUID.fromString(message.getSourceUuid()) : null) + .bind("source_device", message.hasSourceDevice() ? message.getSourceDevice() : null) + .bind("message", message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null) + .bind("content", message.hasContent() ? message.getContent().toByteArray() : null) + .execute(); } })); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 908a75f5f..44d3f02a1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -19,7 +19,6 @@ import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Iterator; @@ -452,43 +451,47 @@ public class MessagesCache implements Managed { } private void persistQueue(ReplicatedJedisPool jedisPool, Key key) throws IOException { + Timer.Context timer = persistQueueTimer.time(); + int messagesPersistedCount = 0; - try (Jedis jedis = jedisPool.getWriteResource(); - Timer.Context ignored = persistQueueTimer.time()) { - + try (Jedis jedis = jedisPool.getWriteResource()) { while (true) { jedis.setex(key.getUserMessageQueuePersistInProgress(), 30, "1".getBytes()); Set messages = jedis.zrangeWithScores(key.getUserMessageQueue(), 0, CHUNK_SIZE); - List envelopes = new ArrayList<>(messages.size()); - for (Tuple tuple : messages) { - try { - envelopes.add(Envelope.parseFrom(tuple.getBinaryElement())); - } catch (InvalidProtocolBufferException e) { - logger.error("Error parsing envelope", e); - } + for (Tuple message : messages) { + persistMessage(jedis, key, (long)message.getScore(), message.getBinaryElement()); + messagesPersistedCount++; } - database.store(envelopes, key.getAddress(), key.getDeviceId()); - - for (Tuple tuple : messages) { - removeOperation.remove(jedis, key.getAddress(), key.getDeviceId(), (long)tuple.getScore()); - } - - messagesPersistedCount += envelopes.size(); - if (messages.size() < CHUNK_SIZE) { jedis.del(key.getUserMessageQueuePersistInProgress()); return; } } } finally { + timer.stop(); queueSizeHistogram.update(messagesPersistedCount); } } + private void persistMessage(Jedis jedis, Key key, long score, byte[] message) { + try { + Envelope envelope = Envelope.parseFrom(message); + UUID guid = envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null; + + envelope = envelope.toBuilder().clearServerGuid().build(); + + database.store(guid, envelope, key.getAddress(), key.getDeviceId()); + } catch (InvalidProtocolBufferException e) { + logger.error("Error parsing envelope", e); + } + + removeOperation.remove(jedis, key.getAddress(), key.getDeviceId(), score); + } + private List getQueuesToPersist(GetOperation getOperation) { Timer.Context timer = getQueuesTimer.time(); try { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java index a9955463d..477d876dd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java @@ -24,7 +24,6 @@ import java.util.List; import java.util.Optional; import java.util.Random; import java.util.UUID; -import java.util.stream.Collectors; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; @@ -45,8 +44,9 @@ public class MessagesTest { @Test public void testStore() throws SQLException { Envelope envelope = generateEnvelope(); + UUID guid = UUID.randomUUID(); - messages.store(List.of(envelope), "+14151112222", 1); + messages.store(guid, envelope, "+14151112222", 1); PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM messages WHERE destination = ?"); statement.setString(1, "+14151112222"); @@ -54,7 +54,7 @@ public class MessagesTest { ResultSet resultSet = statement.executeQuery(); assertThat(resultSet.next()).isTrue(); - assertThat(resultSet.getString("guid")).isEqualTo(envelope.getServerGuid()); + assertThat(resultSet.getString("guid")).isEqualTo(guid.toString()); assertThat(resultSet.getInt("type")).isEqualTo(envelope.getType().getNumber()); assertThat(resultSet.getString("relay")).isNullOrEmpty(); assertThat(resultSet.getLong("timestamp")).isEqualTo(envelope.getTimestamp()); @@ -71,28 +71,36 @@ public class MessagesTest { @Test public void testLoad() { - List inserted = insertRandom("+14151112222", 1); - inserted.sort(Comparator.comparingLong(Envelope::getTimestamp)); + List inserted = new ArrayList<>(50); + + for (int i=0;i<50;i++) { + MessageToStore message = generateMessageToStore(); + inserted.add(message); + + messages.store(message.guid, message.envelope, "+14151112222", 1); + } + + inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp())); List retrieved = messages.load("+14151112222", 1); assertThat(retrieved.size()).isEqualTo(inserted.size()); for (int i=0;i inserted = insertRandom("+14151112222", 1); - List unrelated = insertRandom("+14151114444", 3); - Envelope toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); - Optional removed = messages.remove("+14151112222", 1, toRemove.getSource(), toRemove.getTimestamp()); + List inserted = insertRandom("+14151112222", 1); + List unrelated = insertRandom("+14151114444", 3); + MessageToStore toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); + Optional removed = messages.remove("+14151112222", 1, toRemove.envelope.getSource(), toRemove.envelope.getTimestamp()); assertThat(removed.isPresent()).isTrue(); - verifyExpected(removed.get(), toRemove); + verifyExpected(removed.get(), toRemove.envelope, toRemove.guid); verifyInTact(inserted, "+14151112222", 1); verifyInTact(unrelated, "+14151114444", 3); @@ -100,13 +108,13 @@ public class MessagesTest { @Test public void removeByDestinationGuid() { - List unrelated = insertRandom("+14151113333", 2); - List inserted = insertRandom("+14151112222", 1); - Envelope toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); - Optional removed = messages.remove("+14151112222", UUID.fromString(toRemove.getServerGuid())); + List unrelated = insertRandom("+14151113333", 2); + List inserted = insertRandom("+14151112222", 1); + MessageToStore toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); + Optional removed = messages.remove("+14151112222", toRemove.guid); - assertThat(removed).isPresent(); - verifyExpected(removed.get(), toRemove); + assertThat(removed.isPresent()).isTrue(); + verifyExpected(removed.get(), toRemove.envelope, toRemove.guid); verifyInTact(inserted, "+14151112222", 1); verifyInTact(unrelated, "+14151113333", 2); @@ -114,10 +122,10 @@ public class MessagesTest { @Test public void removeByDestinationRowId() { - List unrelatedInserted = insertRandom("+14151111111", 1); - List inserted = insertRandom("+14151112222", 1); + List unrelatedInserted = insertRandom("+14151111111", 1); + List inserted = insertRandom("+14151112222", 1); - inserted.sort(Comparator.comparingLong(Envelope::getTimestamp)); + inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp())); List retrieved = messages.load("+14151112222", 1); @@ -133,7 +141,7 @@ public class MessagesTest { @Test public void testLoadEmpty() { - insertRandom("+14151112222", 1); + List inserted = insertRandom("+14151112222", 1); List loaded = messages.load("+14159999999", 1); assertThat(loaded.isEmpty()).isTrue(); } @@ -143,7 +151,7 @@ public class MessagesTest { insertRandom("+14151112222", 1); insertRandom("+14151112222", 2); - List unrelated = insertRandom("+14151111111", 1); + List unrelated = insertRandom("+14151111111", 1); messages.clear("+14151112222"); @@ -155,9 +163,9 @@ public class MessagesTest { @Test public void testClearDestinationDevice() { insertRandom("+14151112222", 1); - List inserted = insertRandom("+14151112222", 2); + List inserted = insertRandom("+14151112222", 2); - List unrelated = insertRandom("+14151111111", 1); + List unrelated = insertRandom("+14151111111", 1); messages.clear("+14151112222", 1); @@ -169,37 +177,38 @@ public class MessagesTest { @Test public void testVacuum() { - List inserted = insertRandom("+14151112222", 2); + List inserted = insertRandom("+14151112222", 2); messages.vacuum(); verifyInTact(inserted, "+14151112222", 2); } - private List insertRandom(String destination, int destinationDevice) { - List inserted = new ArrayList<>(50); + private List insertRandom(String destination, int destinationDevice) { + List inserted = new ArrayList<>(50); for (int i=0;i<50;i++) { - inserted.add(generateEnvelope()); - } + MessageToStore message = generateMessageToStore(); + inserted.add(message); - messages.store(inserted, destination, destinationDevice); + messages.store(message.guid, message.envelope, destination, destinationDevice); + } return inserted; } - private void verifyInTact(List inserted, String destination, int destinationDevice) { - inserted.sort(Comparator.comparingLong(Envelope::getTimestamp)); + private void verifyInTact(List inserted, String destination, int destinationDevice) { + inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp())); List retrieved = messages.load(destination, destinationDevice); assertThat(retrieved.size()).isEqualTo(inserted.size()); for (int i=0;i