Don't return message IDs from the "insert message" script

This commit is contained in:
Jon Chambers 2024-11-07 14:50:01 -05:00 committed by Jon Chambers
parent 7158a504fa
commit 1fa31b3974
4 changed files with 57 additions and 18 deletions

View File

@ -211,10 +211,13 @@ public class MessagesCache {
this.removeRecipientViewFromMrmDataScript = removeRecipientViewFromMrmDataScript; this.removeRecipientViewFromMrmDataScript = removeRecipientViewFromMrmDataScript;
} }
public long insert(final UUID guid, final UUID destinationUuid, final byte destinationDevice, public void insert(final UUID messageGuid,
final UUID destinationAccountIdentifier,
final byte destinationDeviceId,
final MessageProtos.Envelope message) { final MessageProtos.Envelope message) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
return insertTimer.record(() -> insertScript.execute(destinationUuid, destinationDevice, messageWithGuid)); final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(messageGuid.toString()).build();
insertTimer.record(() -> insertScript.execute(destinationAccountIdentifier, destinationDeviceId, messageWithGuid));
} }
public byte[] insertSharedMultiRecipientMessagePayload( public byte[] insertSharedMultiRecipientMessagePayload(

View File

@ -27,7 +27,7 @@ class MessagesCacheInsertScript {
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
} }
long execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) { void execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) {
assert envelope.hasServerGuid(); assert envelope.hasServerGuid();
assert envelope.hasServerTimestamp(); assert envelope.hasServerTimestamp();
@ -43,6 +43,6 @@ class MessagesCacheInsertScript {
envelope.getServerGuid().getBytes(StandardCharsets.UTF_8) // guid envelope.getServerGuid().getBytes(StandardCharsets.UTF_8) // guid
)); ));
return (long) insertScript.executeBinary(keys, args); insertScript.executeBinary(keys, args);
} }
} }

View File

@ -7,8 +7,14 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID; import java.util.UUID;
import com.google.protobuf.InvalidProtocolBufferException;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
@ -21,8 +27,8 @@ class MessagesCacheInsertScriptTest {
@Test @Test
void testCacheInsertScript() throws Exception { void testCacheInsertScript() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript( final MessagesCacheInsertScript insertScript =
REDIS_CLUSTER_EXTENSION.getRedisCluster()); new MessagesCacheInsertScript(REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1; final byte deviceId = 1;
@ -31,15 +37,43 @@ class MessagesCacheInsertScriptTest {
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.build(); .build();
assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1)); insertScript.execute(destinationUuid, deviceId, envelope1);
assertEquals(List.of(envelope1), getStoredMessages(destinationUuid, deviceId));
final MessageProtos.Envelope envelope2 = MessageProtos.Envelope.newBuilder() final MessageProtos.Envelope envelope2 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond()) .setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.build(); .build();
assertEquals(2, insertScript.execute(destinationUuid, deviceId, envelope2));
assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1), insertScript.execute(destinationUuid, deviceId, envelope2);
"Repeated with same guid should have same message ID");
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId));
insertScript.execute(destinationUuid, deviceId, envelope1);
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId),
"Messages with same GUID should be deduplicated");
}
private List<MessageProtos.Envelope> getStoredMessages(final UUID destinationUuid, final byte deviceId) throws IOException {
final MessagesCacheGetItemsScript getItemsScript =
new MessagesCacheGetItemsScript(REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> queueItems = getItemsScript.execute(destinationUuid, deviceId, 1024, 0)
.blockOptional()
.orElseGet(Collections::emptyList);
final List<MessageProtos.Envelope> messages = new ArrayList<>(queueItems.size() / 2);
for (int i = 0; i < queueItems.size(); i += 2) {
try {
messages.add(MessageProtos.Envelope.parseFrom(queueItems.get(i)));
} catch (final InvalidProtocolBufferException e) {
throw new UncheckedIOException(e);
}
}
return messages;
} }
} }

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
@ -132,8 +133,8 @@ class MessagesCacheTest {
@ValueSource(booleans = {true, false}) @ValueSource(booleans = {true, false})
void testInsert(final boolean sealedSender) { void testInsert(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
assertTrue(messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, assertDoesNotThrow(() -> messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid, sealedSender)) > 0); generateRandomMessage(messageGuid, sealedSender)));
} }
@Test @Test
@ -141,12 +142,13 @@ class MessagesCacheTest {
final UUID duplicateGuid = UUID.randomUUID(); final UUID duplicateGuid = UUID.randomUUID();
final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false); final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false);
final long firstId = messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage);
duplicateMessage); messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage);
final long secondId = messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
duplicateMessage);
assertEquals(firstId, secondId); assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0)
.count()
.blockOptional()
.orElse(0L));
} }
@ParameterizedTest @ParameterizedTest