From 8d3316ccd60ba85896b08078c6ffe476e589d750 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Thu, 6 Aug 2020 11:21:55 -0400 Subject: [PATCH] Listen for new messages via keyspace notifications. --- .../textsecuregcm/WhisperServerService.java | 3 +- .../push/ClientPresenceManager.java | 11 +-- .../storage/MessageAvailabilityListener.java | 13 +++ .../storage/MessagesManager.java | 8 ++ .../storage/RedisClusterMessagesCache.java | 88 +++++++++++++++++- .../textsecuregcm/util/RedisClusterUtil.java | 33 +++++++ .../AuthenticatedConnectListener.java | 2 + .../websocket/WebSocketConnection.java | 20 +++- .../RedisClusterMessagesCacheTest.java | 92 +++++++++++++++++-- .../util/RedisClusterUtilTest.java | 64 +++++++++++++ 10 files changed, 313 insertions(+), 21 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index d1d4beb9b..40cf0d53a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -338,6 +338,7 @@ public class WhisperServerService extends Application(1_000)).build(); ExecutorService websocketExperimentExecutor = environment.lifecycle().executorService("websocketPresenceExperiment").maxThreads(8).workQueue(new ArrayBlockingQueue<>(1_000)).build(); ClientPresenceManager clientPresenceManager = new ClientPresenceManager(messagesCacheCluster, clientPresenceExecutor); @@ -349,7 +350,7 @@ public class WhisperServerService extends Application { - final String configuredKeyspaceNotifications = connection.sync().configGet("notify-keyspace-events").getOrDefault("notify-keyspace-events", ""); - - for (final char requiredNotificationType : new char[] {'K', '$'}) { - if (configuredKeyspaceNotifications.indexOf(requiredNotificationType) == -1) { - throw new IllegalStateException("Required keyspace notification type not configured. Need at least K$, but is actually: " + configuredKeyspaceNotifications); - } - } - connection.addListener(this); connection.getResources().eventBus().get() .filter(event -> event instanceof ClusterTopologyChangedEvent) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java new file mode 100644 index 000000000..3eeffc089 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java @@ -0,0 +1,13 @@ +package org.whispersystems.textsecuregcm.storage; + +/** + * A message availability listener is notified when new messages are available for a specific device for a specific + * account. Availability listeners are also notified when messages are moved from the message cache to long-term storage + * as an optimization hint to implementing classes. + */ +public interface MessageAvailabilityListener { + + void handleNewMessagesAvailable(); + + void handleMessagesPersisted(); +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index cbc14698d..c9e526c0f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -127,4 +127,12 @@ public class MessagesManager { final Optional maybeRemovedMessage = messagesCache.remove(destination, destinationUuid, deviceId, id); removeByIdExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationUuid, deviceId, id), experimentExecutor); } + + public void addMessageAvailabilityListener(final UUID destinationUuid, final long deviceId, final MessageAvailabilityListener listener) { + clusterMessagesCache.addMessageAvailabilityListener(destinationUuid, deviceId, listener); + } + + public void removeMessageAvailabilityListener(final MessageAvailabilityListener listener) { + clusterMessagesCache.removeMessageAvailabilityListener(listener); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java index 39d8b4cc3..3e8b63559 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java @@ -4,6 +4,9 @@ import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.InvalidProtocolBufferException; import io.lettuce.core.ScriptOutputType; import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; +import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; import io.micrometer.core.instrument.Metrics; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -18,15 +21,20 @@ import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.ExecutorService; import static com.codahale.metrics.MetricRegistry.name; -public class RedisClusterMessagesCache implements UserMessagesCache { +public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter implements UserMessagesCache { private final FaultTolerantRedisCluster redisCluster; + private final ExecutorService notificationExecutorService; private final ClusterLuaScript insertScript; private final ClusterLuaScript removeByIdScript; @@ -36,9 +44,15 @@ public class RedisClusterMessagesCache implements UserMessagesCache { private final ClusterLuaScript removeQueueScript; private final ClusterLuaScript getQueuesToPersistScript; + private final Map messageListenersByQueueName = new HashMap<>(); + private final Map queueNamesByMessageListener = new IdentityHashMap<>(); + static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); + private static final String QUEUE_KEYSPACE_PATTERN = "__keyspace@0__:user_queue::*"; + private static final String PERSISTING_KEYSPACE_PATTERN = "__keyspace@0__:user_queue_persisting::*"; + private static final String INSERT_TIMER_NAME = name(RedisClusterMessagesCache.class, "insert"); private static final String REMOVE_TIMER_NAME = name(RedisClusterMessagesCache.class, "remove"); private static final String GET_TIMER_NAME = name(RedisClusterMessagesCache.class, "get"); @@ -51,9 +65,10 @@ public class RedisClusterMessagesCache implements UserMessagesCache { private static final Logger logger = LoggerFactory.getLogger(RedisClusterMessagesCache.class); - public RedisClusterMessagesCache(final FaultTolerantRedisCluster redisCluster) throws IOException { + public RedisClusterMessagesCache(final FaultTolerantRedisCluster redisCluster, final ExecutorService notificationExecutorService) throws IOException { - this.redisCluster = redisCluster; + this.redisCluster = redisCluster; + this.notificationExecutorService = notificationExecutorService; this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE); @@ -62,6 +77,24 @@ public class RedisClusterMessagesCache implements UserMessagesCache { this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI); this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS); this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua", ScriptOutputType.MULTI); + + RedisClusterUtil.assertKeyspaceNotificationsConfigured(redisCluster, "K$gz"); + + redisCluster.usePubSubConnection(connection -> { + connection.addListener(this); + connection.getResources().eventBus().get() + .filter(event -> event instanceof ClusterTopologyChangedEvent) + .handle((event, sink) -> { + resubscribeAll(); + sink.next(event); + }); + + connection.sync().masters().commands().psubscribe(QUEUE_KEYSPACE_PATTERN, PERSISTING_KEYSPACE_PATTERN); + }); + } + + private void resubscribeAll() { + redisCluster.usePubSubConnection(connection -> connection.sync().masters().commands().psubscribe(QUEUE_KEYSPACE_PATTERN, PERSISTING_KEYSPACE_PATTERN)); } @Override @@ -251,6 +284,55 @@ public class RedisClusterMessagesCache implements UserMessagesCache { redisCluster.useBinaryWriteCluster(connection -> connection.sync().del(getPersistInProgressKey(queue))); } + public void addMessageAvailabilityListener(final UUID destinationUuid, final long deviceId, final MessageAvailabilityListener listener) { + final String queueName = getQueueName(destinationUuid, deviceId); + + synchronized (messageListenersByQueueName) { + messageListenersByQueueName.put(queueName, listener); + queueNamesByMessageListener.put(listener, queueName); + } + } + + public void removeMessageAvailabilityListener(final MessageAvailabilityListener listener) { + synchronized (messageListenersByQueueName) { + final String queueName = queueNamesByMessageListener.remove(listener); + + if (queueName != null) { + messageListenersByQueueName.remove(queueName); + } + } + } + + @Override + public void message(final RedisClusterNode node, final String pattern, final String channel, final String message) { + if (QUEUE_KEYSPACE_PATTERN.equals(pattern) && "zadd".equals(message)) { + notificationExecutorService.execute(() -> findListener(channel).ifPresent(MessageAvailabilityListener::handleNewMessagesAvailable)); + } else if (PERSISTING_KEYSPACE_PATTERN.equals(pattern) && "del".equals(message)) { + notificationExecutorService.execute(() -> findListener(channel).ifPresent(MessageAvailabilityListener::handleMessagesPersisted)); + } + } + + private Optional findListener(final String keyspaceChannel) { + final String queueName = getQueueNameFromKeyspaceChannel(keyspaceChannel); + + synchronized (messageListenersByQueueName) { + return Optional.ofNullable(messageListenersByQueueName.get(queueName)); + } + } + + @VisibleForTesting + static String getQueueName(final UUID accountUuid, final long deviceId) { + return accountUuid + "::" + deviceId; + } + + @VisibleForTesting + static String getQueueNameFromKeyspaceChannel(final String channel) { + final int startOfHashTag = channel.indexOf('{'); + final int endOfHashTag = channel.lastIndexOf('}'); + + return channel.substring(startOfHashTag + 1, endOfHashTag); + } + @VisibleForTesting static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) { return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java index fe041733b..803ff7b86 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java @@ -1,6 +1,7 @@ package org.whispersystems.textsecuregcm.util; import io.lettuce.core.cluster.SlotHash; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; public class RedisClusterUtil { @@ -21,7 +22,39 @@ public class RedisClusterUtil { } } + /** + * Returns a Redis hash tag that maps to the given cluster slot. + * + * @param slot the Redis cluster slot for which to retrieve a hash tag + * + * @return a Redis hash tag that maps to the given cluster slot + * + * @see Redis Cluster Specification - Keys hash tags + */ public static String getMinimalHashTag(final int slot) { return HASHES_BY_SLOT[slot]; } + + /** + * Asserts that a Redis cluster is configured to generate (at least) a specific set of keyspace notification events. + * + * @param redisCluster the Redis cluster to check for the required keyspace notification configuration + * @param requiredKeyspaceNotifications a string representing the required keyspace notification events (e.g. "Kg$lz") + * + * @throws IllegalStateException if the given Redis cluster is not configured to generate the required keyspace + * notification events + * + * @see Redis Keyspace Notifications - Configuration + */ + public static void assertKeyspaceNotificationsConfigured(final FaultTolerantRedisCluster redisCluster, final String requiredKeyspaceNotifications) { + final String configuredKeyspaceNotifications = redisCluster.withReadCluster(connection -> connection.sync().configGet("notify-keyspace-events")) + .getOrDefault("notify-keyspace-events", "") + .replace("A", "g$lshztxe"); + + for (final char requiredNotificationType : requiredKeyspaceNotifications.toCharArray()) { + if (configuredKeyspaceNotifications.indexOf(requiredNotificationType) == -1) { + throw new IllegalStateException(String.format("Required at least \"%s\" for keyspace notifications, but only had \"%s\".", requiredKeyspaceNotifications, configuredKeyspaceNotifications)); + } + } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 1332f3b8d..acc750eeb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -76,6 +76,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { openWebsocketCounter.inc(); RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device)); clientPresenceManager.setPresent(account.getUuid(), device.getId(), explicitDisplacementMeter::mark); + messagesManager.addMessageAvailabilityListener(account.getUuid(), device.getId(), connection); pubSubManager.publish(address, connectMessage); pubSubManager.subscribe(address, connection); @@ -85,6 +86,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { openWebsocketCounter.dec(); pubSubManager.unsubscribe(address, connection); clientPresenceManager.clearPresence(account.getUuid(), device.getId()); + messagesManager.removeMessageAvailabilityListener(connection); timer.stop(); } }); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 24860dd03..404328544 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -15,12 +15,12 @@ import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; -import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.MessageAvailabilityListener; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.TimestampHeaderUtil; @@ -39,12 +39,16 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") -public class WebSocketConnection implements DispatchChannel { +public class WebSocketConnection implements DispatchChannel, MessageAvailabilityListener { private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); public static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration")); private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message")); private static final Meter pubSubDisplacementMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubDisplacement")); + private static final Meter messageAvailableMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesAvailable")); + private static final Meter messagesPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesPersisted")); + private static final Meter pubSubNewMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubNewMessage")); + private static final Meter pubSubPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubPersisted")); private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); @@ -81,9 +85,11 @@ public class WebSocketConnection implements DispatchChannel { switch (pubSubMessage.getType().getNumber()) { case PubSubMessage.Type.QUERY_DB_VALUE: + pubSubPersistedMeter.mark(); processStoredMessages(); break; case PubSubMessage.Type.DELIVER_VALUE: + pubSubNewMessageMeter.mark(); sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.empty(), false); break; case PubSubMessage.Type.CONNECTED_VALUE: @@ -214,6 +220,16 @@ public class WebSocketConnection implements DispatchChannel { } } + @Override + public void handleNewMessagesAvailable() { + messageAvailableMeter.mark(); + } + + @Override + public void handleMessagesPersisted() { + messagesPersistedMeter.mark(); + } + private static class StoredMessageInfo { private final long id; private final boolean cached; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java index 7efc5ca06..9bc452d8f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java @@ -5,11 +5,14 @@ import junitparams.Parameters; import org.junit.Before; import org.junit.Test; -import java.io.IOException; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.List; import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -20,6 +23,7 @@ public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest { private static final UUID DESTINATION_UUID = UUID.randomUUID(); private static final int DESTINATION_DEVICE_ID = 7; + private ExecutorService notificationExecutorService; private RedisClusterMessagesCache messagesCache; @Override @@ -27,13 +31,18 @@ public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest { public void setUp() throws Exception { super.setUp(); - try { - messagesCache = new RedisClusterMessagesCache(getRedisCluster()); - } catch (final IOException e) { - throw new RuntimeException(e); - } + getRedisCluster().useWriteCluster(connection -> connection.sync().masters().commands().configSet("notify-keyspace-events", "K$gz")); - getRedisCluster().useWriteCluster(connection -> connection.sync().flushall()); + notificationExecutorService = Executors.newSingleThreadExecutor(); + messagesCache = new RedisClusterMessagesCache(getRedisCluster(), notificationExecutorService); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + + notificationExecutorService.shutdown(); + notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS); } @Override @@ -70,6 +79,12 @@ public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest { RedisClusterMessagesCache.getDeviceIdFromQueueName(new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8))); } + @Test + public void testGetQueueNameFromKeyspaceChannel() { + assertEquals("1b363a31-a429-4fb6-8959-984a025e72ff::7", + RedisClusterMessagesCache.getQueueNameFromKeyspaceChannel("__keyspace@0__:user_queue::{1b363a31-a429-4fb6-8959-984a025e72ff::7}")); + } + @Test @Parameters({"true", "false"}) public void testGetQueuesToPersist(final boolean sealedSender) { @@ -86,4 +101,67 @@ public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest { assertEquals(DESTINATION_UUID, RedisClusterMessagesCache.getAccountUuidFromQueueName(queues.get(0))); assertEquals(DESTINATION_DEVICE_ID, RedisClusterMessagesCache.getDeviceIdFromQueueName(queues.get(0))); } + + @Test(timeout = 5_000L) + public void testNotifyListenerNewMessage() throws InterruptedException { + final AtomicBoolean notified = new AtomicBoolean(false); + final UUID messageGuid = UUID.randomUUID(); + + final MessageAvailabilityListener listener = new MessageAvailabilityListener() { + @Override + public void handleNewMessagesAvailable() { + synchronized (notified) { + notified.set(true); + notified.notifyAll(); + } + } + + @Override + public void handleMessagesPersisted() { + } + }; + + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); + messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, true)); + + synchronized (notified) { + while (!notified.get()) { + notified.wait(); + } + } + + assertTrue(notified.get()); + } + + @Test(timeout = 5_000L) + public void testNotifyListenerPersisted() throws InterruptedException { + final AtomicBoolean notified = new AtomicBoolean(false); + + final MessageAvailabilityListener listener = new MessageAvailabilityListener() { + @Override + public void handleNewMessagesAvailable() { + } + + @Override + public void handleMessagesPersisted() { + synchronized (notified) { + notified.set(true); + notified.notifyAll(); + } + } + }; + + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); + + messagesCache.lockQueueForPersistence(RedisClusterMessagesCache.getQueueName(DESTINATION_UUID, DESTINATION_DEVICE_ID)); + messagesCache.unlockQueueForPersistence(RedisClusterMessagesCache.getQueueName(DESTINATION_UUID, DESTINATION_DEVICE_ID)); + + synchronized (notified) { + while (!notified.get()) { + notified.wait(); + } + } + + assertTrue(notified.get()); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java new file mode 100644 index 000000000..f4d144ab8 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/RedisClusterUtilTest.java @@ -0,0 +1,64 @@ +package org.whispersystems.textsecuregcm.util; + +import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.api.sync.Executions; +import io.lettuce.core.cluster.api.sync.NodeSelection; +import io.lettuce.core.cluster.api.sync.NodeSelectionCommands; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; +import redis.embedded.Redis; + +import java.util.Map; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(JUnitParamsRunner.class) +public class RedisClusterUtilTest { + + @Test + public void testGetMinimalHashTag() { + for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) { + assertEquals(slot, SlotHash.getSlot(RedisClusterUtil.getMinimalHashTag(slot))); + } + } + + @SuppressWarnings("unchecked") + @Test + @Parameters(method = "argumentsForTestAssertKeyspaceNotificationsConfigured") + public void testAssertKeyspaceNotificationsConfigured(final String requiredKeyspaceNotifications, final String configuerdKeyspaceNotifications, final boolean expectException) { + final RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); + final FaultTolerantRedisCluster redisCluster = RedisClusterHelper.buildMockRedisCluster(commands); + + when(commands.configGet("notify-keyspace-events")).thenReturn(Map.of("notify-keyspace-events", configuerdKeyspaceNotifications)); + + if (expectException) { + try { + RedisClusterUtil.assertKeyspaceNotificationsConfigured(redisCluster, requiredKeyspaceNotifications); + fail("Expected IllegalStateException"); + } catch (final IllegalStateException ignored) { + } + } else { + RedisClusterUtil.assertKeyspaceNotificationsConfigured(redisCluster, requiredKeyspaceNotifications); + } + } + + @SuppressWarnings("unused") + private Object argumentsForTestAssertKeyspaceNotificationsConfigured() { + return new Object[] { + new Object[] { "K$gz", "", true }, + new Object[] { "K$gz", "K$gz", false }, + new Object[] { "K$gz", "K$gzl", false }, + new Object[] { "K$gz", "KA", false }, + new Object[] { "", "A", false }, + new Object[] { "", "", false }, + }; + } +}