Add check for existing key to MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript

This commit is contained in:
Chris Eager 2024-09-06 16:11:43 -05:00 committed by Chris Eager
parent 020c21f4ef
commit 5bc6ff0e77
6 changed files with 45 additions and 20 deletions

View File

@ -270,23 +270,24 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return insertTimer.record(() -> insertScript.execute(destinationUuid, destinationDevice, messageWithGuid)); return insertTimer.record(() -> insertScript.execute(destinationUuid, destinationDevice, messageWithGuid));
} }
public byte[] insertSharedMultiRecipientMessagePayload(UUID mrmGuid, public byte[] insertSharedMultiRecipientMessagePayload(
SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
final byte[] sharedMrmKey = getSharedMrmKey(mrmGuid); return insertSharedMrmPayloadTimer.record(() -> {
insertSharedMrmPayloadTimer.record(() -> insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage)); final byte[] sharedMrmKey = getSharedMrmKey(UUID.randomUUID());
return sharedMrmKey; insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage);
return sharedMrmKey;
});
} }
public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid, public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid, final byte destinationDevice,
final byte destinationDevice,
final UUID messageGuid) { final UUID messageGuid) {
return remove(destinationUuid, destinationDevice, List.of(messageGuid)) return remove(destinationUuid, destinationDevice, List.of(messageGuid))
.thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.getFirst())); .thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.getFirst()));
} }
public CompletableFuture<List<RemovedMessage>> remove(final UUID destinationUuid, public CompletableFuture<List<RemovedMessage>> remove(final UUID destinationUuid, final byte destinationDevice,
final byte destinationDevice, final List<UUID> messageGuids) { final List<UUID> messageGuids) {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
@ -469,8 +470,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
/** /**
* Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure * Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure
*/ */
void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final UUID accountUuid, void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final UUID accountUuid, final byte deviceId) {
final byte deviceId) {
if (sharedMrmKeys.isEmpty()) { if (sharedMrmKeys.isEmpty()) {
return; return;

View File

@ -23,6 +23,8 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript {
private final ClusterLuaScript script; private final ClusterLuaScript script;
static final String ERROR_KEY_EXISTS = "ERR key exists";
MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(FaultTolerantRedisCluster redisCluster) MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(FaultTolerantRedisCluster redisCluster)
throws IOException { throws IOException {
this.script = ClusterLuaScript.fromResource(redisCluster, "lua/insert_shared_multirecipient_message_data.lua", this.script = ClusterLuaScript.fromResource(redisCluster, "lua/insert_shared_multirecipient_message_data.lua",

View File

@ -207,8 +207,8 @@ public class MessagesManager {
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript * @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
*/ */
public byte[] insertSharedMultiRecipientMessagePayload( public byte[] insertSharedMultiRecipientMessagePayload(
SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
return messagesCache.insertSharedMultiRecipientMessagePayload(UUID.randomUUID(), sealedSenderMultiRecipientMessage); return messagesCache.insertSharedMultiRecipientMessagePayload(sealedSenderMultiRecipientMessage);
} }
/** /**

View File

@ -4,6 +4,10 @@ local sharedMrmKey = KEYS[1] -- [string] the key containing the shared MRM data
local mrmData = ARGV[1] -- [bytes] the serialized multi-recipient message data local mrmData = ARGV[1] -- [bytes] the serialized multi-recipient message data
-- the remainder of ARGV is list of recipient keys and view data -- the remainder of ARGV is list of recipient keys and view data
if 1 == redis.call("EXISTS", sharedMrmKey) then
return redis.error_reply("ERR key exists")
end
redis.call("HSET", sharedMrmKey, "data", mrmData); redis.call("HSET", sharedMrmKey, "data", mrmData);
redis.call("EXPIRE", sharedMrmKey, 604800) -- 7 days redis.call("EXPIRE", sharedMrmKey, 604800) -- 7 days

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@ -14,6 +15,8 @@ import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import io.lettuce.core.RedisCommandExecutionException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
@ -41,7 +44,8 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
final int totalDevices = destinations.values().stream().mapToInt(List::size).sum(); final int totalDevices = destinations.values().stream().mapToInt(List::size).sum();
final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster() final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().hlen(sharedMrmKey)); .withBinaryCluster(conn -> conn.sync().hlen(sharedMrmKey));
assertEquals(totalDevices + 1, hashFieldCount); // + 1 because of "data" field
assertEquals(1 + totalDevices, hashFieldCount);
} }
public static List<Arguments> testInsert() { public static List<Arguments> testInsert() {
@ -71,4 +75,21 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
return testCases; return testCases;
} }
@Test
void testInsertDuplicateKey() throws Exception {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID));
final RedisCommandExecutionException e = assertThrows(RedisCommandExecutionException.class,
() -> insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()),
Device.PRIMARY_ID)));
assertEquals(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS, e.getMessage());
}
} }

View File

@ -569,13 +569,12 @@ class MessagesCacheTest {
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1; final byte deviceId = 1;
final UUID mrmGuid = UUID.randomUUID();
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage( final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(
new AciServiceIdentifier(destinationUuid), deviceId); new AciServiceIdentifier(destinationUuid), deviceId);
final byte[] sharedMrmDataKey; final byte[] sharedMrmDataKey;
if (sharedMrmKeyPresent) { if (sharedMrmKeyPresent) {
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrmGuid, mrm); sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm);
} else { } else {
sharedMrmDataKey = new byte[]{1}; sharedMrmDataKey = new byte[]{1};
} }
@ -593,7 +592,7 @@ class MessagesCacheTest {
messagesCache.insert(guid, destinationUuid, deviceId, message); messagesCache.insert(guid, destinationUuid, deviceId, message);
assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster() assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid)))); .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));
final List<MessageProtos.Envelope> messages = get(destinationUuid, deviceId, 1); final List<MessageProtos.Envelope> messages = get(destinationUuid, deviceId, 1);
assertEquals(1, messages.size()); assertEquals(1, messages.size());
@ -616,9 +615,9 @@ class MessagesCacheTest {
boolean exists; boolean exists;
do { do {
exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster() exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid))); .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey));
} while (exists); } while (exists);
}); }, "Shared MRM data should be deleted asynchronously");
} }
private List<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId, private List<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId,
@ -628,7 +627,6 @@ class MessagesCacheTest {
.collectList() .collectList()
.block(); .block();
} }
} }
@Nested @Nested