diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java index 22d4f4ffc..e5a992048 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java @@ -8,6 +8,7 @@ import static java.util.Objects.requireNonNull; import static java.util.concurrent.CompletableFuture.completedFuture; import static java.util.concurrent.CompletableFuture.failedFuture; +import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import java.time.Clock; @@ -88,12 +89,12 @@ public class StaticRateLimiter implements RateLimiter { @Override public void clear(final String key) { - cacheCluster.useCluster(connection -> connection.sync().del(bucketName(key))); + cacheCluster.useCluster(connection -> connection.sync().del(bucketName(name, key))); } @Override public CompletionStage clearAsync(final String key) { - return cacheCluster.withCluster(connection -> connection.async().del(bucketName(key))) + return cacheCluster.withCluster(connection -> connection.async().del(bucketName(name, key))) .thenRun(Util.NOOP); } @@ -103,7 +104,7 @@ public class StaticRateLimiter implements RateLimiter { } private long executeValidateScript(final String key, final int amount, final boolean applyChanges) { - final List keys = List.of(bucketName(key)); + final List keys = List.of(bucketName(name, key)); final List arguments = List.of( String.valueOf(config.bucketSize()), String.valueOf(config.leakRatePerMillis()), @@ -115,7 +116,7 @@ public class StaticRateLimiter implements RateLimiter { } private CompletionStage executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) { - final List keys = List.of(bucketName(key)); + final List keys = List.of(bucketName(name, key)); final List arguments = List.of( String.valueOf(config.bucketSize()), String.valueOf(config.leakRatePerMillis()), @@ -126,7 +127,8 @@ public class StaticRateLimiter implements RateLimiter { return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o); } - private String bucketName(final String key) { + @VisibleForTesting + protected static String bucketName(final String name, final String key) { return "leaky_bucket::" + name + "::" + key; } } diff --git a/service/src/main/resources/lua/validate_rate_limit.lua b/service/src/main/resources/lua/validate_rate_limit.lua index 5c092320a..424be261a 100644 --- a/service/src/main/resources/lua/validate_rate_limit.lua +++ b/service/src/main/resources/lua/validate_rate_limit.lua @@ -14,48 +14,72 @@ local currentTimeMillis = tonumber(ARGV[3]) local requestedAmount = tonumber(ARGV[4]) local useTokens = ARGV[5] and string.lower(ARGV[5]) == "true" -local tokenBucketJson = redis.call("GET", bucketId) -local tokenBucket +local SIZE_FIELD = "s" +local TIME_FIELD = "t" + local changesMade = false +local tokensRemaining +local lastUpdateTimeMillis -if tokenBucketJson then - tokenBucket = cjson.decode(tokenBucketJson) -else - tokenBucket = { - ["bucketSize"] = bucketSize, - ["leakRatePerMillis"] = refillRatePerMillis, - ["spaceRemaining"] = bucketSize, - ["lastUpdateTimeMillis"] = currentTimeMillis - } -end - --- this can happen if rate limiter configuration has changed while the key is still in Redis -if tokenBucket["bucketSize"] ~= bucketSize or tokenBucket["leakRatePerMillis"] ~= refillRatePerMillis then - tokenBucket["bucketSize"] = bucketSize - tokenBucket["leakRatePerMillis"] = refillRatePerMillis +-- while we're migrating from json to redis list key types, there are three possible options for the +-- type of the `bucketId` key: "string" (legacy, json value), "list" (new format), "none" (key not set). +-- +-- In the phase 1 of migration, we prepare the script to deal with the phase 2 :) I.e. when phase 2 will be rolling out, +-- it will start writing data in the new format, and the still running instances of the previous version +-- need to be able to know how to read the new format before we start writing it. +-- +-- On a separate note -- the reason we're not using a different key is because Redis Lua requires to list all keys +-- as a script input and we don't want to expose this migration to the script users. +-- +-- Finally, it's okay to read the "ok" key of the return here because "TYPE" command always succeeds. +local keyType = redis.call("TYPE", bucketId)["ok"] +if keyType == "none" then + -- if the key is not set, building the object from the configuration + tokensRemaining = bucketSize + lastUpdateTimeMillis = currentTimeMillis +elseif keyType == "string" then + -- if the key is "string", we parse the value from json + local fromJson = cjson.decode(redis.call("GET", bucketId)) + if bucketSize ~= fromJson.bucketSize or refillRatePerMillis ~= fromJson.leakRatePerMillis then + changesMade = true + end + tokensRemaining = fromJson.spaceRemaining + lastUpdateTimeMillis = fromJson.lastUpdateTimeMillis +elseif keyType == "hash" then + -- finally, reading values from the new storage format + local tokensRemainingStr, lastUpdateTimeMillisStr = unpack(redis.call("HMGET", bucketId, SIZE_FIELD, TIME_FIELD)) + tokensRemaining = tonumber(tokensRemainingStr) + lastUpdateTimeMillis = tonumber(lastUpdateTimeMillisStr) + redis.call("DEL", bucketId) changesMade = true end -local elapsedTime = currentTimeMillis - tokenBucket["lastUpdateTimeMillis"] +local elapsedTime = currentTimeMillis - lastUpdateTimeMillis local availableAmount = math.min( - tokenBucket["bucketSize"], - math.floor(tokenBucket["spaceRemaining"] + (elapsedTime * tokenBucket["leakRatePerMillis"])) + bucketSize, + math.floor(tokensRemaining + (elapsedTime * refillRatePerMillis)) ) if availableAmount >= requestedAmount then if useTokens then - tokenBucket["spaceRemaining"] = availableAmount - requestedAmount - tokenBucket["lastUpdateTimeMillis"] = currentTimeMillis + tokensRemaining = availableAmount - requestedAmount + lastUpdateTimeMillis = currentTimeMillis changesMade = true end if changesMade then - local tokensUsed = tokenBucket["bucketSize"] - tokenBucket["spaceRemaining"] + local tokensUsed = bucketSize - tokensRemaining -- Storing a 'full' bucket is equivalent of not storing any state at all -- (in which case a bucket will be just initialized from the input configs as a 'full' one). -- For this reason, we either set an expiration time on the record (calculated to let the bucket fully replenish) -- or we just delete the key if the bucket is full. if tokensUsed > 0 then - local ttlMillis = math.ceil(tokensUsed / tokenBucket["leakRatePerMillis"]) + local ttlMillis = math.ceil(tokensUsed / refillRatePerMillis) + local tokenBucket = { + ["bucketSize"] = bucketSize, + ["leakRatePerMillis"] = refillRatePerMillis, + ["spaceRemaining"] = tokensRemaining, + ["lastUpdateTimeMillis"] = lastUpdateTimeMillis + } redis.call("SET", bucketId, cjson.encode(tokenBucket), "PX", ttlMillis) else redis.call("DEL", bucketId) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java index 424c0fff1..eb69ba30d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java @@ -19,6 +19,7 @@ import java.util.Optional; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; @@ -60,7 +61,51 @@ public class RateLimitersLuaScriptTest { final RateLimiter rateLimiter = limiters.forDescriptor(descriptor); rateLimiter.validate("test", 25); rateLimiter.validate("test", 25); - assertThrows(Exception.class, () -> rateLimiter.validate("test", 25)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 25)); + } + + @Test + public void testFormatMigration() throws Exception { + final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; + final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); + final RateLimiters limiters = new RateLimiters( + Map.of(descriptor.id(), new RateLimiterConfig(60, 60)), + dynamicConfig, + RateLimiters.defaultScript(redisCluster), + redisCluster, + Clock.systemUTC()); + + final RateLimiter rateLimiter = limiters.forDescriptor(descriptor); + + // embedding an existing value in the old format + redisCluster.useCluster(c -> c.sync().set( + StaticRateLimiter.bucketName(descriptor.id(), "test"), + serializeToOldBucketValueFormat(new TokenBucket(60, 60, 30, System.currentTimeMillis() + 10000)) + )); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 40)); + + // embedding an existing value in the old format + redisCluster.useCluster(c -> c.sync().set( + StaticRateLimiter.bucketName(descriptor.id(), "test1"), + serializeToOldBucketValueFormat(new TokenBucket(60, 60, 30, System.currentTimeMillis() + 10000)) + )); + rateLimiter.validate("test1", 20); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test1", 20)); + + // embedding an existing value in the new format + redisCluster.useCluster(c -> c.sync().hset( + StaticRateLimiter.bucketName(descriptor.id(), "test2"), + Map.of("s", "30", "t", String.valueOf(System.currentTimeMillis() + 10000)) + )); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test2", 40)); + + // embedding an existing value in the new format + redisCluster.useCluster(c -> c.sync().hset( + StaticRateLimiter.bucketName(descriptor.id(), "test3"), + Map.of("s", "30", "t", String.valueOf(System.currentTimeMillis() + 10000)) + )); + rateLimiter.validate("test3", 20); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test3", 20)); } @Test @@ -117,6 +162,19 @@ public class RateLimitersLuaScriptTest { assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining); } + private String serializeToOldBucketValueFormat(final TokenBucket bucket) { + try { + return SystemMapper.jsonMapper().writeValueAsString(Map.of( + "bucketSize", bucket.bucketSize, + "leakRatePerMillis", bucket.leakRatePerMillis, + "spaceRemaining", bucket.spaceRemaining, + "lastUpdateTimeMillis", bucket.lastUpdateTimeMillis + )); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + private Optional decodeBucket(final String key) { try { final String json = redisCommandsHandler.get(key); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java index 0619ad8c9..f5617c1a1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail; import java.util.List; +import java.util.Map; /** * This class is to be extended with implementations of Redis commands as needed. @@ -28,23 +29,79 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler { yield get(args.get(0).toString()); } case "DEL" -> { - assertTrue(args.size() > 1); - yield del(args.get(0).toString()); + assertTrue(args.size() >= 1); + yield del(args.stream().map(Object::toString).toList()); } + case "HSET" -> { + assertTrue(args.size() >= 3); + yield hset(args.get(0).toString(), args.get(1).toString(), args.get(2).toString(), tail(args, 3)); + } + case "HGET" -> { + assertEquals(2, args.size()); + yield hget(args.get(0).toString(), args.get(1).toString()); + } + case "PEXPIRE" -> { + assertEquals(2, args.size()); + yield pexpire(args.get(0).toString(), Double.valueOf(args.get(1).toString()).longValue(), tail(args, 2)); + } + case "TYPE" -> { + assertEquals(1, args.size()); + yield type(args.get(0).toString()); + } + case "RPUSH" -> { + assertTrue(args.size() > 1); + yield push(false, args.get(0).toString(), tail(args, 1)); + } + case "LPUSH" -> { + assertTrue(args.size() > 1); + yield push(true, args.get(0).toString(), tail(args, 1)); + } + case "RPOP" -> { + assertEquals(2, args.size()); + yield pop(false, args.get(0).toString(), Double.valueOf(args.get(1).toString()).intValue()); + } + case "LPOP" -> { + assertEquals(2, args.size()); + yield pop(true, args.get(0).toString(), Double.valueOf(args.get(1).toString()).intValue()); + } + default -> other(command, args); }; } + public Object[] pop(final boolean left, final String key, final int count) { + return new Object[count]; + } + + public Object push(final boolean left, final String key, final List values) { + return 0; + } + + public Object type(final String key) { + return Map.of("ok", "none"); + } + + public Object pexpire(final String key, final long ttlMillis, final List args) { + return 0; + } + + public Object hset(final String key, final String field, final String value, final List other) { + return "OK"; + } + + public Object hget(final String key, final String field) { + return null; + } + public Object set(final String key, final String value, final List tail) { return "OK"; } public String get(final String key) { return null; - } - public int del(final String key) { + public int del(final List keys) { return 0; } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/RedisLuaScriptSandbox.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/RedisLuaScriptSandbox.java index 06452e5bb..734617ed5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/RedisLuaScriptSandbox.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/RedisLuaScriptSandbox.java @@ -28,7 +28,12 @@ public class RedisLuaScriptSandbox { function redis_call(...) -- variable name needs to match the one used in the `L.setGlobal()` call -- method name needs to match method name of the Java class - return proxy:redisCall(arg) + local result = proxy:redisCall(arg) + if type(result) == "userdata" then + return java.luaify(result) + else + return result + end end function json_encode(obj) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java index 290da8e48..75d48c07e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java @@ -6,13 +6,15 @@ package org.whispersystems.textsecuregcm.util.redis; import java.time.Clock; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import javax.annotation.Nullable; public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler { - public record Entry(String value, long expirationEpochMillis) { + public record Entry(Object value, long expirationEpochMillis) { } private final Map cache = new ConcurrentHashMap<>(); @@ -32,20 +34,106 @@ public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler { @Override public String get(final String key) { + return getIfNotExpired(key, String.class); + } + + @Override + public int del(final List key) { + return key.stream() + .mapToInt(k -> cache.remove(k) != null ? 1 : 0) + .sum(); + } + + @SuppressWarnings("unchecked") + @Override + public Object hset(final String key, final String field, final String value, final List other) { + Map map = getIfNotExpired(key, Map.class); + if (map == null) { + map = new ConcurrentHashMap<>(); + cache.put(key, new Entry(map, Long.MAX_VALUE)); + } + map.put(field, value); + return "OK"; + } + + @Override + public Object hget(final String key, final String field) { + final Map map = getIfNotExpired(key, Map.class); + return map == null ? null : map.get(field); + } + + @SuppressWarnings("unchecked") + @Override + public Object push(final boolean left, final String key, final List values) { + LinkedList list = getIfNotExpired(key, LinkedList.class); + if (list == null) { + list = new LinkedList<>(); + cache.put(key, new Entry(list, Long.MAX_VALUE)); + } + for (Object v: values) { + if (left) { + list.addFirst(v.toString()); + } else { + list.addLast(v.toString()); + } + } + return list.size(); + } + + @SuppressWarnings("unchecked") + @Override + public Object[] pop(final boolean left, final String key, final int count) { + final Object[] result = new String[count]; + final LinkedList list = getIfNotExpired(key, LinkedList.class); + if (list == null) { + return result; + } + for (int i = 0; i < Math.min(count, list.size()); i++) { + result[i] = left ? list.removeFirst() : list.removeLast(); + } + return result; + } + + @Override + public Object pexpire(final String key, final long ttlMillis, final List args) { + final Entry e = cache.get(key); + if (e == null) { + return 0; + } + final Entry updated = new Entry(e.value(), clock.millis() + ttlMillis); + cache.put(key, updated); + return 1; + } + + @Override + public Object type(final String key) { + final Object o = getIfNotExpired(key, Object.class); + final String type; + if (o == null) { + type = "none"; + } else if (o.getClass() == String.class) { + type = "string"; + } else if (Map.class.isAssignableFrom(o.getClass())) { + type = "hash"; + } else if (List.class.isAssignableFrom(o.getClass())) { + type = "list"; + } else { + throw new IllegalArgumentException("Unsupported value type: " + o.getClass()); + } + return Map.of("ok", type); + } + + @Nullable + protected T getIfNotExpired(final String key, final Class expectedType) { final Entry entry = cache.get(key); if (entry == null) { return null; } if (entry.expirationEpochMillis() < clock.millis()) { - del(key); + del(List.of(key)); return null; } - return entry.value(); - } - - @Override - public int del(final String key) { - return cache.remove(key) != null ? 1 : 0; + return expectedType.cast(entry.value()); } protected long resolveExpirationEpochMillis(final List args) {