diff --git a/service/pom.xml b/service/pom.xml
index 7af2b1700..9cd916669 100644
--- a/service/pom.xml
+++ b/service/pom.xml
@@ -172,6 +172,13 @@
test
+
+ org.signal
+ embedded-redis
+ 0.8.1
+ test
+
+
com.fasterxml.uuid
java-uuid-generator
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java
index 494288d3b..c6361ffb5 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java
@@ -1,5 +1,6 @@
package org.whispersystems.textsecuregcm.redis;
+import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.RedisNoScriptException;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.api.sync.RedisCommands;
@@ -36,7 +37,8 @@ public class ClusterLuaScript {
}
}
- private ClusterLuaScript(final FaultTolerantRedisCluster redisCluster, final String script, final ScriptOutputType scriptOutputType) {
+ @VisibleForTesting
+ ClusterLuaScript(final FaultTolerantRedisCluster redisCluster, final String script, final ScriptOutputType scriptOutputType) {
this.redisCluster = redisCluster;
this.scriptOutputType = scriptOutputType;
this.script = script;
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java
new file mode 100644
index 000000000..211d82b1b
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java
@@ -0,0 +1,152 @@
+package org.whispersystems.textsecuregcm.redis;
+
+import io.lettuce.core.RedisClient;
+import io.lettuce.core.RedisURI;
+import io.lettuce.core.api.StatefulRedisConnection;
+import io.lettuce.core.api.sync.RedisCommands;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
+import redis.embedded.RedisServer;
+
+import java.io.File;
+import java.io.IOException;
+import java.net.ServerSocket;
+import java.net.URISyntaxException;
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.junit.Assume.assumeFalse;
+
+/**
+ * An abstract base class that assembles a real (local!) Redis cluster and provides a client to that cluster for
+ * subclasses.
+ */
+public abstract class AbstractRedisClusterTest {
+
+ private static final int MAX_SLOT = 16384;
+ private static final int NODE_COUNT = 2;
+
+ private static RedisServer[] clusterNodes;
+
+ private FaultTolerantRedisCluster redisCluster;
+
+ @BeforeClass
+ public static void setUpBeforeClass() throws IOException, URISyntaxException, InterruptedException {
+ assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows"));
+
+ clusterNodes = new RedisServer[NODE_COUNT];
+
+ for (int i = 0; i < NODE_COUNT; i++) {
+ clusterNodes[i] = buildClusterNode(getNextPort());
+ clusterNodes[i].start();
+ }
+
+ assembleCluster(clusterNodes);
+ }
+
+ @Before
+ public void setUp() {
+ final List urls = Arrays.stream(clusterNodes)
+ .map(node -> String.format("redis://127.0.0.1:%d", node.ports().get(0)))
+ .collect(Collectors.toList());
+
+ redisCluster = new FaultTolerantRedisCluster("test-cluster", urls, Duration.ofSeconds(2), new CircuitBreakerConfiguration());
+ }
+
+ protected FaultTolerantRedisCluster getRedisCluster() {
+ return redisCluster;
+ }
+
+ @After
+ public void tearDown() {
+ redisCluster.stop();
+ }
+
+ @AfterClass
+ public static void tearDownAfterClass() {
+ for (final RedisServer node : clusterNodes) {
+ node.stop();
+ }
+ }
+
+ private static RedisServer buildClusterNode(final int port) throws IOException, URISyntaxException {
+ final File clusterConfigFile = File.createTempFile("redis", ".conf");
+ final File rdbFile = File.createTempFile("redis", ".rdb");
+
+ // Redis struggles with existing-but-empty RDB files
+ rdbFile.delete();
+ rdbFile.deleteOnExit();
+ clusterConfigFile.deleteOnExit();
+
+ return RedisServer.builder()
+ .setting("cluster-enabled yes")
+ .setting("cluster-config-file " + clusterConfigFile.getAbsolutePath())
+ .setting("cluster-node-timeout 5000")
+ .setting("dir " + System.getProperty("java.io.tmpdir"))
+ .setting("dbfilename " + rdbFile.getName())
+ .port(port)
+ .build();
+ }
+
+ private static void assembleCluster(final RedisServer... nodes) throws InterruptedException {
+ final RedisClient meetClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0)));
+
+ try {
+ final StatefulRedisConnection connection = meetClient.connect();
+ final RedisCommands commands = connection.sync();
+
+ for (int i = 1; i < nodes.length; i++) {
+ commands.clusterMeet("127.0.0.1", nodes[i].ports().get(0));
+ }
+ } finally {
+ meetClient.shutdown();
+ }
+
+ final int slotsPerNode = MAX_SLOT / nodes.length;
+
+ for (int i = 0; i < nodes.length; i++) {
+ final int startInclusive = i * slotsPerNode;
+ final int endExclusive = i == nodes.length - 1 ? MAX_SLOT : (i + 1) * slotsPerNode;
+
+ final RedisClient assignSlotClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0)));
+
+ try {
+ final int[] slots = new int[endExclusive - startInclusive];
+
+ for (int s = startInclusive; s < endExclusive; s++) {
+ slots[s - startInclusive] = s;
+ }
+
+ assignSlotClient.connect().sync().clusterAddSlots(slots);
+ } finally {
+ assignSlotClient.shutdown();
+ }
+ }
+
+ final RedisClient waitClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0)));
+ final StatefulRedisConnection connection = waitClient.connect();
+
+ try {
+ // CLUSTER INFO gives us a big blob of key-value pairs, but the one we're interested in is `cluster_state`.
+ // According to https://redis.io/commands/cluster-info, `cluster_state:ok` means that the node is ready to
+ // receive queries, all slots are assigned, and a majority of master nodes are reachable.
+ while (!connection.sync().clusterInfo().contains("cluster_state:ok")) {
+ Thread.sleep(500);
+ }
+ } finally {
+ waitClient.shutdown();
+ }
+ }
+
+ private static int getNextPort() throws IOException {
+ try (ServerSocket socket = new ServerSocket(0)) {
+ socket.setReuseAddress(false);
+ return socket.getLocalPort();
+ }
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java
new file mode 100644
index 000000000..56d2a753c
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java
@@ -0,0 +1,106 @@
+package org.whispersystems.textsecuregcm.redis;
+
+import io.lettuce.core.RedisNoScriptException;
+import io.lettuce.core.ScriptOutputType;
+import io.lettuce.core.api.sync.RedisCommands;
+import io.lettuce.core.cluster.SlotHash;
+import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
+import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
+import org.junit.Test;
+import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class ClusterLuaScriptTest extends AbstractRedisClusterTest {
+
+ @Test
+ public void testExecuteMovedKey() {
+ final String key = "key";
+ final String value = "value";
+
+ final FaultTolerantRedisCluster redisCluster = getRedisCluster();
+
+ final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])", ScriptOutputType.VALUE);
+
+ assertEquals("OK", script.execute(List.of(key), List.of("value")));
+ assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key)));
+
+ final int slot = SlotHash.getSlot(key);
+
+ final int sourcePort = redisCluster.withWriteCluster(connection -> connection.sync().nodes(node -> node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.MASTER)).node(0).getUri().getPort());
+ final RedisCommands sourceCommands = redisCluster.withWriteCluster(connection -> connection.sync().nodes(node -> node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.MASTER)).commands(0));
+ final RedisCommands destinationCommands = redisCluster.withWriteCluster(connection -> connection.sync().nodes(node -> !node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.MASTER)).commands(0));
+
+ destinationCommands.clusterSetSlotImporting(slot, sourceCommands.clusterMyId());
+
+ assertEquals("OK", script.execute(List.of(key), List.of("value")));
+ assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key)));
+
+ sourceCommands.clusterSetSlotMigrating(slot, destinationCommands.clusterMyId());
+
+ assertEquals("OK", script.execute(List.of(key), List.of("value")));
+ assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key)));
+
+ for (final String migrateKey : sourceCommands.clusterGetKeysInSlot(slot, Integer.MAX_VALUE)) {
+ destinationCommands.migrate("127.0.0.1", sourcePort, migrateKey, 0, 1000);
+ }
+
+ assertEquals("OK", script.execute(List.of(key), List.of("value")));
+ assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key)));
+
+ destinationCommands.clusterSetSlotNode(slot, destinationCommands.clusterMyId());
+
+ assertEquals("OK", script.execute(List.of(key), List.of("value")));
+ assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key)));
+ }
+
+ @Test
+ public void testExecute() {
+ final RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class);
+ final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands);
+
+ final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
+ final String sha = "abc123";
+ final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
+ final List keys = List.of("key");
+ final List values = List.of("value");
+
+ when(commands.scriptLoad(script)).thenReturn(sha);
+ when(commands.evalsha(any(), any(), any(), any())).thenReturn("OK");
+
+ new ClusterLuaScript(mockCluster, script, scriptOutputType).execute(keys, values);
+
+ verify(commands).scriptLoad(script);
+ verify(commands).evalsha(sha, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0]));
+ }
+
+ @Test
+ public void testExecuteNoScriptException() {
+ final RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class);
+ final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands);
+
+ final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
+ final String sha = "abc123";
+ final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
+ final List keys = List.of("key");
+ final List values = List.of("value");
+
+ when(commands.scriptLoad(script)).thenReturn(sha);
+ when(commands.evalsha(any(), any(), any(), any()))
+ .thenThrow(new RedisNoScriptException("OH NO"))
+ .thenReturn("OK");
+
+ new ClusterLuaScript(mockCluster, script, scriptOutputType).execute(keys, values);
+
+ verify(commands, times(2)).scriptLoad(script);
+ verify(commands, times(2)).evalsha(sha, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0]));
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java
new file mode 100644
index 000000000..3b26b61aa
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java
@@ -0,0 +1,44 @@
+package org.whispersystems.textsecuregcm.tests.util;
+
+import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
+import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
+import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
+
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class RedisClusterHelper {
+
+ @SuppressWarnings("unchecked")
+ public static FaultTolerantRedisCluster buildMockRedisCluster(final RedisAdvancedClusterCommands commands) {
+ final FaultTolerantRedisCluster cluster = mock(FaultTolerantRedisCluster.class);
+ final StatefulRedisClusterConnection connection = mock(StatefulRedisClusterConnection.class);
+
+ when(connection.sync()).thenReturn(commands);
+
+ when(cluster.withReadCluster(any(Function.class))).thenAnswer(invocation -> {
+ return invocation.getArgument(0, Function.class).apply(connection);
+ });
+
+ doAnswer(invocation -> {
+ invocation.getArgument(0, Consumer.class).accept(connection);
+ return null;
+ }).when(cluster).useReadCluster(any(Consumer.class));
+
+ when(cluster.withWriteCluster(any(Function.class))).thenAnswer(invocation -> {
+ return invocation.getArgument(0, Function.class).apply(connection);
+ });
+
+ doAnswer(invocation -> {
+ invocation.getArgument(0, Consumer.class).accept(connection);
+ return null;
+ }).when(cluster).useWriteCluster(any(Consumer.class));
+
+ return cluster;
+ }
+}