From c14ef7e6cff437680814ab626f7fd21fba1a355f Mon Sep 17 00:00:00 2001 From: Sergey Skrobotov Date: Thu, 16 Mar 2023 01:41:48 -0700 Subject: [PATCH] migrate token bucket redis record format from json to hash: phase 2 --- .../resources/lua/validate_rate_limit.lua | 22 ++---- .../limits/RateLimitersLuaScriptTest.java | 72 +++++++++---------- .../util/redis/BaseRedisCommandsHandler.java | 18 +++-- .../redis/SimpleCacheCommandsHandler.java | 19 ++++- 4 files changed, 71 insertions(+), 60 deletions(-) diff --git a/service/src/main/resources/lua/validate_rate_limit.lua b/service/src/main/resources/lua/validate_rate_limit.lua index 424be261a..d18c4e52c 100644 --- a/service/src/main/resources/lua/validate_rate_limit.lua +++ b/service/src/main/resources/lua/validate_rate_limit.lua @@ -24,10 +24,6 @@ local lastUpdateTimeMillis -- while we're migrating from json to redis list key types, there are three possible options for the -- type of the `bucketId` key: "string" (legacy, json value), "list" (new format), "none" (key not set). -- --- In the phase 1 of migration, we prepare the script to deal with the phase 2 :) I.e. when phase 2 will be rolling out, --- it will start writing data in the new format, and the still running instances of the previous version --- need to be able to know how to read the new format before we start writing it. --- -- On a separate note -- the reason we're not using a different key is because Redis Lua requires to list all keys -- as a script input and we don't want to expose this migration to the script users. -- @@ -40,18 +36,15 @@ if keyType == "none" then elseif keyType == "string" then -- if the key is "string", we parse the value from json local fromJson = cjson.decode(redis.call("GET", bucketId)) - if bucketSize ~= fromJson.bucketSize or refillRatePerMillis ~= fromJson.leakRatePerMillis then - changesMade = true - end tokensRemaining = fromJson.spaceRemaining lastUpdateTimeMillis = fromJson.lastUpdateTimeMillis + redis.call("DEL", bucketId) + changesMade = true elseif keyType == "hash" then -- finally, reading values from the new storage format local tokensRemainingStr, lastUpdateTimeMillisStr = unpack(redis.call("HMGET", bucketId, SIZE_FIELD, TIME_FIELD)) tokensRemaining = tonumber(tokensRemainingStr) lastUpdateTimeMillis = tonumber(lastUpdateTimeMillisStr) - redis.call("DEL", bucketId) - changesMade = true end local elapsedTime = currentTimeMillis - lastUpdateTimeMillis @@ -68,19 +61,14 @@ if availableAmount >= requestedAmount then end if changesMade then local tokensUsed = bucketSize - tokensRemaining - -- Storing a 'full' bucket is equivalent of not storing any state at all + -- Storing a 'full' bucket (i.e. tokensUsed == 0) is equivalent of not storing any state at all -- (in which case a bucket will be just initialized from the input configs as a 'full' one). -- For this reason, we either set an expiration time on the record (calculated to let the bucket fully replenish) -- or we just delete the key if the bucket is full. if tokensUsed > 0 then local ttlMillis = math.ceil(tokensUsed / refillRatePerMillis) - local tokenBucket = { - ["bucketSize"] = bucketSize, - ["leakRatePerMillis"] = refillRatePerMillis, - ["spaceRemaining"] = tokensRemaining, - ["lastUpdateTimeMillis"] = lastUpdateTimeMillis - } - redis.call("SET", bucketId, cjson.encode(tokenBucket), "PX", ttlMillis) + redis.call("HSET", bucketId, SIZE_FIELD, tokensRemaining, TIME_FIELD, lastUpdateTimeMillis) + redis.call("PEXPIRE", bucketId, ttlMillis) else redis.call("DEL", bucketId) end diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java index eb69ba30d..30ec78ea7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.limits; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -80,14 +81,14 @@ public class RateLimitersLuaScriptTest { // embedding an existing value in the old format redisCluster.useCluster(c -> c.sync().set( StaticRateLimiter.bucketName(descriptor.id(), "test"), - serializeToOldBucketValueFormat(new TokenBucket(60, 60, 30, System.currentTimeMillis() + 10000)) + serializeToOldBucketValueFormat(60, 60, 30, System.currentTimeMillis() + 10000) )); assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 40)); // embedding an existing value in the old format redisCluster.useCluster(c -> c.sync().set( StaticRateLimiter.bucketName(descriptor.id(), "test1"), - serializeToOldBucketValueFormat(new TokenBucket(60, 60, 30, System.currentTimeMillis() + 10000)) + serializeToOldBucketValueFormat(60, 60, 30, System.currentTimeMillis() + 10000) )); rateLimiter.validate("test1", 20); assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test1", 20)); @@ -109,25 +110,21 @@ public class RateLimitersLuaScriptTest { } @Test - public void testLuaBucketConfigurationUpdates() throws Exception { - final String key = "key1"; - clock.setTimeMillis(0); - long result = (long) sandbox.execute( - List.of(key), - scriptArgs(1000, 1, 1, true), - redisCommandsHandler - ); - assertEquals(0L, result); - assertEquals(1000L, decodeBucket(key).orElseThrow().bucketSize); + public void testTtl() throws Exception { + final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; + final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); + final RateLimiters limiters = new RateLimiters( + Map.of(descriptor.id(), new RateLimiterConfig(1000, 60)), + dynamicConfig, + RateLimiters.defaultScript(redisCluster), + redisCluster, + Clock.systemUTC()); - // now making a check-only call, but changing the bucket size - result = (long) sandbox.execute( - List.of(key), - scriptArgs(2000, 1, 1, false), - redisCommandsHandler - ); - assertEquals(0L, result); - assertEquals(2000L, decodeBucket(key).orElseThrow().bucketSize); + final RateLimiter rateLimiter = limiters.forDescriptor(descriptor); + rateLimiter.validate("test", 200); + // after using 200 tokens, we expect 200 seconds to refill, so the TTL should be under 200000 + final long ttl = redisCluster.withCluster(c -> c.sync().ttl("test")); + assertTrue(ttl <= 200000); } @Test @@ -140,7 +137,7 @@ public class RateLimitersLuaScriptTest { redisCommandsHandler ); assertEquals(0L, result); - assertEquals(800L, decodeBucket(key).orElseThrow().spaceRemaining); + assertEquals(800L, decodeBucket(key).orElseThrow().tokensRemaining); // 50 tokens replenished, acquiring 100 more, should end up with 750 available clock.setTimeMillis(50); @@ -150,7 +147,7 @@ public class RateLimitersLuaScriptTest { redisCommandsHandler ); assertEquals(0L, result); - assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining); + assertEquals(750L, decodeBucket(key).orElseThrow().tokensRemaining); // now checking without an update, should not affect the count result = (long) sandbox.execute( @@ -159,16 +156,20 @@ public class RateLimitersLuaScriptTest { redisCommandsHandler ); assertEquals(0L, result); - assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining); + assertEquals(750L, decodeBucket(key).orElseThrow().tokensRemaining); } - private String serializeToOldBucketValueFormat(final TokenBucket bucket) { + private String serializeToOldBucketValueFormat( + final long bucketSize, + final long leakRatePerMillis, + final long spaceRemaining, + final long lastUpdateTimeMillis) { try { return SystemMapper.jsonMapper().writeValueAsString(Map.of( - "bucketSize", bucket.bucketSize, - "leakRatePerMillis", bucket.leakRatePerMillis, - "spaceRemaining", bucket.spaceRemaining, - "lastUpdateTimeMillis", bucket.lastUpdateTimeMillis + "bucketSize", bucketSize, + "leakRatePerMillis", leakRatePerMillis, + "spaceRemaining", spaceRemaining, + "lastUpdateTimeMillis", lastUpdateTimeMillis )); } catch (JsonProcessingException e) { throw new RuntimeException(e); @@ -176,14 +177,11 @@ public class RateLimitersLuaScriptTest { } private Optional decodeBucket(final String key) { - try { - final String json = redisCommandsHandler.get(key); - return json == null - ? Optional.empty() - : Optional.of(SystemMapper.jsonMapper().readValue(json, TokenBucket.class)); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } + final Object[] fields = redisCommandsHandler.hmget(key, List.of("s", "t")); + return fields[0] == null + ? Optional.empty() + : Optional.of(new TokenBucket( + Double.valueOf(fields[0].toString()).longValue(), Double.valueOf(fields[1].toString()).longValue())); } private List scriptArgs( @@ -200,6 +198,6 @@ public class RateLimitersLuaScriptTest { ); } - private record TokenBucket(long bucketSize, long leakRatePerMillis, long spaceRemaining, long lastUpdateTimeMillis) { + private record TokenBucket(long tokensRemaining, long lastUpdateTimeMillis) { } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java index f5617c1a1..8343722b1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail; import java.util.List; +import java.util.Locale; import java.util.Map; /** @@ -19,7 +20,7 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler { @Override public Object redisCommand(final String command, final List args) { - return switch (command) { + return switch (command.toUpperCase(Locale.ROOT)) { case "SET" -> { assertTrue(args.size() > 2); yield set(args.get(0).toString(), args.get(1).toString(), tail(args, 2)); @@ -33,13 +34,18 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler { yield del(args.stream().map(Object::toString).toList()); } case "HSET" -> { - assertTrue(args.size() >= 3); - yield hset(args.get(0).toString(), args.get(1).toString(), args.get(2).toString(), tail(args, 3)); + assertTrue(args.size() > 1); + assertTrue(args.size() % 2 == 1); + yield hset(args.get(0).toString(), tail(args, 1)); } case "HGET" -> { assertEquals(2, args.size()); yield hget(args.get(0).toString(), args.get(1).toString()); } + case "HMGET" -> { + assertTrue(args.size() > 1); + yield hmget(args.get(0).toString(), tail(args, 1)); + } case "PEXPIRE" -> { assertEquals(2, args.size()); yield pexpire(args.get(0).toString(), Double.valueOf(args.get(1).toString()).longValue(), tail(args, 2)); @@ -85,7 +91,7 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler { return 0; } - public Object hset(final String key, final String field, final String value, final List other) { + public Object hset(final String key, final List fieldsAndValues) { return "OK"; } @@ -93,6 +99,10 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler { return null; } + public Object[] hmget(final String key, final List fields) { + return new Object[fields.size()]; + } + public Object set(final String key, final String value, final List tail) { return "OK"; } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java index 75d48c07e..e99615711 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.util.redis; import java.time.Clock; +import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -46,13 +47,18 @@ public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler { @SuppressWarnings("unchecked") @Override - public Object hset(final String key, final String field, final String value, final List other) { + public Object hset(final String key, final List fieldsAndValues) { Map map = getIfNotExpired(key, Map.class); if (map == null) { map = new ConcurrentHashMap<>(); cache.put(key, new Entry(map, Long.MAX_VALUE)); } - map.put(field, value); + final Iterator iter = fieldsAndValues.iterator(); + while (iter.hasNext()) { + final Object k = iter.next(); + final Object v = iter.next(); + map.put(k, v); + } return "OK"; } @@ -62,6 +68,15 @@ public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler { return map == null ? null : map.get(field); } + @Override + public Object[] hmget(final String key, final List fields) { + final Object[] res = new Object[fields.size()]; + for (int i = 0; i < fields.size(); i++) { + res[i] = hget(key, fields.get(i).toString()); + } + return res; + } + @SuppressWarnings("unchecked") @Override public Object push(final boolean left, final String key, final List values) {