diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java index 59a8a2ff5..0b6a24e29 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java @@ -13,6 +13,7 @@ import org.slf4j.LoggerFactory; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.util.List; public class ClusterLuaScript { @@ -22,7 +23,8 @@ public class ClusterLuaScript { private final String script; private final String sha; - private static final String[] STRING_ARRAY = new String[0]; + private static final String[] STRING_ARRAY = new String[0]; + private static final byte[][] BYTE_ARRAY_ARRAY = new byte[0][]; private static final Logger log = LoggerFactory.getLogger(ClusterLuaScript.class); @@ -66,4 +68,22 @@ public class ClusterLuaScript { } }); } + + public Object executeBinary(final List keys, final List args) { + return redisCluster.withBinaryWriteCluster(connection -> { + try { + final RedisAdvancedClusterCommands binaryCommands = connection.sync(); + + try { + return binaryCommands.evalsha(sha, scriptOutputType, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY)); + } catch (final RedisNoScriptException e) { + binaryCommands.scriptLoad(script.getBytes(StandardCharsets.UTF_8)); + return binaryCommands.evalsha(sha, scriptOutputType, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY)); + } + } catch (final Exception e) { + log.warn("Failed to execute script", e); + throw e; + } + }); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java index 56d2a753c..20f84c958 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java @@ -9,8 +9,10 @@ import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import org.junit.Test; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; +import java.nio.charset.StandardCharsets; import java.util.List; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -30,7 +32,7 @@ public class ClusterLuaScriptTest extends AbstractRedisClusterTest { final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])", ScriptOutputType.VALUE); - assertEquals("OK", script.execute(List.of(key), List.of("value"))); + assertEquals("OK", script.execute(List.of(key), List.of(value))); assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key))); final int slot = SlotHash.getSlot(key); @@ -41,24 +43,24 @@ public class ClusterLuaScriptTest extends AbstractRedisClusterTest { destinationCommands.clusterSetSlotImporting(slot, sourceCommands.clusterMyId()); - assertEquals("OK", script.execute(List.of(key), List.of("value"))); + assertEquals("OK", script.execute(List.of(key), List.of(value))); assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key))); sourceCommands.clusterSetSlotMigrating(slot, destinationCommands.clusterMyId()); - assertEquals("OK", script.execute(List.of(key), List.of("value"))); + assertEquals("OK", script.execute(List.of(key), List.of(value))); assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key))); for (final String migrateKey : sourceCommands.clusterGetKeysInSlot(slot, Integer.MAX_VALUE)) { destinationCommands.migrate("127.0.0.1", sourcePort, migrateKey, 0, 1000); } - assertEquals("OK", script.execute(List.of(key), List.of("value"))); + assertEquals("OK", script.execute(List.of(key), List.of(value))); assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key))); destinationCommands.clusterSetSlotNode(slot, destinationCommands.clusterMyId()); - assertEquals("OK", script.execute(List.of(key), List.of("value"))); + assertEquals("OK", script.execute(List.of(key), List.of(value))); assertEquals(value, redisCluster.withReadCluster(connection -> connection.sync().get(key))); } @@ -103,4 +105,49 @@ public class ClusterLuaScriptTest extends AbstractRedisClusterTest { verify(commands, times(2)).scriptLoad(script); verify(commands, times(2)).evalsha(sha, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0])); } + + @Test + public void testExecuteBinary() { + final RedisAdvancedClusterCommands stringCommands = mock(RedisAdvancedClusterCommands.class); + final RedisAdvancedClusterCommands binaryCommands = mock(RedisAdvancedClusterCommands.class); + final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands); + + final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; + final String sha = "abc123"; + final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; + final List keys = List.of("key".getBytes(StandardCharsets.UTF_8)); + final List values = List.of("value".getBytes(StandardCharsets.UTF_8)); + + when(stringCommands.scriptLoad(script)).thenReturn(sha); + when(binaryCommands.evalsha(any(), any(), any(), any())).thenReturn("OK".getBytes(StandardCharsets.UTF_8)); + + new ClusterLuaScript(mockCluster, script, scriptOutputType).executeBinary(keys, values); + + verify(stringCommands).scriptLoad(script); + verify(binaryCommands).evalsha(sha, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][])); + } + + @Test + public void testExecuteBinaryNoScriptException() { + final RedisAdvancedClusterCommands stringCommands = mock(RedisAdvancedClusterCommands.class); + final RedisAdvancedClusterCommands binaryCommands = mock(RedisAdvancedClusterCommands.class); + final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands); + + final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; + final String sha = "abc123"; + final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; + final List keys = List.of("key".getBytes(StandardCharsets.UTF_8)); + final List values = List.of("value".getBytes(StandardCharsets.UTF_8)); + + when(stringCommands.scriptLoad(script)).thenReturn(sha); + when(binaryCommands.evalsha(any(), any(), any(), any())) + .thenThrow(new RedisNoScriptException("OH NO")) + .thenReturn("OK".getBytes(StandardCharsets.UTF_8)); + + new ClusterLuaScript(mockCluster, script, scriptOutputType).executeBinary(keys, values); + + verify(stringCommands).scriptLoad(script); + verify(binaryCommands).scriptLoad(script.getBytes(StandardCharsets.UTF_8)); + verify(binaryCommands, times(2)).evalsha(sha, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][])); + } }