From aa4bd92fee0f3ac8708557ca6159a08e9ab560a8 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Tue, 28 Sep 2021 12:21:56 -0400 Subject: [PATCH] Lazy-load scripts; fall back to `eval` if `evalsha` returns `NOSCRIPT` --- .../textsecuregcm/redis/ClusterLuaScript.java | 74 ++++---- .../redis/ClusterLuaScriptTest.java | 173 ++++++++---------- 2 files changed, 108 insertions(+), 139 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java index 1369047e9..089d3f37a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java @@ -8,15 +8,17 @@ package org.whispersystems.textsecuregcm.redis; import com.google.common.annotations.VisibleForTesting; import io.lettuce.core.RedisNoScriptException; import io.lettuce.core.ScriptOutputType; -import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.ByteArrayOutputStream; +import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.commons.codec.binary.Hex; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class ClusterLuaScript { @@ -53,48 +55,40 @@ public class ClusterLuaScript { this.redisCluster = redisCluster; this.scriptOutputType = scriptOutputType; this.script = script; - this.sha = redisCluster.withCluster(connection -> connection.sync().scriptLoad(script)); + + try { + this.sha = Hex.encodeHexString(MessageDigest.getInstance("SHA-1").digest(script.getBytes(StandardCharsets.UTF_8))); + } catch (final NoSuchAlgorithmException e) { + // All Java implementations are required to support SHA-1, so this should never happen + throw new AssertionError(e); + } + } + + @VisibleForTesting + String getSha() { + return sha; } public Object execute(final List keys, final List args) { - return redisCluster.withCluster(connection -> { - try { - final RedisAdvancedClusterCommands clusterCommands = connection.sync(); - - try { - return clusterCommands.evalsha(sha, scriptOutputType, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY)); - } catch (final RedisNoScriptException e) { - reloadScript(); - return clusterCommands.evalsha(sha, scriptOutputType, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY)); - } - } catch (final Exception e) { - log.warn("Failed to execute script", e); - throw e; - } - }); + return redisCluster.withCluster(connection -> + execute(connection, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY))); } public Object executeBinary(final List keys, final List args) { - return redisCluster.withBinaryCluster(connection -> { - try { - final RedisAdvancedClusterCommands binaryCommands = connection.sync(); - - try { - return binaryCommands - .evalsha(sha, scriptOutputType, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY)); - } catch (final RedisNoScriptException e) { - reloadScript(); - return binaryCommands - .evalsha(sha, scriptOutputType, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY)); - } - } catch (final Exception e) { - log.warn("Failed to execute script", e); - throw e; - } - }); + return redisCluster.withBinaryCluster(connection -> + execute(connection, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY))); } - private void reloadScript() { - redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptLoad(script)); + private Object execute(final StatefulRedisClusterConnection connection, final T[] keys, final T[] args) { + try { + try { + return connection.sync().evalsha(sha, scriptOutputType, keys, args); + } catch (final RedisNoScriptException e) { + return connection.sync().eval(script, scriptOutputType, keys, args); + } + } catch (final Exception e) { + log.warn("Failed to execute script", e); + throw e; + } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java index 3553a998e..feccf3bb7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java @@ -5,146 +5,121 @@ package org.whispersystems.textsecuregcm.redis; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import io.lettuce.core.RedisNoScriptException; import io.lettuce.core.ScriptOutputType; -import io.lettuce.core.api.sync.RedisCommands; -import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; -import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; import java.util.List; -import org.junit.Test; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; -public class ClusterLuaScriptTest extends AbstractRedisClusterTest { +public class ClusterLuaScriptTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); @Test - public void testExecuteMovedKey() { - final String key = "key"; - final String value = "value"; - - final FaultTolerantRedisCluster redisCluster = getRedisCluster(); - - final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])", - ScriptOutputType.VALUE); - - assertEquals("OK", script.execute(List.of(key), List.of(value))); - assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key))); - - final int slot = SlotHash.getSlot(key); - - final int sourcePort = redisCluster.withCluster( - connection -> connection.sync().nodes(node -> node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.UPSTREAM)) - .node(0).getUri().getPort()); - final RedisCommands sourceCommands = redisCluster.withCluster( - connection -> connection.sync().nodes(node -> node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.UPSTREAM)) - .commands(0)); - final RedisCommands destinationCommands = redisCluster.withCluster(connection -> connection.sync() - .nodes(node -> !node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.UPSTREAM)).commands(0)); - - destinationCommands.clusterSetSlotImporting(slot, sourceCommands.clusterMyId()); - - assertEquals("OK", script.execute(List.of(key), List.of(value))); - assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key))); - - sourceCommands.clusterSetSlotMigrating(slot, destinationCommands.clusterMyId()); - - assertEquals("OK", script.execute(List.of(key), List.of(value))); - assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key))); - - for (final String migrateKey : sourceCommands.clusterGetKeysInSlot(slot, Integer.MAX_VALUE)) { - destinationCommands.migrate("127.0.0.1", sourcePort, migrateKey, 0, 1000); - } - - assertEquals("OK", script.execute(List.of(key), List.of(value))); - assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key))); - - destinationCommands.clusterSetSlotNode(slot, destinationCommands.clusterMyId()); - - assertEquals("OK", script.execute(List.of(key), List.of(value))); - assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key))); - } - - @Test - public void testExecute() { + void testExecute() { final RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands); final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; - final String sha = "abc123"; final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; final List keys = List.of("key"); final List values = List.of("value"); - when(commands.scriptLoad(script)).thenReturn(sha); when(commands.evalsha(any(), any(), any(), any())).thenReturn("OK"); - new ClusterLuaScript(mockCluster, script, scriptOutputType).execute(keys, values); + final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType); + luaScript.execute(keys, values); - verify(commands).scriptLoad(script); - verify(commands).evalsha(sha, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0])); + verify(commands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0])); + verify(commands, never()).eval(anyString(), any(), any(), any()); } @Test - public void testExecuteNoScriptException() { - final String key = "key"; - final String value = "value"; - - final FaultTolerantRedisCluster redisCluster = getRedisCluster(); - - final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])", - ScriptOutputType.VALUE); - - // Remove the scripts created by the CLusterLuaScript constructor - redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptFlush()); - - assertEquals("OK", script.execute(List.of(key), List.of(value))); - assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key))); - } - - @Test - public void testExecuteBinary() { - final RedisAdvancedClusterCommands stringCommands = mock(RedisAdvancedClusterCommands.class); - final RedisAdvancedClusterCommands binaryCommands = mock(RedisAdvancedClusterCommands.class); - final FaultTolerantRedisCluster mockCluster = RedisClusterHelper - .buildMockRedisCluster(stringCommands, binaryCommands); + void testExecuteScriptNotLoaded() { + final RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); + final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands); + + final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; + final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; + final List keys = List.of("key"); + final List values = List.of("value"); + + when(commands.evalsha(any(), any(), any(), any())).thenThrow(new RedisNoScriptException("OH NO")); + + final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType); + luaScript.execute(keys, values); + + verify(commands).eval(script, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0])); + verify(commands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0])); + } + + @Test + void testExecuteBinaryScriptNotLoaded() { + final RedisAdvancedClusterCommands stringCommands = mock(RedisAdvancedClusterCommands.class); + final RedisAdvancedClusterCommands binaryCommands = mock(RedisAdvancedClusterCommands.class); + final FaultTolerantRedisCluster mockCluster = + RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands); final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; - final String sha = "abc123"; final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; final List keys = List.of("key".getBytes(StandardCharsets.UTF_8)); final List values = List.of("value".getBytes(StandardCharsets.UTF_8)); - when(stringCommands.scriptLoad(script)).thenReturn(sha); - when(binaryCommands.evalsha(any(), any(), any(), any())).thenReturn("OK".getBytes(StandardCharsets.UTF_8)); + when(binaryCommands.evalsha(any(), any(), any(), any())).thenThrow(new RedisNoScriptException("OH NO")); - new ClusterLuaScript(mockCluster, script, scriptOutputType).executeBinary(keys, values); + final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType); + luaScript.executeBinary(keys, values); - verify(stringCommands).scriptLoad(script); - verify(binaryCommands).evalsha(sha, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][])); + verify(binaryCommands).eval(script, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][])); + verify(binaryCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][])); } @Test - public void testExecuteBinaryNoScriptException() { - final String key = "key"; - final String value = "value"; + public void testExecuteRealCluster() { + final ClusterLuaScript script = new ClusterLuaScript(REDIS_CLUSTER_EXTENSION.getRedisCluster(), + "return 2;", + ScriptOutputType.INTEGER); - final FaultTolerantRedisCluster redisCluster = getRedisCluster(); + for (int i = 0; i < 7; i++) { + assertEquals(2L, script.execute(Collections.emptyList(), Collections.emptyList())); + } - final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])", - ScriptOutputType.VALUE); + final int evalCount = REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(connection -> { + final String commandStats = connection.sync().info("commandstats"); - // Remove the scripts created by the CLusterLuaScript constructor - redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptFlush()); + // We're looking for (and parsing) a line in the command stats that looks like: + // + // ``` + // cmdstat_eval:calls=1,usec=44,usec_per_call=44.00 + // ``` + return Arrays.stream(commandStats.split("\\n")) + .filter(line -> line.startsWith("cmdstat_eval:")) + .map(String::trim) + .map(evalLine -> Arrays.stream(evalLine.substring(evalLine.indexOf(':') + 1).split(",")) + .filter(pair -> pair.startsWith("calls=")) + .map(callsPair -> Integer.parseInt(callsPair.substring(callsPair.indexOf('=') + 1))) + .findFirst() + .orElse(0)) + .findFirst() + .orElse(0); + }); - assertArrayEquals("OK".getBytes(StandardCharsets.UTF_8), (byte[]) script - .executeBinary(List.of(key.getBytes(StandardCharsets.UTF_8)), List.of(value.getBytes(StandardCharsets.UTF_8)))); - assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key))); + assertEquals(1, evalCount); } }