Add check for existing key to MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript

This commit is contained in:
Chris Eager 2024-09-06 16:11:43 -05:00 committed by Chris Eager
parent 020c21f4ef
commit 5bc6ff0e77
6 changed files with 45 additions and 20 deletions

View File

@ -270,23 +270,24 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return insertTimer.record(() -> insertScript.execute(destinationUuid, destinationDevice, messageWithGuid));
}
public byte[] insertSharedMultiRecipientMessagePayload(UUID mrmGuid,
SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
final byte[] sharedMrmKey = getSharedMrmKey(mrmGuid);
insertSharedMrmPayloadTimer.record(() -> insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage));
return sharedMrmKey;
public byte[] insertSharedMultiRecipientMessagePayload(
final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
return insertSharedMrmPayloadTimer.record(() -> {
final byte[] sharedMrmKey = getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage);
return sharedMrmKey;
});
}
public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid,
final byte destinationDevice,
public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid, final byte destinationDevice,
final UUID messageGuid) {
return remove(destinationUuid, destinationDevice, List.of(messageGuid))
.thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.getFirst()));
}
public CompletableFuture<List<RemovedMessage>> remove(final UUID destinationUuid,
final byte destinationDevice, final List<UUID> messageGuids) {
public CompletableFuture<List<RemovedMessage>> remove(final UUID destinationUuid, final byte destinationDevice,
final List<UUID> messageGuids) {
final Timer.Sample sample = Timer.start();
@ -469,8 +470,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
/**
* Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure
*/
void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final UUID accountUuid,
final byte deviceId) {
void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final UUID accountUuid, final byte deviceId) {
if (sharedMrmKeys.isEmpty()) {
return;

View File

@ -23,6 +23,8 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript {
private final ClusterLuaScript script;
static final String ERROR_KEY_EXISTS = "ERR key exists";
MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(FaultTolerantRedisCluster redisCluster)
throws IOException {
this.script = ClusterLuaScript.fromResource(redisCluster, "lua/insert_shared_multirecipient_message_data.lua",

View File

@ -207,8 +207,8 @@ public class MessagesManager {
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
*/
public byte[] insertSharedMultiRecipientMessagePayload(
SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
return messagesCache.insertSharedMultiRecipientMessagePayload(UUID.randomUUID(), sealedSenderMultiRecipientMessage);
final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
return messagesCache.insertSharedMultiRecipientMessagePayload(sealedSenderMultiRecipientMessage);
}
/**

View File

@ -4,6 +4,10 @@ local sharedMrmKey = KEYS[1] -- [string] the key containing the shared MRM data
local mrmData = ARGV[1] -- [bytes] the serialized multi-recipient message data
-- the remainder of ARGV is list of recipient keys and view data
if 1 == redis.call("EXISTS", sharedMrmKey) then
return redis.error_reply("ERR key exists")
end
redis.call("HSET", sharedMrmKey, "data", mrmData);
redis.call("EXPIRE", sharedMrmKey, 604800) -- 7 days

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.ArrayList;
import java.util.HashMap;
@ -14,6 +15,8 @@ import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import io.lettuce.core.RedisCommandExecutionException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
@ -41,7 +44,8 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
final int totalDevices = destinations.values().stream().mapToInt(List::size).sum();
final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().hlen(sharedMrmKey));
assertEquals(totalDevices + 1, hashFieldCount);
// + 1 because of "data" field
assertEquals(1 + totalDevices, hashFieldCount);
}
public static List<Arguments> testInsert() {
@ -71,4 +75,21 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
return testCases;
}
@Test
void testInsertDuplicateKey() throws Exception {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID));
final RedisCommandExecutionException e = assertThrows(RedisCommandExecutionException.class,
() -> insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()),
Device.PRIMARY_ID)));
assertEquals(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS, e.getMessage());
}
}

View File

@ -569,13 +569,12 @@ class MessagesCacheTest {
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final UUID mrmGuid = UUID.randomUUID();
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(
new AciServiceIdentifier(destinationUuid), deviceId);
final byte[] sharedMrmDataKey;
if (sharedMrmKeyPresent) {
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrmGuid, mrm);
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm);
} else {
sharedMrmDataKey = new byte[]{1};
}
@ -593,7 +592,7 @@ class MessagesCacheTest {
messagesCache.insert(guid, destinationUuid, deviceId, message);
assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid))));
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));
final List<MessageProtos.Envelope> messages = get(destinationUuid, deviceId, 1);
assertEquals(1, messages.size());
@ -616,9 +615,9 @@ class MessagesCacheTest {
boolean exists;
do {
exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid)));
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey));
} while (exists);
});
}, "Shared MRM data should be deleted asynchronously");
}
private List<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId,
@ -628,7 +627,6 @@ class MessagesCacheTest {
.collectList()
.block();
}
}
@Nested