diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java index b6b9a423a..ab7e62772 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java @@ -64,7 +64,7 @@ public class ClusterLuaScript { try { return clusterCommands.evalsha(sha, scriptOutputType, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY)); } catch (final RedisNoScriptException e) { - clusterCommands.scriptLoad(script); + reloadScript(); return clusterCommands.evalsha(sha, scriptOutputType, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY)); } } catch (final Exception e) { @@ -82,7 +82,7 @@ public class ClusterLuaScript { try { return binaryCommands.evalsha(sha, scriptOutputType, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY)); } catch (final RedisNoScriptException e) { - binaryCommands.scriptLoad(script.getBytes(StandardCharsets.UTF_8)); + reloadScript(); return binaryCommands.evalsha(sha, scriptOutputType, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY)); } } catch (final Exception e) { @@ -91,4 +91,8 @@ public class ClusterLuaScript { } }); } + + private void reloadScript() { + redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptLoad(script)); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java index b2f97b152..b627c8226 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScriptTest.java @@ -90,24 +90,18 @@ public class ClusterLuaScriptTest extends AbstractRedisClusterTest { @Test public void testExecuteNoScriptException() { - final RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); - final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands); + final String key = "key"; + final String value = "value"; - final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; - final String sha = "abc123"; - final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; - final List keys = List.of("key"); - final List values = List.of("value"); + final FaultTolerantRedisCluster redisCluster = getRedisCluster(); - when(commands.scriptLoad(script)).thenReturn(sha); - when(commands.evalsha(any(), any(), any(), any())) - .thenThrow(new RedisNoScriptException("OH NO")) - .thenReturn("OK"); + final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])", ScriptOutputType.VALUE); - new ClusterLuaScript(mockCluster, script, scriptOutputType).execute(keys, values); + // Remove the scripts created by the CLusterLuaScript constructor + redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptFlush()); - verify(commands, times(2)).scriptLoad(script); - verify(commands, times(2)).evalsha(sha, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0])); + assertEquals("OK", script.execute(List.of(key), List.of(value))); + assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key))); } @Test @@ -133,25 +127,17 @@ public class ClusterLuaScriptTest extends AbstractRedisClusterTest { @Test public void testExecuteBinaryNoScriptException() { - final RedisAdvancedClusterCommands stringCommands = mock(RedisAdvancedClusterCommands.class); - final RedisAdvancedClusterCommands binaryCommands = mock(RedisAdvancedClusterCommands.class); - final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands); + final String key = "key"; + final String value = "value"; - final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; - final String sha = "abc123"; - final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; - final List keys = List.of("key".getBytes(StandardCharsets.UTF_8)); - final List values = List.of("value".getBytes(StandardCharsets.UTF_8)); + final FaultTolerantRedisCluster redisCluster = getRedisCluster(); - when(stringCommands.scriptLoad(script)).thenReturn(sha); - when(binaryCommands.evalsha(any(), any(), any(), any())) - .thenThrow(new RedisNoScriptException("OH NO")) - .thenReturn("OK".getBytes(StandardCharsets.UTF_8)); + final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])", ScriptOutputType.VALUE); - new ClusterLuaScript(mockCluster, script, scriptOutputType).executeBinary(keys, values); + // Remove the scripts created by the CLusterLuaScript constructor + redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptFlush()); - verify(stringCommands).scriptLoad(script); - verify(binaryCommands).scriptLoad(script.getBytes(StandardCharsets.UTF_8)); - verify(binaryCommands, times(2)).evalsha(sha, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][])); + assertArrayEquals("OK".getBytes(StandardCharsets.UTF_8), (byte[])script.executeBinary(List.of(key.getBytes(StandardCharsets.UTF_8)), List.of(value.getBytes(StandardCharsets.UTF_8)))); + assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key))); } }