diff --git a/service/src/main/resources/lua/validate_rate_limit.lua b/service/src/main/resources/lua/validate_rate_limit.lua index d18c4e52c..3851089eb 100644 --- a/service/src/main/resources/lua/validate_rate_limit.lua +++ b/service/src/main/resources/lua/validate_rate_limit.lua @@ -21,30 +21,13 @@ local changesMade = false local tokensRemaining local lastUpdateTimeMillis --- while we're migrating from json to redis list key types, there are three possible options for the --- type of the `bucketId` key: "string" (legacy, json value), "list" (new format), "none" (key not set). --- --- On a separate note -- the reason we're not using a different key is because Redis Lua requires to list all keys --- as a script input and we don't want to expose this migration to the script users. --- --- Finally, it's okay to read the "ok" key of the return here because "TYPE" command always succeeds. -local keyType = redis.call("TYPE", bucketId)["ok"] -if keyType == "none" then - -- if the key is not set, building the object from the configuration - tokensRemaining = bucketSize - lastUpdateTimeMillis = currentTimeMillis -elseif keyType == "string" then - -- if the key is "string", we parse the value from json - local fromJson = cjson.decode(redis.call("GET", bucketId)) - tokensRemaining = fromJson.spaceRemaining - lastUpdateTimeMillis = fromJson.lastUpdateTimeMillis - redis.call("DEL", bucketId) - changesMade = true -elseif keyType == "hash" then - -- finally, reading values from the new storage format - local tokensRemainingStr, lastUpdateTimeMillisStr = unpack(redis.call("HMGET", bucketId, SIZE_FIELD, TIME_FIELD)) +local tokensRemainingStr, lastUpdateTimeMillisStr = unpack(redis.call("HMGET", bucketId, SIZE_FIELD, TIME_FIELD)) +if tokensRemainingStr and lastUpdateTimeMillisStr then tokensRemaining = tonumber(tokensRemainingStr) lastUpdateTimeMillis = tonumber(lastUpdateTimeMillisStr) +else + tokensRemaining = bucketSize + lastUpdateTimeMillis = currentTimeMillis end local elapsedTime = currentTimeMillis - lastUpdateTimeMillis diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java index d2c3098bb..6d3908470 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java @@ -69,50 +69,6 @@ public class RateLimitersLuaScriptTest { assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 25)); } - @Test - public void testFormatMigration() throws Exception { - final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; - final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); - final RateLimiters limiters = new RateLimiters( - Map.of(descriptor.id(), new RateLimiterConfig(60, Duration.ofSeconds(1))), - dynamicConfig, - RateLimiters.defaultScript(redisCluster), - redisCluster, - Clock.systemUTC()); - - final RateLimiter rateLimiter = limiters.forDescriptor(descriptor); - - // embedding an existing value in the old format - redisCluster.useCluster(c -> c.sync().set( - StaticRateLimiter.bucketName(descriptor.id(), "test"), - serializeToOldBucketValueFormat(60, 60, 30, System.currentTimeMillis() + 10000) - )); - assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 40)); - - // embedding an existing value in the old format - redisCluster.useCluster(c -> c.sync().set( - StaticRateLimiter.bucketName(descriptor.id(), "test1"), - serializeToOldBucketValueFormat(60, 60, 30, System.currentTimeMillis() + 10000) - )); - rateLimiter.validate("test1", 20); - assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test1", 20)); - - // embedding an existing value in the new format - redisCluster.useCluster(c -> c.sync().hset( - StaticRateLimiter.bucketName(descriptor.id(), "test2"), - Map.of("s", "30", "t", String.valueOf(System.currentTimeMillis() + 10000)) - )); - assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test2", 40)); - - // embedding an existing value in the new format - redisCluster.useCluster(c -> c.sync().hset( - StaticRateLimiter.bucketName(descriptor.id(), "test3"), - Map.of("s", "30", "t", String.valueOf(System.currentTimeMillis() + 10000)) - )); - rateLimiter.validate("test3", 20); - assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test3", 20)); - } - @Test public void testTtl() throws Exception { final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;