From f5ddb0f1f82ea6d1685a888ed647aa8445b6dd2e Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Thu, 18 Jun 2020 17:12:31 -0400 Subject: [PATCH] Test ClusterLuaScript against a real Redis cluster. --- service/pom.xml | 7 + .../textsecuregcm/redis/ClusterLuaScript.java | 4 +- .../redis/AbstractRedisClusterTest.java | 152 ++++++++++++++++++ .../redis/ClusterLuaScriptTest.java | 106 ++++++++++++ .../tests/util/RedisClusterHelper.java | 44 +++++ 5 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java diff --git a/service/pom.xml b/service/pom.xml index 7af2b1700..9cd916669 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -172,6 +172,13 @@ test + + org.signal + embedded-redis + 0.8.1 + test + + com.fasterxml.uuid java-uuid-generator diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java index 494288d3b..c6361ffb5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java @@ -1,5 +1,6 @@ package org.whispersystems.textsecuregcm.redis; +import com.google.common.annotations.VisibleForTesting; import io.lettuce.core.RedisNoScriptException; import io.lettuce.core.ScriptOutputType; import io.lettuce.core.api.sync.RedisCommands; @@ -36,7 +37,8 @@ public class ClusterLuaScript { } } - private ClusterLuaScript(final FaultTolerantRedisCluster redisCluster, final String script, final ScriptOutputType scriptOutputType) { + @VisibleForTesting + ClusterLuaScript(final FaultTolerantRedisCluster redisCluster, final String script, final ScriptOutputType scriptOutputType) { this.redisCluster = redisCluster; this.scriptOutputType = scriptOutputType; this.script = script; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java new file mode 100644 index 000000000..211d82b1b --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java @@ -0,0 +1,152 @@ +package org.whispersystems.textsecuregcm.redis; + +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisURI; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import redis.embedded.RedisServer; + +import java.io.File; +import java.io.IOException; +import java.net.ServerSocket; +import java.net.URISyntaxException; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.Assume.assumeFalse; + +/** + * An abstract base class that assembles a real (local!) Redis cluster and provides a client to that cluster for + * subclasses. + */ +public abstract class AbstractRedisClusterTest { + + private static final int MAX_SLOT = 16384; + private static final int NODE_COUNT = 2; + + private static RedisServer[] clusterNodes; + + private FaultTolerantRedisCluster redisCluster; + + @BeforeClass + public static void setUpBeforeClass() throws IOException, URISyntaxException, InterruptedException { + assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows")); + + clusterNodes = new RedisServer[NODE_COUNT]; + + for (int i = 0; i < NODE_COUNT; i++) { + clusterNodes[i] = buildClusterNode(getNextPort()); + clusterNodes[i].start(); + } + + assembleCluster(clusterNodes); + } + + @Before + public void setUp() { + final List urls = Arrays.stream(clusterNodes) + .map(node -> String.format("redis://127.0.0.1:%d", node.ports().get(0))) + .collect(Collectors.toList()); + + redisCluster = new FaultTolerantRedisCluster("test-cluster", urls, Duration.ofSeconds(2), new CircuitBreakerConfiguration()); + } + + protected FaultTolerantRedisCluster getRedisCluster() { + return redisCluster; + } + + @After + public void tearDown() { + redisCluster.stop(); + } + + @AfterClass + public static void tearDownAfterClass() { + for (final RedisServer node : clusterNodes) { + node.stop(); + } + } + + private static RedisServer buildClusterNode(final int port) throws IOException, URISyntaxException { + final File clusterConfigFile = File.createTempFile("redis", ".conf"); + final File rdbFile = File.createTempFile("redis", ".rdb"); + + // Redis struggles with existing-but-empty RDB files + rdbFile.delete(); + rdbFile.deleteOnExit(); + clusterConfigFile.deleteOnExit(); + + return RedisServer.builder() + .setting("cluster-enabled yes") + .setting("cluster-config-file " + clusterConfigFile.getAbsolutePath()) + .setting("cluster-node-timeout 5000") + .setting("dir " + System.getProperty("java.io.tmpdir")) + .setting("dbfilename " + rdbFile.getName()) + .port(port) + .build(); + } + + private static void assembleCluster(final RedisServer... nodes) throws InterruptedException { + final RedisClient meetClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0))); + + try { + final StatefulRedisConnection connection = meetClient.connect(); + final RedisCommands commands = connection.sync(); + + for (int i = 1; i < nodes.length; i++) { + commands.clusterMeet("127.0.0.1", nodes[i].ports().get(0)); + } + } finally { + meetClient.shutdown(); + } + + final int slotsPerNode = MAX_SLOT / nodes.length; + + for (int i = 0; i < nodes.length; i++) { + final int startInclusive = i * slotsPerNode; + final int endExclusive = i == nodes.length - 1 ? MAX_SLOT : (i + 1) * slotsPerNode; + + final RedisClient assignSlotClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0))); + + try { + final int[] slots = new int[endExclusive - startInclusive]; + + for (int s = startInclusive; s < endExclusive; s++) { + slots[s - startInclusive] = s; + } + + assignSlotClient.connect().sync().clusterAddSlots(slots); + } finally { + assignSlotClient.shutdown(); + } + } + + final RedisClient waitClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0))); + final StatefulRedisConnection connection = waitClient.connect(); + + try { + // CLUSTER INFO gives us a big blob of key-value pairs, but the one we're interested in is `cluster_state`. + // According to https://redis.io/commands/cluster-info, `cluster_state:ok` means that the node is ready to + // receive queries, all slots are assigned, and a majority of master nodes are reachable. + while (!connection.sync().clusterInfo().contains("cluster_state:ok")) { + Thread.sleep(500); + } + } finally { + waitClient.shutdown(); + } + } + + private static int getNextPort() throws IOException { + try (ServerSocket socket = new ServerSocket(0)) { + socket.setReuseAddress(false); + return socket.getLocalPort(); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java new file mode 100644 index 000000000..56d2a753c --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java @@ -0,0 +1,106 @@ +package org.whispersystems.textsecuregcm.redis; + +import io.lettuce.core.RedisNoScriptException; +import io.lettuce.core.ScriptOutputType; +import io.lettuce.core.api.sync.RedisCommands; +import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; +import org.junit.Test; +import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; + +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ClusterLuaScriptTest extends AbstractRedisClusterTest { + + @Test + public void testExecuteMovedKey() { + final String key = "key"; + final String value = "value"; + + final FaultTolerantRedisCluster redisCluster = getRedisCluster(); + + final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])", ScriptOutputType.VALUE); + + assertEquals("OK", script.execute(List.of(key), List.of("value"))); + assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key))); + + final int slot = SlotHash.getSlot(key); + + final int sourcePort = redisCluster.withWriteCluster(connection -> connection.sync().nodes(node -> node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.MASTER)).node(0).getUri().getPort()); + final RedisCommands sourceCommands = redisCluster.withWriteCluster(connection -> connection.sync().nodes(node -> node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.MASTER)).commands(0)); + final RedisCommands destinationCommands = redisCluster.withWriteCluster(connection -> connection.sync().nodes(node -> !node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.MASTER)).commands(0)); + + destinationCommands.clusterSetSlotImporting(slot, sourceCommands.clusterMyId()); + + assertEquals("OK", script.execute(List.of(key), List.of("value"))); + assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key))); + + sourceCommands.clusterSetSlotMigrating(slot, destinationCommands.clusterMyId()); + + assertEquals("OK", script.execute(List.of(key), List.of("value"))); + assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key))); + + for (final String migrateKey : sourceCommands.clusterGetKeysInSlot(slot, Integer.MAX_VALUE)) { + destinationCommands.migrate("127.0.0.1", sourcePort, migrateKey, 0, 1000); + } + + assertEquals("OK", script.execute(List.of(key), List.of("value"))); + assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key))); + + destinationCommands.clusterSetSlotNode(slot, destinationCommands.clusterMyId()); + + assertEquals("OK", script.execute(List.of(key), List.of("value"))); + assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key))); + } + + @Test + public void testExecute() { + final RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); + final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands); + + final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; + final String sha = "abc123"; + final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; + final List keys = List.of("key"); + final List values = List.of("value"); + + when(commands.scriptLoad(script)).thenReturn(sha); + when(commands.evalsha(any(), any(), any(), any())).thenReturn("OK"); + + new ClusterLuaScript(mockCluster, script, scriptOutputType).execute(keys, values); + + verify(commands).scriptLoad(script); + verify(commands).evalsha(sha, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0])); + } + + @Test + public void testExecuteNoScriptException() { + final RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); + final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands); + + final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; + final String sha = "abc123"; + final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; + final List keys = List.of("key"); + final List values = List.of("value"); + + when(commands.scriptLoad(script)).thenReturn(sha); + when(commands.evalsha(any(), any(), any(), any())) + .thenThrow(new RedisNoScriptException("OH NO")) + .thenReturn("OK"); + + new ClusterLuaScript(mockCluster, script, scriptOutputType).execute(keys, values); + + verify(commands, times(2)).scriptLoad(script); + verify(commands, times(2)).evalsha(sha, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0])); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java new file mode 100644 index 000000000..3b26b61aa --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java @@ -0,0 +1,44 @@ +package org.whispersystems.textsecuregcm.tests.util; + +import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; + +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RedisClusterHelper { + + @SuppressWarnings("unchecked") + public static FaultTolerantRedisCluster buildMockRedisCluster(final RedisAdvancedClusterCommands commands) { + final FaultTolerantRedisCluster cluster = mock(FaultTolerantRedisCluster.class); + final StatefulRedisClusterConnection connection = mock(StatefulRedisClusterConnection.class); + + when(connection.sync()).thenReturn(commands); + + when(cluster.withReadCluster(any(Function.class))).thenAnswer(invocation -> { + return invocation.getArgument(0, Function.class).apply(connection); + }); + + doAnswer(invocation -> { + invocation.getArgument(0, Consumer.class).accept(connection); + return null; + }).when(cluster).useReadCluster(any(Consumer.class)); + + when(cluster.withWriteCluster(any(Function.class))).thenAnswer(invocation -> { + return invocation.getArgument(0, Function.class).apply(connection); + }); + + doAnswer(invocation -> { + invocation.getArgument(0, Consumer.class).accept(connection); + return null; + }).when(cluster).useWriteCluster(any(Consumer.class)); + + return cluster; + } +}