migrate token bucket redis record format from json to hash: phase 2

This commit is contained in:
Sergey Skrobotov 2023-03-16 01:41:48 -07:00
parent a04fe133b6
commit c14ef7e6cf
4 changed files with 71 additions and 60 deletions

View File

@ -24,10 +24,6 @@ local lastUpdateTimeMillis
-- while we're migrating from json to redis list key types, there are three possible options for the -- while we're migrating from json to redis list key types, there are three possible options for the
-- type of the `bucketId` key: "string" (legacy, json value), "list" (new format), "none" (key not set). -- type of the `bucketId` key: "string" (legacy, json value), "list" (new format), "none" (key not set).
-- --
-- In the phase 1 of migration, we prepare the script to deal with the phase 2 :) I.e. when phase 2 will be rolling out,
-- it will start writing data in the new format, and the still running instances of the previous version
-- need to be able to know how to read the new format before we start writing it.
--
-- On a separate note -- the reason we're not using a different key is because Redis Lua requires to list all keys -- On a separate note -- the reason we're not using a different key is because Redis Lua requires to list all keys
-- as a script input and we don't want to expose this migration to the script users. -- as a script input and we don't want to expose this migration to the script users.
-- --
@ -40,18 +36,15 @@ if keyType == "none" then
elseif keyType == "string" then elseif keyType == "string" then
-- if the key is "string", we parse the value from json -- if the key is "string", we parse the value from json
local fromJson = cjson.decode(redis.call("GET", bucketId)) local fromJson = cjson.decode(redis.call("GET", bucketId))
if bucketSize ~= fromJson.bucketSize or refillRatePerMillis ~= fromJson.leakRatePerMillis then
changesMade = true
end
tokensRemaining = fromJson.spaceRemaining tokensRemaining = fromJson.spaceRemaining
lastUpdateTimeMillis = fromJson.lastUpdateTimeMillis lastUpdateTimeMillis = fromJson.lastUpdateTimeMillis
redis.call("DEL", bucketId)
changesMade = true
elseif keyType == "hash" then elseif keyType == "hash" then
-- finally, reading values from the new storage format -- finally, reading values from the new storage format
local tokensRemainingStr, lastUpdateTimeMillisStr = unpack(redis.call("HMGET", bucketId, SIZE_FIELD, TIME_FIELD)) local tokensRemainingStr, lastUpdateTimeMillisStr = unpack(redis.call("HMGET", bucketId, SIZE_FIELD, TIME_FIELD))
tokensRemaining = tonumber(tokensRemainingStr) tokensRemaining = tonumber(tokensRemainingStr)
lastUpdateTimeMillis = tonumber(lastUpdateTimeMillisStr) lastUpdateTimeMillis = tonumber(lastUpdateTimeMillisStr)
redis.call("DEL", bucketId)
changesMade = true
end end
local elapsedTime = currentTimeMillis - lastUpdateTimeMillis local elapsedTime = currentTimeMillis - lastUpdateTimeMillis
@ -68,19 +61,14 @@ if availableAmount >= requestedAmount then
end end
if changesMade then if changesMade then
local tokensUsed = bucketSize - tokensRemaining local tokensUsed = bucketSize - tokensRemaining
-- Storing a 'full' bucket is equivalent of not storing any state at all -- Storing a 'full' bucket (i.e. tokensUsed == 0) is equivalent of not storing any state at all
-- (in which case a bucket will be just initialized from the input configs as a 'full' one). -- (in which case a bucket will be just initialized from the input configs as a 'full' one).
-- For this reason, we either set an expiration time on the record (calculated to let the bucket fully replenish) -- For this reason, we either set an expiration time on the record (calculated to let the bucket fully replenish)
-- or we just delete the key if the bucket is full. -- or we just delete the key if the bucket is full.
if tokensUsed > 0 then if tokensUsed > 0 then
local ttlMillis = math.ceil(tokensUsed / refillRatePerMillis) local ttlMillis = math.ceil(tokensUsed / refillRatePerMillis)
local tokenBucket = { redis.call("HSET", bucketId, SIZE_FIELD, tokensRemaining, TIME_FIELD, lastUpdateTimeMillis)
["bucketSize"] = bucketSize, redis.call("PEXPIRE", bucketId, ttlMillis)
["leakRatePerMillis"] = refillRatePerMillis,
["spaceRemaining"] = tokensRemaining,
["lastUpdateTimeMillis"] = lastUpdateTimeMillis
}
redis.call("SET", bucketId, cjson.encode(tokenBucket), "PX", ttlMillis)
else else
redis.call("DEL", bucketId) redis.call("DEL", bucketId)
end end

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -80,14 +81,14 @@ public class RateLimitersLuaScriptTest {
// embedding an existing value in the old format // embedding an existing value in the old format
redisCluster.useCluster(c -> c.sync().set( redisCluster.useCluster(c -> c.sync().set(
StaticRateLimiter.bucketName(descriptor.id(), "test"), StaticRateLimiter.bucketName(descriptor.id(), "test"),
serializeToOldBucketValueFormat(new TokenBucket(60, 60, 30, System.currentTimeMillis() + 10000)) serializeToOldBucketValueFormat(60, 60, 30, System.currentTimeMillis() + 10000)
)); ));
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 40)); assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 40));
// embedding an existing value in the old format // embedding an existing value in the old format
redisCluster.useCluster(c -> c.sync().set( redisCluster.useCluster(c -> c.sync().set(
StaticRateLimiter.bucketName(descriptor.id(), "test1"), StaticRateLimiter.bucketName(descriptor.id(), "test1"),
serializeToOldBucketValueFormat(new TokenBucket(60, 60, 30, System.currentTimeMillis() + 10000)) serializeToOldBucketValueFormat(60, 60, 30, System.currentTimeMillis() + 10000)
)); ));
rateLimiter.validate("test1", 20); rateLimiter.validate("test1", 20);
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test1", 20)); assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test1", 20));
@ -109,25 +110,21 @@ public class RateLimitersLuaScriptTest {
} }
@Test @Test
public void testLuaBucketConfigurationUpdates() throws Exception { public void testTtl() throws Exception {
final String key = "key1"; final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
clock.setTimeMillis(0); final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
long result = (long) sandbox.execute( final RateLimiters limiters = new RateLimiters(
List.of(key), Map.of(descriptor.id(), new RateLimiterConfig(1000, 60)),
scriptArgs(1000, 1, 1, true), dynamicConfig,
redisCommandsHandler RateLimiters.defaultScript(redisCluster),
); redisCluster,
assertEquals(0L, result); Clock.systemUTC());
assertEquals(1000L, decodeBucket(key).orElseThrow().bucketSize);
// now making a check-only call, but changing the bucket size final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
result = (long) sandbox.execute( rateLimiter.validate("test", 200);
List.of(key), // after using 200 tokens, we expect 200 seconds to refill, so the TTL should be under 200000
scriptArgs(2000, 1, 1, false), final long ttl = redisCluster.withCluster(c -> c.sync().ttl("test"));
redisCommandsHandler assertTrue(ttl <= 200000);
);
assertEquals(0L, result);
assertEquals(2000L, decodeBucket(key).orElseThrow().bucketSize);
} }
@Test @Test
@ -140,7 +137,7 @@ public class RateLimitersLuaScriptTest {
redisCommandsHandler redisCommandsHandler
); );
assertEquals(0L, result); assertEquals(0L, result);
assertEquals(800L, decodeBucket(key).orElseThrow().spaceRemaining); assertEquals(800L, decodeBucket(key).orElseThrow().tokensRemaining);
// 50 tokens replenished, acquiring 100 more, should end up with 750 available // 50 tokens replenished, acquiring 100 more, should end up with 750 available
clock.setTimeMillis(50); clock.setTimeMillis(50);
@ -150,7 +147,7 @@ public class RateLimitersLuaScriptTest {
redisCommandsHandler redisCommandsHandler
); );
assertEquals(0L, result); assertEquals(0L, result);
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining); assertEquals(750L, decodeBucket(key).orElseThrow().tokensRemaining);
// now checking without an update, should not affect the count // now checking without an update, should not affect the count
result = (long) sandbox.execute( result = (long) sandbox.execute(
@ -159,16 +156,20 @@ public class RateLimitersLuaScriptTest {
redisCommandsHandler redisCommandsHandler
); );
assertEquals(0L, result); assertEquals(0L, result);
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining); assertEquals(750L, decodeBucket(key).orElseThrow().tokensRemaining);
} }
private String serializeToOldBucketValueFormat(final TokenBucket bucket) { private String serializeToOldBucketValueFormat(
final long bucketSize,
final long leakRatePerMillis,
final long spaceRemaining,
final long lastUpdateTimeMillis) {
try { try {
return SystemMapper.jsonMapper().writeValueAsString(Map.of( return SystemMapper.jsonMapper().writeValueAsString(Map.of(
"bucketSize", bucket.bucketSize, "bucketSize", bucketSize,
"leakRatePerMillis", bucket.leakRatePerMillis, "leakRatePerMillis", leakRatePerMillis,
"spaceRemaining", bucket.spaceRemaining, "spaceRemaining", spaceRemaining,
"lastUpdateTimeMillis", bucket.lastUpdateTimeMillis "lastUpdateTimeMillis", lastUpdateTimeMillis
)); ));
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
@ -176,14 +177,11 @@ public class RateLimitersLuaScriptTest {
} }
private Optional<TokenBucket> decodeBucket(final String key) { private Optional<TokenBucket> decodeBucket(final String key) {
try { final Object[] fields = redisCommandsHandler.hmget(key, List.of("s", "t"));
final String json = redisCommandsHandler.get(key); return fields[0] == null
return json == null ? Optional.empty()
? Optional.empty() : Optional.of(new TokenBucket(
: Optional.of(SystemMapper.jsonMapper().readValue(json, TokenBucket.class)); Double.valueOf(fields[0].toString()).longValue(), Double.valueOf(fields[1].toString()).longValue()));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
} }
private List<String> scriptArgs( private List<String> scriptArgs(
@ -200,6 +198,6 @@ public class RateLimitersLuaScriptTest {
); );
} }
private record TokenBucket(long bucketSize, long leakRatePerMillis, long spaceRemaining, long lastUpdateTimeMillis) { private record TokenBucket(long tokensRemaining, long lastUpdateTimeMillis) {
} }
} }

View File

@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail; import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Map; import java.util.Map;
/** /**
@ -19,7 +20,7 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler {
@Override @Override
public Object redisCommand(final String command, final List<Object> args) { public Object redisCommand(final String command, final List<Object> args) {
return switch (command) { return switch (command.toUpperCase(Locale.ROOT)) {
case "SET" -> { case "SET" -> {
assertTrue(args.size() > 2); assertTrue(args.size() > 2);
yield set(args.get(0).toString(), args.get(1).toString(), tail(args, 2)); yield set(args.get(0).toString(), args.get(1).toString(), tail(args, 2));
@ -33,13 +34,18 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler {
yield del(args.stream().map(Object::toString).toList()); yield del(args.stream().map(Object::toString).toList());
} }
case "HSET" -> { case "HSET" -> {
assertTrue(args.size() >= 3); assertTrue(args.size() > 1);
yield hset(args.get(0).toString(), args.get(1).toString(), args.get(2).toString(), tail(args, 3)); assertTrue(args.size() % 2 == 1);
yield hset(args.get(0).toString(), tail(args, 1));
} }
case "HGET" -> { case "HGET" -> {
assertEquals(2, args.size()); assertEquals(2, args.size());
yield hget(args.get(0).toString(), args.get(1).toString()); yield hget(args.get(0).toString(), args.get(1).toString());
} }
case "HMGET" -> {
assertTrue(args.size() > 1);
yield hmget(args.get(0).toString(), tail(args, 1));
}
case "PEXPIRE" -> { case "PEXPIRE" -> {
assertEquals(2, args.size()); assertEquals(2, args.size());
yield pexpire(args.get(0).toString(), Double.valueOf(args.get(1).toString()).longValue(), tail(args, 2)); yield pexpire(args.get(0).toString(), Double.valueOf(args.get(1).toString()).longValue(), tail(args, 2));
@ -85,7 +91,7 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler {
return 0; return 0;
} }
public Object hset(final String key, final String field, final String value, final List<Object> other) { public Object hset(final String key, final List<Object> fieldsAndValues) {
return "OK"; return "OK";
} }
@ -93,6 +99,10 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler {
return null; return null;
} }
public Object[] hmget(final String key, final List<Object> fields) {
return new Object[fields.size()];
}
public Object set(final String key, final String value, final List<Object> tail) { public Object set(final String key, final String value, final List<Object> tail) {
return "OK"; return "OK";
} }

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.util.redis; package org.whispersystems.textsecuregcm.util.redis;
import java.time.Clock; import java.time.Clock;
import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -46,13 +47,18 @@ public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override @Override
public Object hset(final String key, final String field, final String value, final List<Object> other) { public Object hset(final String key, final List<Object> fieldsAndValues) {
Map<Object, Object> map = getIfNotExpired(key, Map.class); Map<Object, Object> map = getIfNotExpired(key, Map.class);
if (map == null) { if (map == null) {
map = new ConcurrentHashMap<>(); map = new ConcurrentHashMap<>();
cache.put(key, new Entry(map, Long.MAX_VALUE)); cache.put(key, new Entry(map, Long.MAX_VALUE));
} }
map.put(field, value); final Iterator<Object> iter = fieldsAndValues.iterator();
while (iter.hasNext()) {
final Object k = iter.next();
final Object v = iter.next();
map.put(k, v);
}
return "OK"; return "OK";
} }
@ -62,6 +68,15 @@ public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler {
return map == null ? null : map.get(field); return map == null ? null : map.get(field);
} }
@Override
public Object[] hmget(final String key, final List<Object> fields) {
final Object[] res = new Object[fields.size()];
for (int i = 0; i < fields.size(); i++) {
res[i] = hget(key, fields.get(i).toString());
}
return res;
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override @Override
public Object push(final boolean left, final String key, final List<Object> values) { public Object push(final boolean left, final String key, final List<Object> values) {