diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 7db90e492..75eb6346b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -211,10 +211,13 @@ public class MessagesCache { this.removeRecipientViewFromMrmDataScript = removeRecipientViewFromMrmDataScript; } - public long insert(final UUID guid, final UUID destinationUuid, final byte destinationDevice, + public void insert(final UUID messageGuid, + final UUID destinationAccountIdentifier, + final byte destinationDeviceId, final MessageProtos.Envelope message) { - final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); - return insertTimer.record(() -> insertScript.execute(destinationUuid, destinationDevice, messageWithGuid)); + + final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(messageGuid.toString()).build(); + insertTimer.record(() -> insertScript.execute(destinationAccountIdentifier, destinationDeviceId, messageWithGuid)); } public byte[] insertSharedMultiRecipientMessagePayload( diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java index 87e3ba664..249ba5d78 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java @@ -27,7 +27,7 @@ class MessagesCacheInsertScript { this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); } - long execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) { + void execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) { assert envelope.hasServerGuid(); assert envelope.hasServerTimestamp(); @@ -43,6 +43,6 @@ class MessagesCacheInsertScript { envelope.getServerGuid().getBytes(StandardCharsets.UTF_8) // guid )); - return (long) insertScript.executeBinary(keys, args); + insertScript.executeBinary(keys, args); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java index febe8c32e..b43c681a1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java @@ -7,8 +7,14 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertEquals; +import java.io.IOException; +import java.io.UncheckedIOException; import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.UUID; +import com.google.protobuf.InvalidProtocolBufferException; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.entities.MessageProtos; @@ -21,8 +27,8 @@ class MessagesCacheInsertScriptTest { @Test void testCacheInsertScript() throws Exception { - final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript( - REDIS_CLUSTER_EXTENSION.getRedisCluster()); + final MessagesCacheInsertScript insertScript = + new MessagesCacheInsertScript(REDIS_CLUSTER_EXTENSION.getRedisCluster()); final UUID destinationUuid = UUID.randomUUID(); final byte deviceId = 1; @@ -31,15 +37,43 @@ class MessagesCacheInsertScriptTest { .setServerGuid(UUID.randomUUID().toString()) .build(); - assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1)); + insertScript.execute(destinationUuid, deviceId, envelope1); + + assertEquals(List.of(envelope1), getStoredMessages(destinationUuid, deviceId)); final MessageProtos.Envelope envelope2 = MessageProtos.Envelope.newBuilder() .setServerTimestamp(Instant.now().getEpochSecond()) .setServerGuid(UUID.randomUUID().toString()) .build(); - assertEquals(2, insertScript.execute(destinationUuid, deviceId, envelope2)); - assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1), - "Repeated with same guid should have same message ID"); + insertScript.execute(destinationUuid, deviceId, envelope2); + + assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId)); + + insertScript.execute(destinationUuid, deviceId, envelope1); + + assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId), + "Messages with same GUID should be deduplicated"); + } + + private List getStoredMessages(final UUID destinationUuid, final byte deviceId) throws IOException { + final MessagesCacheGetItemsScript getItemsScript = + new MessagesCacheGetItemsScript(REDIS_CLUSTER_EXTENSION.getRedisCluster()); + + final List queueItems = getItemsScript.execute(destinationUuid, deviceId, 1024, 0) + .blockOptional() + .orElseGet(Collections::emptyList); + + final List messages = new ArrayList<>(queueItems.size() / 2); + + for (int i = 0; i < queueItems.size(); i += 2) { + try { + messages.add(MessageProtos.Envelope.parseFrom(queueItems.get(i))); + } catch (final InvalidProtocolBufferException e) { + throw new UncheckedIOException(e); + } + } + + return messages; } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 299898323..d5d1817ed 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; @@ -132,8 +133,8 @@ class MessagesCacheTest { @ValueSource(booleans = {true, false}) void testInsert(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); - assertTrue(messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, - generateRandomMessage(messageGuid, sealedSender)) > 0); + assertDoesNotThrow(() -> messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid, sealedSender))); } @Test @@ -141,12 +142,13 @@ class MessagesCacheTest { final UUID duplicateGuid = UUID.randomUUID(); final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false); - final long firstId = messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, - duplicateMessage); - final long secondId = messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, - duplicateMessage); + messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); + messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); - assertEquals(firstId, secondId); + assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0) + .count() + .blockOptional() + .orElse(0L)); } @ParameterizedTest