diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index 49bc74abd..4a714a82a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -141,6 +141,7 @@ public class MessagePersister implements Managed { try (final Timer.Context ignored = persistQueueTimer.time()) { messagesCache.lockQueueForPersistence(accountUuid, deviceId); + messagesCache.repairMetadata(accountUuid, deviceId); try { int messageCount = 0; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 8100f2c0f..584ea38d3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -55,6 +55,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp private final ClusterLuaScript getItemsScript; private final ClusterLuaScript removeQueueScript; private final ClusterLuaScript getQueuesToPersistScript; + private final ClusterLuaScript repairMetadataScript; private final Map messageListenersByQueueName = new HashMap<>(); private final Map queueNamesByMessageListener = new IdentityHashMap<>(); @@ -65,6 +66,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp private final Timer getQueuesToPersistTimer = Metrics.timer(name(MessagesCache.class, "getQueuesToPersist")); private final Timer clearQueueTimer = Metrics.timer(name(MessagesCache.class, "clear")); private final Timer takeEphemeralMessageTimer = Metrics.timer(name(MessagesCache.class, "takeEphemeral")); + private final Timer repairMetadataTimer = Metrics.timer(name(MessagesCache.class, "repairMetadata")); private final Counter pubSubMessageCounter = Metrics.counter(name(MessagesCache.class, "pubSubMessage")); private final Counter newMessageNotificationCounter = Metrics.counter(name(MessagesCache.class, "newMessageNotification"), "ephemeral", "false"); private final Counter ephemeralMessageNotificationCounter = Metrics.counter(name(MessagesCache.class, "newMessageNotification"), "ephemeral", "true"); @@ -102,6 +104,8 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI); this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS); this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua", ScriptOutputType.MULTI); + + this.repairMetadataScript = ClusterLuaScript.fromResource(redisCluster, "lua/repair_queue_metadata.lua", ScriptOutputType.VALUE); } @Override @@ -220,6 +224,15 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp return removedMessages; } + @VisibleForTesting + void repairMetadata(final UUID destinationUuid, final long destinationDevice) { + repairMetadataTimer.record(() -> { + repairMetadataScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), + getMessageQueueMetadataKey(destinationUuid, destinationDevice)), + Collections.emptyList()); + }); + } + @SuppressWarnings("unchecked") public List get(final UUID destinationUuid, final long destinationDevice, final int limit) { return getMessagesTimer.record(() -> { diff --git a/service/src/main/resources/lua/insert_item.lua b/service/src/main/resources/lua/insert_item.lua index 0878e4ab1..28808dc59 100644 --- a/service/src/main/resources/lua/insert_item.lua +++ b/service/src/main/resources/lua/insert_item.lua @@ -16,7 +16,6 @@ if sender ~= "nil" then end redis.call("HSET", queueMetadataKey, guid, messageId) - redis.call("HSET", queueMetadataKey, messageId .. "guid", guid) redis.call("EXPIRE", queueKey, 7776000) -- 90 days diff --git a/service/src/main/resources/lua/repair_queue_metadata.lua b/service/src/main/resources/lua/repair_queue_metadata.lua new file mode 100644 index 000000000..0c3b914d0 --- /dev/null +++ b/service/src/main/resources/lua/repair_queue_metadata.lua @@ -0,0 +1,21 @@ +local queueKey = KEYS[1] +local queueMetadataKey = KEYS[2] + +local firstMessageWithScore = redis.call("ZRANGE", queueKey, 0, 0, "WITHSCORES") +local lastMessageWithScore = redis.call("ZRANGE", queueKey, -1, -1, "WITHSCORES") + +if firstMessageWithScore ~= nil and lastMessageWithScore ~= nil then + local firstMessageId = tonumber(firstMessageWithScore[2]) + local lastMessageId = tonumber(lastMessageWithScore[2]) + + for messageId = firstMessageId,lastMessageId do + if redis.call("ZRANGEBYSCORE", queueKey, messageId, messageId) then + -- This message actually exists, and its GUID may be pointing to the wrong ID + local guid = redis.call("HGET", queueMetadataKey, messageId .. "guid") + redis.call("HSET", queueMetadataKey, guid, messageId) + else + -- No message actually exists with that ID; drop the metadata reference to that ID + redis.call("HDEL", queueMetadataKey, messageId .. "guid") + end + end +end diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 43b3e4ef3..96df6cd13 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -27,6 +27,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @RunWith(JUnitParamsRunner.class) @@ -71,6 +72,41 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertTrue(messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender)) > 0); } + @Test + public void testRepairMetadata() { + final int distinctUuidCount = 17; + + for (int i = 0; i < distinctUuidCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, false)); + } + + assertEquals(distinctUuidCount, messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, 100).size()); + + final int duplicateGuidCount = 5; + + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope duplicatedMessage = generateRandomMessage(messageGuid, false); + + for (int i = 0; i < duplicateGuidCount; i++) { + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicatedMessage); + } + + assertEquals(distinctUuidCount + 1, messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, 100).size()); + assertFalse(messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid).isPresent()); + + messagesCache.repairMetadata(DESTINATION_UUID, DESTINATION_DEVICE_ID); + + assertTrue(messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid).isPresent()); + + final List messagesToPersist = messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, 100); + assertEquals(distinctUuidCount, messagesToPersist.size()); + + messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messagesToPersist.stream().map(message -> UUID.fromString(message.getServerGuid())).collect(Collectors.toList())); + + assertTrue(messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, 100).isEmpty()); + } + @Test @Parameters({"true", "false"}) public void testRemoveById(final boolean sealedSender) {