From beac73b6c8222128d2dfa87871887e7cd8536ea3 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Tue, 21 Jul 2020 11:47:53 -0400 Subject: [PATCH] Add a cluster-capable message persister --- .../storage/RedisClusterMessagePersister.java | 181 +++++++++++++++++ .../storage/RedisClusterMessagesCache.java | 109 ++++++++-- .../textsecuregcm/util/RedisClusterUtil.java | 14 +- .../RedisClusterMessagePersisterTest.java | 189 ++++++++++++++++++ .../RedisClusterMessagesCacheTest.java | 34 ++++ 5 files changed, 500 insertions(+), 27 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagePersister.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagePersisterTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagePersister.java new file mode 100644 index 000000000..c0aff47f2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagePersister.java @@ -0,0 +1,181 @@ +package org.whispersystems.textsecuregcm.storage; + +import com.google.common.annotations.VisibleForTesting; +import io.dropwizard.lifecycle.Managed; +import io.micrometer.core.instrument.DistributionSummary; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; +import org.whispersystems.textsecuregcm.push.PushSender; +import org.whispersystems.textsecuregcm.util.Util; +import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import static com.codahale.metrics.MetricRegistry.name; + +public class RedisClusterMessagePersister implements Managed { + + private final RedisClusterMessagesCache messagesCache; + private final Messages messagesDatabase; + private final PubSubManager pubSubManager; + private final PushSender pushSender; + private final AccountsManager accountsManager; + + private final Duration persistDelay; + + private volatile boolean running = false; + private Thread workerThread; + + private static final Timer GET_QUEUES_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "getQueues")); + private static final Timer PERSIST_QUEUE_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "persistQueue")); + private static final Timer NOTIFY_SUBSCRIBERS_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "notifySubscribers")); + private static final DistributionSummary QUEUE_COUNT_SUMMARY = Metrics.summary(name(RedisClusterMessagePersister.class, "queueCount")); + private static final DistributionSummary QUEUE_SIZE_SUMMARY = Metrics.summary(name(RedisClusterMessagePersister.class, "queueSize")); + + static final int QUEUE_BATCH_LIMIT = 100; + static final int MESSAGE_BATCH_LIMIT = 100; + + private static final Logger logger = LoggerFactory.getLogger(RedisClusterMessagePersister.class); + + public RedisClusterMessagePersister(final RedisClusterMessagesCache messagesCache, final Messages messagesDatabase, final PubSubManager pubSubManager, final PushSender pushSender, final AccountsManager accountsManager, final Duration persistDelay) { + this.messagesCache = messagesCache; + this.messagesDatabase = messagesDatabase; + this.pubSubManager = pubSubManager; + this.pushSender = pushSender; + this.accountsManager = accountsManager; + + this.persistDelay = persistDelay; + } + + @Override + public void start() { + running = true; + + workerThread = new Thread(() -> { + while (running) { + persistNextQueues(Instant.now()); + Util.sleep(100); + } + }); + + workerThread.start(); + } + + @Override + public void stop() throws Exception { + running = false; + + if (workerThread != null) { + workerThread.join(); + workerThread = null; + } + } + + @VisibleForTesting + void persistNextQueues(final Instant currentTime) { + final int slot = messagesCache.getNextSlotToPersist(); + + List queuesToPersist; + int queuesPersisted = 0; + + do { + queuesToPersist = GET_QUEUES_TIMER.record(() -> messagesCache.getQueuesToPersist(slot, currentTime.minus(persistDelay), QUEUE_BATCH_LIMIT)); + + for (final String queue : queuesToPersist) { + persistQueue(queue); + notifyClients(RedisClusterMessagesCache.getAccountUuidFromQueueName(queue), RedisClusterMessagesCache.getDeviceIdFromQueueName(queue)); + } + + queuesPersisted += queuesToPersist.size(); + } while (queuesToPersist.size() == QUEUE_BATCH_LIMIT); + + QUEUE_COUNT_SUMMARY.record(queuesPersisted); + } + + @VisibleForTesting + void persistQueue(final String queue) { + final UUID accountUuid = RedisClusterMessagesCache.getAccountUuidFromQueueName(queue); + final long deviceId = RedisClusterMessagesCache.getDeviceIdFromQueueName(queue); + + final Optional maybeAccount = accountsManager.get(accountUuid); + + final String accountNumber; + + if (maybeAccount.isPresent()) { + accountNumber = maybeAccount.get().getNumber(); + } else { + logger.error("No account record found for account {}", accountUuid); + return; + } + + PERSIST_QUEUE_TIMER.record(() -> { + messagesCache.lockQueueForPersistence(queue); + + try { + int messageCount = 0; + List messages; + + do { + messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT); + + for (final MessageProtos.Envelope message : messages) { + final UUID uuid = UUID.fromString(message.getServerGuid()); + + messagesDatabase.store(uuid, message, accountNumber, deviceId); + messagesCache.remove(accountNumber, accountUuid, deviceId, uuid); + + messageCount++; + } + } while (messages.size() == MESSAGE_BATCH_LIMIT); + + QUEUE_SIZE_SUMMARY.record(messageCount); + } finally { + messagesCache.unlockQueueForPersistence(queue); + } + }); + } + + public void notifyClients(final UUID accountUuid, final long deviceId) { + NOTIFY_SUBSCRIBERS_TIMER.record(() -> { + final Optional maybeAccount = accountsManager.get(accountUuid); + + final String address; + + if (maybeAccount.isPresent()) { + address = maybeAccount.get().getNumber(); + } else { + logger.error("No account record found for account {}", accountUuid); + return; + } + + final boolean notified = pubSubManager.publish(new WebsocketAddress(address, deviceId), + PubSubProtos.PubSubMessage.newBuilder() + .setType(PubSubProtos.PubSubMessage.Type.QUERY_DB) + .build()); + + if (!notified) { + Optional account = accountsManager.get(address); + + if (account.isPresent()) { + Optional device = account.get().getDevice(deviceId); + + if (device.isPresent()) { + try { + pushSender.sendQueuedNotification(account.get(), device.get()); + } catch (final NotPushRegisteredException e) { + logger.warn("After message persistence, no longer push registered!"); + } + } + } + } + }); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java index a5c323255..39d8b4cc3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java @@ -1,7 +1,9 @@ package org.whispersystems.textsecuregcm.storage; +import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.InvalidProtocolBufferException; import io.lettuce.core.ScriptOutputType; +import io.lettuce.core.cluster.SlotHash; import io.micrometer.core.instrument.Metrics; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -13,6 +15,7 @@ import org.whispersystems.textsecuregcm.util.RedisClusterUtil; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -23,12 +26,18 @@ import static com.codahale.metrics.MetricRegistry.name; public class RedisClusterMessagesCache implements UserMessagesCache { + private final FaultTolerantRedisCluster redisCluster; + private final ClusterLuaScript insertScript; private final ClusterLuaScript removeByIdScript; private final ClusterLuaScript removeBySenderScript; private final ClusterLuaScript removeByGuidScript; private final ClusterLuaScript getItemsScript; private final ClusterLuaScript removeQueueScript; + private final ClusterLuaScript getQueuesToPersistScript; + + static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; + private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); private static final String INSERT_TIMER_NAME = name(RedisClusterMessagesCache.class, "insert"); private static final String REMOVE_TIMER_NAME = name(RedisClusterMessagesCache.class, "remove"); @@ -44,12 +53,15 @@ public class RedisClusterMessagesCache implements UserMessagesCache { public RedisClusterMessagesCache(final FaultTolerantRedisCluster redisCluster) throws IOException { - this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); - this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE); - this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE); - this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE); - this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI); - this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS); + this.redisCluster = redisCluster; + + this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); + this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE); + this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE); + this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE); + this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI); + this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS); + this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua", ScriptOutputType.MULTI); } @Override @@ -122,13 +134,13 @@ public class RedisClusterMessagesCache implements UserMessagesCache { } @Override - public Optional remove(final String destination, final UUID destinationUuid, final long destinationDevice, final UUID guid) { + public Optional remove(final String destination, final UUID destinationUuid, final long destinationDevice, final UUID messageGuid) { try { final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() -> removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), getMessageQueueMetadataKey(destinationUuid, destinationDevice), getQueueIndexKey(destinationUuid, destinationDevice)), - List.of(guid.toString().getBytes(StandardCharsets.UTF_8)))); + List.of(messageGuid.toString().getBytes(StandardCharsets.UTF_8)))); if (serialized != null) { return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized))); @@ -142,11 +154,11 @@ public class RedisClusterMessagesCache implements UserMessagesCache { @Override @SuppressWarnings("unchecked") - public List get(String destination, final UUID destinationUuid, long destinationDevice, int limit) { + public List get(final String destination, final UUID destinationUuid, final long destinationDevice, final int limit) { return Metrics.timer(GET_TIMER_NAME).record(() -> { final List queueItems = (List)getItemsScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), getPersistInProgressKey(destinationUuid, destinationDevice)), - List.of(String.valueOf(limit).getBytes())); + List.of(String.valueOf(limit).getBytes(StandardCharsets.UTF_8))); final List messageEntities; @@ -172,6 +184,35 @@ public class RedisClusterMessagesCache implements UserMessagesCache { }); } + @SuppressWarnings("unchecked") + @VisibleForTesting + List getMessagesToPersist(final UUID accountUuid, final long destinationDevice, final int limit) { + return Metrics.timer(GET_TIMER_NAME).record(() -> { + final List queueItems = (List)getItemsScript.executeBinary(List.of(getMessageQueueKey(accountUuid, destinationDevice), + getPersistInProgressKey(accountUuid, destinationDevice)), + List.of(String.valueOf(limit).getBytes(StandardCharsets.UTF_8))); + + final List envelopes; + + if (queueItems.size() % 2 == 0) { + envelopes = new ArrayList<>(queueItems.size() / 2); + + for (int i = 0; i < queueItems.size(); i += 2) { + try { + envelopes.add(MessageProtos.Envelope.parseFrom(queueItems.get(i))); + } catch (InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + } + } else { + logger.error("\"Get messages\" operation returned a list with a non-even number of elements."); + envelopes = Collections.emptyList(); + } + + return envelopes; + }); + } + @Override public void clear(final String destination, final UUID destinationUuid) { // TODO Remove null check in a fully UUID-based world @@ -191,7 +232,27 @@ public class RedisClusterMessagesCache implements UserMessagesCache { Collections.emptyList())); } - private static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) { + int getNextSlotToPersist() { + return (int)(redisCluster.withWriteCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY)) % SlotHash.SLOT_COUNT); + } + + List getQueuesToPersist(final int slot, final Instant maxTime, final int limit) { + //noinspection unchecked + return (List)getQueuesToPersistScript.execute(List.of(new String(getQueueIndexKey(slot), StandardCharsets.UTF_8)), + List.of(String.valueOf(maxTime.toEpochMilli()), + String.valueOf(limit))); + } + + void lockQueueForPersistence(final String queue) { + redisCluster.useBinaryWriteCluster(connection -> connection.sync().setex(getPersistInProgressKey(queue), 30, LOCK_VALUE)); + } + + void unlockQueueForPersistence(final String queue) { + redisCluster.useBinaryWriteCluster(connection -> connection.sync().del(getPersistInProgressKey(queue))); + } + + @VisibleForTesting + static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) { return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } @@ -199,11 +260,29 @@ public class RedisClusterMessagesCache implements UserMessagesCache { return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } - private byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) { - return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(accountUuid.toString() + "::" + deviceId) + "}").getBytes(StandardCharsets.UTF_8); + private static byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) { + return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId)); } - private byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) { - return ("user_queue_persisting::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); + private static byte[] getQueueIndexKey(final int slot) { + return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8); + } + + private static byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) { + return getPersistInProgressKey(accountUuid + "::" + deviceId); + } + + private static byte[] getPersistInProgressKey(final String queueName) { + return ("user_queue_persisting::{" + queueName + "}").getBytes(StandardCharsets.UTF_8); + } + + static UUID getAccountUuidFromQueueName(final String queueName) { + final int startOfHashTag = queueName.indexOf('{'); + + return UUID.fromString(queueName.substring(startOfHashTag + 1, queueName.indexOf("::", startOfHashTag))); + } + + static long getDeviceIdFromQueueName(final String queueName) { + return Long.parseLong(queueName.substring(queueName.lastIndexOf("::") + 2, queueName.lastIndexOf('}'))); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java index bb02d56dd..fe041733b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java @@ -1,8 +1,6 @@ package org.whispersystems.textsecuregcm.util; import io.lettuce.core.cluster.SlotHash; -import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; -import io.lettuce.core.cluster.models.partitions.RedisClusterNode; public class RedisClusterUtil { @@ -23,15 +21,7 @@ public class RedisClusterUtil { } } - /** - * Returns a short Redis hash tag that maps to the same Redis cluster slot as the given key. - * - * @param key the key for which to find a matching hash tag - * @return a Redis hash tag that maps to the same Redis cluster slot as the given key - * - * @see Redis Cluster Specification - Keys hash tags - */ - public static String getMinimalHashTag(final String key) { - return HASHES_BY_SLOT[SlotHash.getSlot(key)]; + public static String getMinimalHashTag(final int slot) { + return HASHES_BY_SLOT[slot]; } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagePersisterTest.java new file mode 100644 index 000000000..e0c62fe7a --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagePersisterTest.java @@ -0,0 +1,189 @@ +package org.whispersystems.textsecuregcm.storage; + +import com.google.protobuf.ByteString; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.Before; +import org.junit.Test; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.push.PushSender; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.UUID; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class RedisClusterMessagePersisterTest { + + private RedisClusterMessagesCache messagesCache; + private Messages messagesDatabase; + private RedisClusterMessagePersister messagePersister; + private AccountsManager accountsManager; + + private long serialTimestamp = 0; + + private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID(); + private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234"; + private static final long DESTINATION_DEVICE_ID = 7; + + private static final Duration PERSIST_DELAY = Duration.ofMinutes(5); + + private static final Random RANDOM = new Random(); + + @Before + public void setUp() { + messagesCache = mock(RedisClusterMessagesCache.class); + messagesDatabase = mock(Messages.class); + accountsManager = mock(AccountsManager.class); + + final Account account = mock(Account.class); + + when(accountsManager.get(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(account)); + when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER); + + messagePersister = new RedisClusterMessagePersister(messagesCache, messagesDatabase, mock(PubSubManager.class), mock(PushSender.class), accountsManager, PERSIST_DELAY); + } + + @Test + public void testPersistNextQueuesNoQueues() { + final int slot = 7; + + when(messagesCache.getNextSlotToPersist()).thenReturn(slot); + when(messagesCache.getQueuesToPersist(eq(slot), any(Instant.class), eq(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT))).thenReturn(Collections.emptyList()); + + messagePersister.persistNextQueues(Instant.now()); + + verify(messagesCache, never()).lockQueueForPersistence(any()); + } + + @Test + public void testPersistNextQueuesSingleQueue() { + final int slot = 7; + final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + + when(messagesCache.getNextSlotToPersist()).thenReturn(slot); + when(messagesCache.getQueuesToPersist(eq(slot), any(Instant.class), eq(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT))).thenReturn(List.of(queueName)); + + messagePersister.persistNextQueues(Instant.now()); + + verify(messagesCache).lockQueueForPersistence(queueName); + } + + @Test + public void testPersistNextQueuesMultiplePages() { + final int slot = 7; + final int queueCount = RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 3; + + final List queues = new ArrayList<>(queueCount); + + for (int i = 0; i < queueCount; i++) { + final String queueName = generateRandomQueueName(); + final UUID accountUuid = RedisClusterMessagesCache.getAccountUuidFromQueueName(queueName); + + queues.add(queueName); + + final Account account = mock(Account.class); + + when(accountsManager.get(accountUuid)).thenReturn(Optional.of(account)); + when(account.getNumber()).thenReturn("+1" + RandomStringUtils.randomNumeric(10)); + } + + when(messagesCache.getNextSlotToPersist()).thenReturn(slot); + when(messagesCache.getQueuesToPersist(eq(slot), any(Instant.class), eq(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT))) + .thenReturn(queues.subList(0, RedisClusterMessagePersister.QUEUE_BATCH_LIMIT)) + .thenReturn(queues.subList(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT, RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 2)) + .thenReturn(queues.subList(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 2, RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 3)) + .thenReturn(Collections.emptyList()); + + messagePersister.persistNextQueues(Instant.now()); + + verify(messagesCache, times(queueCount)).lockQueueForPersistence(any()); + } + + @Test + public void testPersistQueueNoMessages() { + final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + + when(messagesCache.getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT)).thenReturn(Collections.emptyList()); + + messagePersister.persistQueue(queueName); + + verify(messagesCache).lockQueueForPersistence(queueName); + verify(messagesCache).getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT); + verify(messagesDatabase, never()).store(any(), any(), any(), anyLong()); + verify(messagesCache, never()).remove(anyString(), any(UUID.class), anyLong(), any(UUID.class)); + verify(messagesCache).unlockQueueForPersistence(queueName); + } + + @Test + public void testPersistQueueSingleMessage() { + final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + + final MessageProtos.Envelope message = generateRandomMessage(); + + when(messagesCache.getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT)).thenReturn(List.of(message)); + + messagePersister.persistQueue(queueName); + + verify(messagesCache).lockQueueForPersistence(queueName); + verify(messagesCache).getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT); + verify(messagesDatabase).store(UUID.fromString(message.getServerGuid()), message, DESTINATION_ACCOUNT_NUMBER, DESTINATION_DEVICE_ID); + verify(messagesCache).remove(DESTINATION_ACCOUNT_NUMBER, DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, UUID.fromString(message.getServerGuid())); + verify(messagesCache).unlockQueueForPersistence(queueName); + } + + @Test + public void testPersistQueueMultiplePages() { + final int messageCount = RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 3; + final List messagesInQueue = new ArrayList<>(messageCount); + + for (int i = 0; i < messageCount; i++) { + messagesInQueue.add(generateRandomMessage()); + } + + when(messagesCache.getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT)) + .thenReturn(messagesInQueue.subList(0, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT)) + .thenReturn(messagesInQueue.subList(RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 2)) + .thenReturn(messagesInQueue.subList(RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 2, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 3)) + .thenReturn(Collections.emptyList()); + + final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + + messagePersister.persistQueue(queueName); + + verify(messagesCache).lockQueueForPersistence(queueName); + verify(messagesCache, times(4)).getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT); + verify(messagesDatabase, times(messageCount)).store(any(UUID.class), any(MessageProtos.Envelope.class), eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_DEVICE_ID)); + verify(messagesCache, times(messageCount)).remove(eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID), any(UUID.class)); + verify(messagesCache).unlockQueueForPersistence(queueName); + } + + private MessageProtos.Envelope generateRandomMessage() { + final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() + .setTimestamp(serialTimestamp++) + .setServerTimestamp(serialTimestamp++) + .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) + .setType(MessageProtos.Envelope.Type.CIPHERTEXT) + .setServerGuid(UUID.randomUUID().toString()); + + return envelopeBuilder.build(); + } + + private String generateRandomQueueName() { + return String.format("user_queue::{%s::%d}", UUID.randomUUID().toString(), RANDOM.nextInt(10)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java index 6efe93ef9..7efc5ca06 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java @@ -1,13 +1,18 @@ package org.whispersystems.textsecuregcm.storage; +import io.lettuce.core.cluster.SlotHash; import junitparams.Parameters; import org.junit.Before; import org.junit.Test; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.List; import java.util.UUID; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest { @@ -52,4 +57,33 @@ public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest { // We're happy as long as this doesn't throw an exception messagesCache.clear(DESTINATION_ACCOUNT, null); } + + @Test + public void testGetAccountFromQueueName() { + assertEquals(DESTINATION_UUID, + RedisClusterMessagesCache.getAccountUuidFromQueueName(new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8))); + } + + @Test + public void testGetDeviceIdFromQueueName() { + assertEquals(DESTINATION_DEVICE_ID, + RedisClusterMessagesCache.getDeviceIdFromQueueName(new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8))); + } + + @Test + @Parameters({"true", "false"}) + public void testGetQueuesToPersist(final boolean sealedSender) { + final UUID messageGuid = UUID.randomUUID(); + + messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender)); + final int slot = SlotHash.getSlot(DESTINATION_UUID.toString() + "::" + DESTINATION_DEVICE_ID); + + assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty()); + + final List queues = messagesCache.getQueuesToPersist(slot, Instant.now().plusSeconds(60), 100); + + assertEquals(1, queues.size()); + assertEquals(DESTINATION_UUID, RedisClusterMessagesCache.getAccountUuidFromQueueName(queues.get(0))); + assertEquals(DESTINATION_DEVICE_ID, RedisClusterMessagesCache.getDeviceIdFromQueueName(queues.get(0))); + } }