migrate token bucket redis record format from json to hash: phase 1

This commit is contained in:
Sergey Skrobotov 2023-03-15 14:32:14 -07:00
parent ebf8aa7b15
commit 483e444174
6 changed files with 277 additions and 43 deletions

View File

@ -8,6 +8,7 @@ import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.time.Clock;
@ -88,12 +89,12 @@ public class StaticRateLimiter implements RateLimiter {
@Override
public void clear(final String key) {
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(key)));
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
}
@Override
public CompletionStage<Void> clearAsync(final String key) {
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(key)))
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
.thenRun(Util.NOOP);
}
@ -103,7 +104,7 @@ public class StaticRateLimiter implements RateLimiter {
}
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(key));
final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
@ -115,7 +116,7 @@ public class StaticRateLimiter implements RateLimiter {
}
private CompletionStage<Long> executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(key));
final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
@ -126,7 +127,8 @@ public class StaticRateLimiter implements RateLimiter {
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
}
private String bucketName(final String key) {
@VisibleForTesting
protected static String bucketName(final String name, final String key) {
return "leaky_bucket::" + name + "::" + key;
}
}

View File

@ -14,48 +14,72 @@ local currentTimeMillis = tonumber(ARGV[3])
local requestedAmount = tonumber(ARGV[4])
local useTokens = ARGV[5] and string.lower(ARGV[5]) == "true"
local tokenBucketJson = redis.call("GET", bucketId)
local tokenBucket
local SIZE_FIELD = "s"
local TIME_FIELD = "t"
local changesMade = false
local tokensRemaining
local lastUpdateTimeMillis
if tokenBucketJson then
tokenBucket = cjson.decode(tokenBucketJson)
else
tokenBucket = {
["bucketSize"] = bucketSize,
["leakRatePerMillis"] = refillRatePerMillis,
["spaceRemaining"] = bucketSize,
["lastUpdateTimeMillis"] = currentTimeMillis
}
end
-- this can happen if rate limiter configuration has changed while the key is still in Redis
if tokenBucket["bucketSize"] ~= bucketSize or tokenBucket["leakRatePerMillis"] ~= refillRatePerMillis then
tokenBucket["bucketSize"] = bucketSize
tokenBucket["leakRatePerMillis"] = refillRatePerMillis
-- while we're migrating from json to redis list key types, there are three possible options for the
-- type of the `bucketId` key: "string" (legacy, json value), "list" (new format), "none" (key not set).
--
-- In the phase 1 of migration, we prepare the script to deal with the phase 2 :) I.e. when phase 2 will be rolling out,
-- it will start writing data in the new format, and the still running instances of the previous version
-- need to be able to know how to read the new format before we start writing it.
--
-- On a separate note -- the reason we're not using a different key is because Redis Lua requires to list all keys
-- as a script input and we don't want to expose this migration to the script users.
--
-- Finally, it's okay to read the "ok" key of the return here because "TYPE" command always succeeds.
local keyType = redis.call("TYPE", bucketId)["ok"]
if keyType == "none" then
-- if the key is not set, building the object from the configuration
tokensRemaining = bucketSize
lastUpdateTimeMillis = currentTimeMillis
elseif keyType == "string" then
-- if the key is "string", we parse the value from json
local fromJson = cjson.decode(redis.call("GET", bucketId))
if bucketSize ~= fromJson.bucketSize or refillRatePerMillis ~= fromJson.leakRatePerMillis then
changesMade = true
end
tokensRemaining = fromJson.spaceRemaining
lastUpdateTimeMillis = fromJson.lastUpdateTimeMillis
elseif keyType == "hash" then
-- finally, reading values from the new storage format
local tokensRemainingStr, lastUpdateTimeMillisStr = unpack(redis.call("HMGET", bucketId, SIZE_FIELD, TIME_FIELD))
tokensRemaining = tonumber(tokensRemainingStr)
lastUpdateTimeMillis = tonumber(lastUpdateTimeMillisStr)
redis.call("DEL", bucketId)
changesMade = true
end
local elapsedTime = currentTimeMillis - tokenBucket["lastUpdateTimeMillis"]
local elapsedTime = currentTimeMillis - lastUpdateTimeMillis
local availableAmount = math.min(
tokenBucket["bucketSize"],
math.floor(tokenBucket["spaceRemaining"] + (elapsedTime * tokenBucket["leakRatePerMillis"]))
bucketSize,
math.floor(tokensRemaining + (elapsedTime * refillRatePerMillis))
)
if availableAmount >= requestedAmount then
if useTokens then
tokenBucket["spaceRemaining"] = availableAmount - requestedAmount
tokenBucket["lastUpdateTimeMillis"] = currentTimeMillis
tokensRemaining = availableAmount - requestedAmount
lastUpdateTimeMillis = currentTimeMillis
changesMade = true
end
if changesMade then
local tokensUsed = tokenBucket["bucketSize"] - tokenBucket["spaceRemaining"]
local tokensUsed = bucketSize - tokensRemaining
-- Storing a 'full' bucket is equivalent of not storing any state at all
-- (in which case a bucket will be just initialized from the input configs as a 'full' one).
-- For this reason, we either set an expiration time on the record (calculated to let the bucket fully replenish)
-- or we just delete the key if the bucket is full.
if tokensUsed > 0 then
local ttlMillis = math.ceil(tokensUsed / tokenBucket["leakRatePerMillis"])
local ttlMillis = math.ceil(tokensUsed / refillRatePerMillis)
local tokenBucket = {
["bucketSize"] = bucketSize,
["leakRatePerMillis"] = refillRatePerMillis,
["spaceRemaining"] = tokensRemaining,
["lastUpdateTimeMillis"] = lastUpdateTimeMillis
}
redis.call("SET", bucketId, cjson.encode(tokenBucket), "PX", ttlMillis)
else
redis.call("DEL", bucketId)

View File

@ -19,6 +19,7 @@ import java.util.Optional;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
@ -60,7 +61,51 @@ public class RateLimitersLuaScriptTest {
final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
rateLimiter.validate("test", 25);
rateLimiter.validate("test", 25);
assertThrows(Exception.class, () -> rateLimiter.validate("test", 25));
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 25));
}
@Test
public void testFormatMigration() throws Exception {
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
final RateLimiters limiters = new RateLimiters(
Map.of(descriptor.id(), new RateLimiterConfig(60, 60)),
dynamicConfig,
RateLimiters.defaultScript(redisCluster),
redisCluster,
Clock.systemUTC());
final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
// embedding an existing value in the old format
redisCluster.useCluster(c -> c.sync().set(
StaticRateLimiter.bucketName(descriptor.id(), "test"),
serializeToOldBucketValueFormat(new TokenBucket(60, 60, 30, System.currentTimeMillis() + 10000))
));
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 40));
// embedding an existing value in the old format
redisCluster.useCluster(c -> c.sync().set(
StaticRateLimiter.bucketName(descriptor.id(), "test1"),
serializeToOldBucketValueFormat(new TokenBucket(60, 60, 30, System.currentTimeMillis() + 10000))
));
rateLimiter.validate("test1", 20);
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test1", 20));
// embedding an existing value in the new format
redisCluster.useCluster(c -> c.sync().hset(
StaticRateLimiter.bucketName(descriptor.id(), "test2"),
Map.of("s", "30", "t", String.valueOf(System.currentTimeMillis() + 10000))
));
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test2", 40));
// embedding an existing value in the new format
redisCluster.useCluster(c -> c.sync().hset(
StaticRateLimiter.bucketName(descriptor.id(), "test3"),
Map.of("s", "30", "t", String.valueOf(System.currentTimeMillis() + 10000))
));
rateLimiter.validate("test3", 20);
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test3", 20));
}
@Test
@ -117,6 +162,19 @@ public class RateLimitersLuaScriptTest {
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
}
private String serializeToOldBucketValueFormat(final TokenBucket bucket) {
try {
return SystemMapper.jsonMapper().writeValueAsString(Map.of(
"bucketSize", bucket.bucketSize,
"leakRatePerMillis", bucket.leakRatePerMillis,
"spaceRemaining", bucket.spaceRemaining,
"lastUpdateTimeMillis", bucket.lastUpdateTimeMillis
));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
private Optional<TokenBucket> decodeBucket(final String key) {
try {
final String json = redisCommandsHandler.get(key);

View File

@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail;
import java.util.List;
import java.util.Map;
/**
* This class is to be extended with implementations of Redis commands as needed.
@ -28,23 +29,79 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler {
yield get(args.get(0).toString());
}
case "DEL" -> {
assertTrue(args.size() > 1);
yield del(args.get(0).toString());
assertTrue(args.size() >= 1);
yield del(args.stream().map(Object::toString).toList());
}
case "HSET" -> {
assertTrue(args.size() >= 3);
yield hset(args.get(0).toString(), args.get(1).toString(), args.get(2).toString(), tail(args, 3));
}
case "HGET" -> {
assertEquals(2, args.size());
yield hget(args.get(0).toString(), args.get(1).toString());
}
case "PEXPIRE" -> {
assertEquals(2, args.size());
yield pexpire(args.get(0).toString(), Double.valueOf(args.get(1).toString()).longValue(), tail(args, 2));
}
case "TYPE" -> {
assertEquals(1, args.size());
yield type(args.get(0).toString());
}
case "RPUSH" -> {
assertTrue(args.size() > 1);
yield push(false, args.get(0).toString(), tail(args, 1));
}
case "LPUSH" -> {
assertTrue(args.size() > 1);
yield push(true, args.get(0).toString(), tail(args, 1));
}
case "RPOP" -> {
assertEquals(2, args.size());
yield pop(false, args.get(0).toString(), Double.valueOf(args.get(1).toString()).intValue());
}
case "LPOP" -> {
assertEquals(2, args.size());
yield pop(true, args.get(0).toString(), Double.valueOf(args.get(1).toString()).intValue());
}
default -> other(command, args);
};
}
public Object[] pop(final boolean left, final String key, final int count) {
return new Object[count];
}
public Object push(final boolean left, final String key, final List<Object> values) {
return 0;
}
public Object type(final String key) {
return Map.of("ok", "none");
}
public Object pexpire(final String key, final long ttlMillis, final List<Object> args) {
return 0;
}
public Object hset(final String key, final String field, final String value, final List<Object> other) {
return "OK";
}
public Object hget(final String key, final String field) {
return null;
}
public Object set(final String key, final String value, final List<Object> tail) {
return "OK";
}
public String get(final String key) {
return null;
}
public int del(final String key) {
public int del(final List<String> keys) {
return 0;
}

View File

@ -28,7 +28,12 @@ public class RedisLuaScriptSandbox {
function redis_call(...)
-- variable name needs to match the one used in the `L.setGlobal()` call
-- method name needs to match method name of the Java class
return proxy:redisCall(arg)
local result = proxy:redisCall(arg)
if type(result) == "userdata" then
return java.luaify(result)
else
return result
end
end
function json_encode(obj)

View File

@ -6,13 +6,15 @@
package org.whispersystems.textsecuregcm.util.redis;
import java.time.Clock;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;
public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler {
public record Entry(String value, long expirationEpochMillis) {
public record Entry(Object value, long expirationEpochMillis) {
}
private final Map<String, Entry> cache = new ConcurrentHashMap<>();
@ -32,20 +34,106 @@ public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler {
@Override
public String get(final String key) {
return getIfNotExpired(key, String.class);
}
@Override
public int del(final List<String> key) {
return key.stream()
.mapToInt(k -> cache.remove(k) != null ? 1 : 0)
.sum();
}
@SuppressWarnings("unchecked")
@Override
public Object hset(final String key, final String field, final String value, final List<Object> other) {
Map<Object, Object> map = getIfNotExpired(key, Map.class);
if (map == null) {
map = new ConcurrentHashMap<>();
cache.put(key, new Entry(map, Long.MAX_VALUE));
}
map.put(field, value);
return "OK";
}
@Override
public Object hget(final String key, final String field) {
final Map<?, ?> map = getIfNotExpired(key, Map.class);
return map == null ? null : map.get(field);
}
@SuppressWarnings("unchecked")
@Override
public Object push(final boolean left, final String key, final List<Object> values) {
LinkedList<Object> list = getIfNotExpired(key, LinkedList.class);
if (list == null) {
list = new LinkedList<>();
cache.put(key, new Entry(list, Long.MAX_VALUE));
}
for (Object v: values) {
if (left) {
list.addFirst(v.toString());
} else {
list.addLast(v.toString());
}
}
return list.size();
}
@SuppressWarnings("unchecked")
@Override
public Object[] pop(final boolean left, final String key, final int count) {
final Object[] result = new String[count];
final LinkedList<Object> list = getIfNotExpired(key, LinkedList.class);
if (list == null) {
return result;
}
for (int i = 0; i < Math.min(count, list.size()); i++) {
result[i] = left ? list.removeFirst() : list.removeLast();
}
return result;
}
@Override
public Object pexpire(final String key, final long ttlMillis, final List<Object> args) {
final Entry e = cache.get(key);
if (e == null) {
return 0;
}
final Entry updated = new Entry(e.value(), clock.millis() + ttlMillis);
cache.put(key, updated);
return 1;
}
@Override
public Object type(final String key) {
final Object o = getIfNotExpired(key, Object.class);
final String type;
if (o == null) {
type = "none";
} else if (o.getClass() == String.class) {
type = "string";
} else if (Map.class.isAssignableFrom(o.getClass())) {
type = "hash";
} else if (List.class.isAssignableFrom(o.getClass())) {
type = "list";
} else {
throw new IllegalArgumentException("Unsupported value type: " + o.getClass());
}
return Map.of("ok", type);
}
@Nullable
protected <T> T getIfNotExpired(final String key, final Class<T> expectedType) {
final Entry entry = cache.get(key);
if (entry == null) {
return null;
}
if (entry.expirationEpochMillis() < clock.millis()) {
del(key);
del(List.of(key));
return null;
}
return entry.value();
}
@Override
public int del(final String key) {
return cache.remove(key) != null ? 1 : 0;
return expectedType.cast(entry.value());
}
protected long resolveExpirationEpochMillis(final List<Object> args) {