Lazy-load scripts; fall back to `eval` if `evalsha` returns `NOSCRIPT`

This commit is contained in:
Jon Chambers 2021-09-28 12:21:56 -04:00 committed by Jon Chambers
parent f37c76dab1
commit aa4bd92fee
2 changed files with 108 additions and 139 deletions

View File

@ -8,15 +8,17 @@ package org.whispersystems.textsecuregcm.redis;
import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.RedisNoScriptException;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayOutputStream;
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.codec.binary.Hex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ClusterLuaScript {
@ -53,48 +55,40 @@ public class ClusterLuaScript {
this.redisCluster = redisCluster;
this.scriptOutputType = scriptOutputType;
this.script = script;
this.sha = redisCluster.withCluster(connection -> connection.sync().scriptLoad(script));
try {
this.sha = Hex.encodeHexString(MessageDigest.getInstance("SHA-1").digest(script.getBytes(StandardCharsets.UTF_8)));
} catch (final NoSuchAlgorithmException e) {
// All Java implementations are required to support SHA-1, so this should never happen
throw new AssertionError(e);
}
}
@VisibleForTesting
String getSha() {
return sha;
}
public Object execute(final List<String> keys, final List<String> args) {
return redisCluster.withCluster(connection -> {
try {
final RedisAdvancedClusterCommands<String, String> clusterCommands = connection.sync();
try {
return clusterCommands.evalsha(sha, scriptOutputType, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY));
} catch (final RedisNoScriptException e) {
reloadScript();
return clusterCommands.evalsha(sha, scriptOutputType, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY));
}
} catch (final Exception e) {
log.warn("Failed to execute script", e);
throw e;
}
});
return redisCluster.withCluster(connection ->
execute(connection, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY)));
}
public Object executeBinary(final List<byte[]> keys, final List<byte[]> args) {
return redisCluster.withBinaryCluster(connection -> {
try {
final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands = connection.sync();
try {
return binaryCommands
.evalsha(sha, scriptOutputType, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY));
} catch (final RedisNoScriptException e) {
reloadScript();
return binaryCommands
.evalsha(sha, scriptOutputType, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY));
}
} catch (final Exception e) {
log.warn("Failed to execute script", e);
throw e;
}
});
return redisCluster.withBinaryCluster(connection ->
execute(connection, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY)));
}
private void reloadScript() {
redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptLoad(script));
private <T> Object execute(final StatefulRedisClusterConnection<T, T> connection, final T[] keys, final T[] args) {
try {
try {
return connection.sync().evalsha(sha, scriptOutputType, keys, args);
} catch (final RedisNoScriptException e) {
return connection.sync().eval(script, scriptOutputType, keys, args);
}
} catch (final Exception e) {
log.warn("Failed to execute script", e);
throw e;
}
}
}

View File

@ -5,146 +5,121 @@
package org.whispersystems.textsecuregcm.redis;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.lettuce.core.RedisNoScriptException;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.api.sync.RedisCommands;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.junit.Test;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
public class ClusterLuaScriptTest extends AbstractRedisClusterTest {
public class ClusterLuaScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
public void testExecuteMovedKey() {
final String key = "key";
final String value = "value";
final FaultTolerantRedisCluster redisCluster = getRedisCluster();
final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])",
ScriptOutputType.VALUE);
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
final int slot = SlotHash.getSlot(key);
final int sourcePort = redisCluster.withCluster(
connection -> connection.sync().nodes(node -> node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.UPSTREAM))
.node(0).getUri().getPort());
final RedisCommands<String, String> sourceCommands = redisCluster.withCluster(
connection -> connection.sync().nodes(node -> node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.UPSTREAM))
.commands(0));
final RedisCommands<String, String> destinationCommands = redisCluster.withCluster(connection -> connection.sync()
.nodes(node -> !node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.UPSTREAM)).commands(0));
destinationCommands.clusterSetSlotImporting(slot, sourceCommands.clusterMyId());
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
sourceCommands.clusterSetSlotMigrating(slot, destinationCommands.clusterMyId());
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
for (final String migrateKey : sourceCommands.clusterGetKeysInSlot(slot, Integer.MAX_VALUE)) {
destinationCommands.migrate("127.0.0.1", sourcePort, migrateKey, 0, 1000);
}
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
destinationCommands.clusterSetSlotNode(slot, destinationCommands.clusterMyId());
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
}
@Test
public void testExecute() {
void testExecute() {
final RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands);
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final String sha = "abc123";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
final List<String> keys = List.of("key");
final List<String> values = List.of("value");
when(commands.scriptLoad(script)).thenReturn(sha);
when(commands.evalsha(any(), any(), any(), any())).thenReturn("OK");
new ClusterLuaScript(mockCluster, script, scriptOutputType).execute(keys, values);
final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType);
luaScript.execute(keys, values);
verify(commands).scriptLoad(script);
verify(commands).evalsha(sha, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0]));
verify(commands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0]));
verify(commands, never()).eval(anyString(), any(), any(), any());
}
@Test
public void testExecuteNoScriptException() {
final String key = "key";
final String value = "value";
final FaultTolerantRedisCluster redisCluster = getRedisCluster();
final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])",
ScriptOutputType.VALUE);
// Remove the scripts created by the CLusterLuaScript constructor
redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptFlush());
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
}
@Test
public void testExecuteBinary() {
final RedisAdvancedClusterCommands<String, String> stringCommands = mock(RedisAdvancedClusterCommands.class);
final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper
.buildMockRedisCluster(stringCommands, binaryCommands);
void testExecuteScriptNotLoaded() {
final RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands);
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
final List<String> keys = List.of("key");
final List<String> values = List.of("value");
when(commands.evalsha(any(), any(), any(), any())).thenThrow(new RedisNoScriptException("OH NO"));
final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType);
luaScript.execute(keys, values);
verify(commands).eval(script, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0]));
verify(commands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0]));
}
@Test
void testExecuteBinaryScriptNotLoaded() {
final RedisAdvancedClusterCommands<String, String> stringCommands = mock(RedisAdvancedClusterCommands.class);
final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster =
RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands);
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final String sha = "abc123";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
final List<byte[]> keys = List.of("key".getBytes(StandardCharsets.UTF_8));
final List<byte[]> values = List.of("value".getBytes(StandardCharsets.UTF_8));
when(stringCommands.scriptLoad(script)).thenReturn(sha);
when(binaryCommands.evalsha(any(), any(), any(), any())).thenReturn("OK".getBytes(StandardCharsets.UTF_8));
when(binaryCommands.evalsha(any(), any(), any(), any())).thenThrow(new RedisNoScriptException("OH NO"));
new ClusterLuaScript(mockCluster, script, scriptOutputType).executeBinary(keys, values);
final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType);
luaScript.executeBinary(keys, values);
verify(stringCommands).scriptLoad(script);
verify(binaryCommands).evalsha(sha, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][]));
verify(binaryCommands).eval(script, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][]));
verify(binaryCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][]));
}
@Test
public void testExecuteBinaryNoScriptException() {
final String key = "key";
final String value = "value";
public void testExecuteRealCluster() {
final ClusterLuaScript script = new ClusterLuaScript(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
"return 2;",
ScriptOutputType.INTEGER);
final FaultTolerantRedisCluster redisCluster = getRedisCluster();
for (int i = 0; i < 7; i++) {
assertEquals(2L, script.execute(Collections.emptyList(), Collections.emptyList()));
}
final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])",
ScriptOutputType.VALUE);
final int evalCount = REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(connection -> {
final String commandStats = connection.sync().info("commandstats");
// Remove the scripts created by the CLusterLuaScript constructor
redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptFlush());
// We're looking for (and parsing) a line in the command stats that looks like:
//
// ```
// cmdstat_eval:calls=1,usec=44,usec_per_call=44.00
// ```
return Arrays.stream(commandStats.split("\\n"))
.filter(line -> line.startsWith("cmdstat_eval:"))
.map(String::trim)
.map(evalLine -> Arrays.stream(evalLine.substring(evalLine.indexOf(':') + 1).split(","))
.filter(pair -> pair.startsWith("calls="))
.map(callsPair -> Integer.parseInt(callsPair.substring(callsPair.indexOf('=') + 1)))
.findFirst()
.orElse(0))
.findFirst()
.orElse(0);
});
assertArrayEquals("OK".getBytes(StandardCharsets.UTF_8), (byte[]) script
.executeBinary(List.of(key.getBytes(StandardCharsets.UTF_8)), List.of(value.getBytes(StandardCharsets.UTF_8))));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
assertEquals(1, evalCount);
}
}