From c7230ccbb07bad71c9cadf9bc0dd0926d18e02d5 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 23 Sep 2020 14:10:51 -0400 Subject: [PATCH] Remove messages from the cache in bulk. --- .../textsecuregcm/storage/MessagesCache.java | 33 +++++++++------- .../storage/MessagesManager.java | 6 +-- .../resources/lua/remove_item_by_guid.lua | 38 ++++++++++--------- .../storage/MessagesCacheTest.java | 34 +++++++++++++++++ 4 files changed, 77 insertions(+), 34 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index ba0d92ea8..f08456b98 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -35,6 +35,7 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; import static com.codahale.metrics.MetricRegistry.name; @@ -94,7 +95,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE); this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE); - this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE); + this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.MULTI); this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI); this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS); this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua", ScriptOutputType.MULTI); @@ -163,7 +164,6 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp getQueueIndexKey(destinationUuid, destinationDevice)), List.of(String.valueOf(id).getBytes(StandardCharsets.UTF_8)))); - if (serialized != null) { return Optional.of(constructEntityFromEnvelope(id, MessageProtos.Envelope.parseFrom(serialized))); } @@ -193,21 +193,28 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } public Optional remove(final UUID destinationUuid, final long destinationDevice, final UUID messageGuid) { - try { - final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() -> - removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), - getMessageQueueMetadataKey(destinationUuid, destinationDevice), - getQueueIndexKey(destinationUuid, destinationDevice)), - List.of(messageGuid.toString().getBytes(StandardCharsets.UTF_8)))); + return remove(destinationUuid, destinationDevice, List.of(messageGuid)).stream().findFirst(); + } - if (serialized != null) { - return Optional.of(constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized))); + @SuppressWarnings("unchecked") + public List remove(final UUID destinationUuid, final long destinationDevice, final List messageGuids) { + final List serialized = (List)Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() -> + removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), + getMessageQueueMetadataKey(destinationUuid, destinationDevice), + getQueueIndexKey(destinationUuid, destinationDevice)), + messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()))); + + final List removedMessages = new ArrayList<>(serialized.size()); + + for (final byte[] bytes : serialized) { + try { + removedMessages.add(constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(bytes))); + } catch (final InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); } - } catch (final InvalidProtocolBufferException e) { - logger.warn("Failed to parse envelope", e); } - return Optional.empty(); + return removedMessages; } @SuppressWarnings("unchecked") diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index dbbf45af4..61c941e24 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -15,6 +15,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.UUID; +import java.util.stream.Collectors; import static com.codahale.metrics.MetricRegistry.name; @@ -119,10 +120,7 @@ public class MessagesManager { public void persistMessages(final String destination, final UUID destinationUuid, final long destinationDeviceId, final List messages) { this.messages.store(messages, destination, destinationDeviceId); - - for (final Envelope message : messages) { - messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())); - } + messagesCache.remove(destinationUuid, destinationDeviceId, messages.stream().map(message -> UUID.fromString(message.getServerGuid())).collect(Collectors.toList())); } public void addMessageAvailabilityListener(final UUID destinationUuid, final long deviceId, final MessageAvailabilityListener listener) { diff --git a/service/src/main/resources/lua/remove_item_by_guid.lua b/service/src/main/resources/lua/remove_item_by_guid.lua index 19ffd257d..0c0345046 100644 --- a/service/src/main/resources/lua/remove_item_by_guid.lua +++ b/service/src/main/resources/lua/remove_item_by_guid.lua @@ -1,28 +1,32 @@ -- keys: queue_key, queue_metadata_key, queue_index -- argv: guid_to_remove -local messageId = redis.call("HGET", KEYS[2], ARGV[1]) +local removedMessages = {} -if messageId then - local envelope = redis.call("ZRANGEBYSCORE", KEYS[1], messageId, messageId, "LIMIT", 0, 1) - local sender = redis.call("HGET", KEYS[2], messageId) +for _, guid in ipairs(ARGV) do + local messageId = redis.call("HGET", KEYS[2], guid) - redis.call("ZREMRANGEBYSCORE", KEYS[1], messageId, messageId) - redis.call("HDEL", KEYS[2], ARGV[1]) - redis.call("HDEL", KEYS[2], messageId .. "guid") + if messageId then + local envelope = redis.call("ZRANGEBYSCORE", KEYS[1], messageId, messageId, "LIMIT", 0, 1) + local sender = redis.call("HGET", KEYS[2], messageId) - if sender then - redis.call("HDEL", KEYS[2], sender) - redis.call("HDEL", KEYS[2], messageId) - end + redis.call("ZREMRANGEBYSCORE", KEYS[1], messageId, messageId) + redis.call("HDEL", KEYS[2], guid) + redis.call("HDEL", KEYS[2], messageId .. "guid") - if (redis.call("ZCARD", KEYS[1]) == 0) then - redis.call("ZREM", KEYS[3], KEYS[1]) - end + if sender then + redis.call("HDEL", KEYS[2], sender) + redis.call("HDEL", KEYS[2], messageId) + end - if envelope and next(envelope) then - return envelope[1] + if (redis.call("ZCARD", KEYS[1]) == 0) then + redis.call("ZREM", KEYS[3], KEYS[1]) + end + + if envelope and next(envelope) then + removedMessages[#removedMessages + 1] = envelope[1] + end end end -return nil +return removedMessages diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index a0d14526a..49d7ce545 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -24,6 +24,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -113,6 +114,39 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertEquals(MessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get()); } + @Test + @Parameters({"true", "false"}) + public void testRemoveBatchByUUID(final boolean sealedSender) { + final int messageCount = 10; + + final List messagesToRemove = new ArrayList<>(messageCount); + final List messagesToPreserve = new ArrayList<>(messageCount); + + for (int i = 0; i < 10; i++) { + messagesToRemove.add(generateRandomMessage(UUID.randomUUID(), sealedSender)); + messagesToPreserve.add(generateRandomMessage(UUID.randomUUID(), sealedSender)); + } + + assertEquals(Collections.emptyList(), messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, + messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid())).collect(Collectors.toList()))); + + for (final MessageProtos.Envelope message : messagesToRemove) { + messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + } + + for (final MessageProtos.Envelope message : messagesToPreserve) { + messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + } + + final List removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, + messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid())).collect(Collectors.toList())); + + assertEquals(messagesToRemove.stream().map(message -> MessagesCache.constructEntityFromEnvelope(0, message)).collect(Collectors.toList()), + removedMessages); + + assertEquals(messagesToPreserve, messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); + } + @Test @Parameters({"true", "false"}) public void testGetMessages(final boolean sealedSender) {