From 5bc6ff0e778468d6ff5a832e1cd09ffcdc085618 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Fri, 6 Sep 2024 16:11:43 -0500 Subject: [PATCH] Add check for existing key to MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript --- .../textsecuregcm/storage/MessagesCache.java | 22 +++++++++--------- ...edMultiRecipientPayloadAndViewsScript.java | 2 ++ .../storage/MessagesManager.java | 4 ++-- ...ert_shared_multirecipient_message_data.lua | 4 ++++ ...ltiRecipientPayloadAndViewsScriptTest.java | 23 ++++++++++++++++++- .../storage/MessagesCacheTest.java | 10 ++++---- 6 files changed, 45 insertions(+), 20 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 40ea5d372..0797e3f12 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -270,23 +270,24 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp return insertTimer.record(() -> insertScript.execute(destinationUuid, destinationDevice, messageWithGuid)); } - public byte[] insertSharedMultiRecipientMessagePayload(UUID mrmGuid, - SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { - final byte[] sharedMrmKey = getSharedMrmKey(mrmGuid); - insertSharedMrmPayloadTimer.record(() -> insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage)); - return sharedMrmKey; + public byte[] insertSharedMultiRecipientMessagePayload( + final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { + return insertSharedMrmPayloadTimer.record(() -> { + final byte[] sharedMrmKey = getSharedMrmKey(UUID.randomUUID()); + insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage); + return sharedMrmKey; + }); } - public CompletableFuture> remove(final UUID destinationUuid, - final byte destinationDevice, + public CompletableFuture> remove(final UUID destinationUuid, final byte destinationDevice, final UUID messageGuid) { return remove(destinationUuid, destinationDevice, List.of(messageGuid)) .thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.getFirst())); } - public CompletableFuture> remove(final UUID destinationUuid, - final byte destinationDevice, final List messageGuids) { + public CompletableFuture> remove(final UUID destinationUuid, final byte destinationDevice, + final List messageGuids) { final Timer.Sample sample = Timer.start(); @@ -469,8 +470,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp /** * Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure */ - void removeRecipientViewFromMrmData(final List sharedMrmKeys, final UUID accountUuid, - final byte deviceId) { + void removeRecipientViewFromMrmData(final List sharedMrmKeys, final UUID accountUuid, final byte deviceId) { if (sharedMrmKeys.isEmpty()) { return; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java index 28e8f0d59..b0f3ef5db 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java @@ -23,6 +23,8 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript { private final ClusterLuaScript script; + static final String ERROR_KEY_EXISTS = "ERR key exists"; + MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(FaultTolerantRedisCluster redisCluster) throws IOException { this.script = ClusterLuaScript.fromResource(redisCluster, "lua/insert_shared_multirecipient_message_data.lua", diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 20580cfd8..fa25a270f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -207,8 +207,8 @@ public class MessagesManager { * @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript */ public byte[] insertSharedMultiRecipientMessagePayload( - SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { - return messagesCache.insertSharedMultiRecipientMessagePayload(UUID.randomUUID(), sealedSenderMultiRecipientMessage); + final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { + return messagesCache.insertSharedMultiRecipientMessagePayload(sealedSenderMultiRecipientMessage); } /** diff --git a/service/src/main/resources/lua/insert_shared_multirecipient_message_data.lua b/service/src/main/resources/lua/insert_shared_multirecipient_message_data.lua index 214bce4b4..7beaaf21b 100644 --- a/service/src/main/resources/lua/insert_shared_multirecipient_message_data.lua +++ b/service/src/main/resources/lua/insert_shared_multirecipient_message_data.lua @@ -4,6 +4,10 @@ local sharedMrmKey = KEYS[1] -- [string] the key containing the shared MRM data local mrmData = ARGV[1] -- [bytes] the serialized multi-recipient message data -- the remainder of ARGV is list of recipient keys and view data +if 1 == redis.call("EXISTS", sharedMrmKey) then + return redis.error_reply("ERR key exists") +end + redis.call("HSET", sharedMrmKey, "data", mrmData); redis.call("EXPIRE", sharedMrmKey, 604800) -- 7 days diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java index e1ceea81d..c5f61e8c6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.ArrayList; import java.util.HashMap; @@ -14,6 +15,8 @@ import java.util.Map; import java.util.UUID; import java.util.stream.Collectors; import java.util.stream.IntStream; +import io.lettuce.core.RedisCommandExecutionException; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -41,7 +44,8 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest { final int totalDevices = destinations.values().stream().mapToInt(List::size).sum(); final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster() .withBinaryCluster(conn -> conn.sync().hlen(sharedMrmKey)); - assertEquals(totalDevices + 1, hashFieldCount); + // + 1 because of "data" field + assertEquals(1 + totalDevices, hashFieldCount); } public static List testInsert() { @@ -71,4 +75,21 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest { return testCases; } + @Test + void testInsertDuplicateKey() throws Exception { + final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript( + REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); + insertMrmScript.execute(sharedMrmKey, + MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID)); + + final RedisCommandExecutionException e = assertThrows(RedisCommandExecutionException.class, + () -> insertMrmScript.execute(sharedMrmKey, + MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), + Device.PRIMARY_ID))); + + assertEquals(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS, e.getMessage()); + } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 189bac16f..54061bc61 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -569,13 +569,12 @@ class MessagesCacheTest { final UUID destinationUuid = UUID.randomUUID(); final byte deviceId = 1; - final UUID mrmGuid = UUID.randomUUID(); final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage( new AciServiceIdentifier(destinationUuid), deviceId); final byte[] sharedMrmDataKey; if (sharedMrmKeyPresent) { - sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrmGuid, mrm); + sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm); } else { sharedMrmDataKey = new byte[]{1}; } @@ -593,7 +592,7 @@ class MessagesCacheTest { messagesCache.insert(guid, destinationUuid, deviceId, message); assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster() - .withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid)))); + .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey))); final List messages = get(destinationUuid, deviceId, 1); assertEquals(1, messages.size()); @@ -616,9 +615,9 @@ class MessagesCacheTest { boolean exists; do { exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster() - .withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid))); + .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)); } while (exists); - }); + }, "Shared MRM data should be deleted asynchronously"); } private List get(final UUID destinationUuid, final byte destinationDeviceId, @@ -628,7 +627,6 @@ class MessagesCacheTest { .collectList() .block(); } - } @Nested