migrate token bucket redis record format from json to hash: phase 1
This commit is contained in:
parent
ebf8aa7b15
commit
483e444174
|
@ -8,6 +8,7 @@ import static java.util.Objects.requireNonNull;
|
|||
import static java.util.concurrent.CompletableFuture.completedFuture;
|
||||
import static java.util.concurrent.CompletableFuture.failedFuture;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.time.Clock;
|
||||
|
@ -88,12 +89,12 @@ public class StaticRateLimiter implements RateLimiter {
|
|||
|
||||
@Override
|
||||
public void clear(final String key) {
|
||||
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(key)));
|
||||
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> clearAsync(final String key) {
|
||||
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(key)))
|
||||
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
|
||||
.thenRun(Util.NOOP);
|
||||
}
|
||||
|
||||
|
@ -103,7 +104,7 @@ public class StaticRateLimiter implements RateLimiter {
|
|||
}
|
||||
|
||||
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(key));
|
||||
final List<String> keys = List.of(bucketName(name, key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
|
@ -115,7 +116,7 @@ public class StaticRateLimiter implements RateLimiter {
|
|||
}
|
||||
|
||||
private CompletionStage<Long> executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(key));
|
||||
final List<String> keys = List.of(bucketName(name, key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
|
@ -126,7 +127,8 @@ public class StaticRateLimiter implements RateLimiter {
|
|||
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
|
||||
}
|
||||
|
||||
private String bucketName(final String key) {
|
||||
@VisibleForTesting
|
||||
protected static String bucketName(final String name, final String key) {
|
||||
return "leaky_bucket::" + name + "::" + key;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,48 +14,72 @@ local currentTimeMillis = tonumber(ARGV[3])
|
|||
local requestedAmount = tonumber(ARGV[4])
|
||||
local useTokens = ARGV[5] and string.lower(ARGV[5]) == "true"
|
||||
|
||||
local tokenBucketJson = redis.call("GET", bucketId)
|
||||
local tokenBucket
|
||||
local SIZE_FIELD = "s"
|
||||
local TIME_FIELD = "t"
|
||||
|
||||
local changesMade = false
|
||||
local tokensRemaining
|
||||
local lastUpdateTimeMillis
|
||||
|
||||
if tokenBucketJson then
|
||||
tokenBucket = cjson.decode(tokenBucketJson)
|
||||
else
|
||||
tokenBucket = {
|
||||
["bucketSize"] = bucketSize,
|
||||
["leakRatePerMillis"] = refillRatePerMillis,
|
||||
["spaceRemaining"] = bucketSize,
|
||||
["lastUpdateTimeMillis"] = currentTimeMillis
|
||||
}
|
||||
end
|
||||
|
||||
-- this can happen if rate limiter configuration has changed while the key is still in Redis
|
||||
if tokenBucket["bucketSize"] ~= bucketSize or tokenBucket["leakRatePerMillis"] ~= refillRatePerMillis then
|
||||
tokenBucket["bucketSize"] = bucketSize
|
||||
tokenBucket["leakRatePerMillis"] = refillRatePerMillis
|
||||
-- while we're migrating from json to redis list key types, there are three possible options for the
|
||||
-- type of the `bucketId` key: "string" (legacy, json value), "list" (new format), "none" (key not set).
|
||||
--
|
||||
-- In the phase 1 of migration, we prepare the script to deal with the phase 2 :) I.e. when phase 2 will be rolling out,
|
||||
-- it will start writing data in the new format, and the still running instances of the previous version
|
||||
-- need to be able to know how to read the new format before we start writing it.
|
||||
--
|
||||
-- On a separate note -- the reason we're not using a different key is because Redis Lua requires to list all keys
|
||||
-- as a script input and we don't want to expose this migration to the script users.
|
||||
--
|
||||
-- Finally, it's okay to read the "ok" key of the return here because "TYPE" command always succeeds.
|
||||
local keyType = redis.call("TYPE", bucketId)["ok"]
|
||||
if keyType == "none" then
|
||||
-- if the key is not set, building the object from the configuration
|
||||
tokensRemaining = bucketSize
|
||||
lastUpdateTimeMillis = currentTimeMillis
|
||||
elseif keyType == "string" then
|
||||
-- if the key is "string", we parse the value from json
|
||||
local fromJson = cjson.decode(redis.call("GET", bucketId))
|
||||
if bucketSize ~= fromJson.bucketSize or refillRatePerMillis ~= fromJson.leakRatePerMillis then
|
||||
changesMade = true
|
||||
end
|
||||
tokensRemaining = fromJson.spaceRemaining
|
||||
lastUpdateTimeMillis = fromJson.lastUpdateTimeMillis
|
||||
elseif keyType == "hash" then
|
||||
-- finally, reading values from the new storage format
|
||||
local tokensRemainingStr, lastUpdateTimeMillisStr = unpack(redis.call("HMGET", bucketId, SIZE_FIELD, TIME_FIELD))
|
||||
tokensRemaining = tonumber(tokensRemainingStr)
|
||||
lastUpdateTimeMillis = tonumber(lastUpdateTimeMillisStr)
|
||||
redis.call("DEL", bucketId)
|
||||
changesMade = true
|
||||
end
|
||||
|
||||
local elapsedTime = currentTimeMillis - tokenBucket["lastUpdateTimeMillis"]
|
||||
local elapsedTime = currentTimeMillis - lastUpdateTimeMillis
|
||||
local availableAmount = math.min(
|
||||
tokenBucket["bucketSize"],
|
||||
math.floor(tokenBucket["spaceRemaining"] + (elapsedTime * tokenBucket["leakRatePerMillis"]))
|
||||
bucketSize,
|
||||
math.floor(tokensRemaining + (elapsedTime * refillRatePerMillis))
|
||||
)
|
||||
|
||||
if availableAmount >= requestedAmount then
|
||||
if useTokens then
|
||||
tokenBucket["spaceRemaining"] = availableAmount - requestedAmount
|
||||
tokenBucket["lastUpdateTimeMillis"] = currentTimeMillis
|
||||
tokensRemaining = availableAmount - requestedAmount
|
||||
lastUpdateTimeMillis = currentTimeMillis
|
||||
changesMade = true
|
||||
end
|
||||
if changesMade then
|
||||
local tokensUsed = tokenBucket["bucketSize"] - tokenBucket["spaceRemaining"]
|
||||
local tokensUsed = bucketSize - tokensRemaining
|
||||
-- Storing a 'full' bucket is equivalent of not storing any state at all
|
||||
-- (in which case a bucket will be just initialized from the input configs as a 'full' one).
|
||||
-- For this reason, we either set an expiration time on the record (calculated to let the bucket fully replenish)
|
||||
-- or we just delete the key if the bucket is full.
|
||||
if tokensUsed > 0 then
|
||||
local ttlMillis = math.ceil(tokensUsed / tokenBucket["leakRatePerMillis"])
|
||||
local ttlMillis = math.ceil(tokensUsed / refillRatePerMillis)
|
||||
local tokenBucket = {
|
||||
["bucketSize"] = bucketSize,
|
||||
["leakRatePerMillis"] = refillRatePerMillis,
|
||||
["spaceRemaining"] = tokensRemaining,
|
||||
["lastUpdateTimeMillis"] = lastUpdateTimeMillis
|
||||
}
|
||||
redis.call("SET", bucketId, cjson.encode(tokenBucket), "PX", ttlMillis)
|
||||
else
|
||||
redis.call("DEL", bucketId)
|
||||
|
|
|
@ -19,6 +19,7 @@ import java.util.Optional;
|
|||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
|
@ -60,7 +61,51 @@ public class RateLimitersLuaScriptTest {
|
|||
final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
|
||||
rateLimiter.validate("test", 25);
|
||||
rateLimiter.validate("test", 25);
|
||||
assertThrows(Exception.class, () -> rateLimiter.validate("test", 25));
|
||||
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 25));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFormatMigration() throws Exception {
|
||||
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
|
||||
final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
|
||||
final RateLimiters limiters = new RateLimiters(
|
||||
Map.of(descriptor.id(), new RateLimiterConfig(60, 60)),
|
||||
dynamicConfig,
|
||||
RateLimiters.defaultScript(redisCluster),
|
||||
redisCluster,
|
||||
Clock.systemUTC());
|
||||
|
||||
final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
|
||||
|
||||
// embedding an existing value in the old format
|
||||
redisCluster.useCluster(c -> c.sync().set(
|
||||
StaticRateLimiter.bucketName(descriptor.id(), "test"),
|
||||
serializeToOldBucketValueFormat(new TokenBucket(60, 60, 30, System.currentTimeMillis() + 10000))
|
||||
));
|
||||
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test", 40));
|
||||
|
||||
// embedding an existing value in the old format
|
||||
redisCluster.useCluster(c -> c.sync().set(
|
||||
StaticRateLimiter.bucketName(descriptor.id(), "test1"),
|
||||
serializeToOldBucketValueFormat(new TokenBucket(60, 60, 30, System.currentTimeMillis() + 10000))
|
||||
));
|
||||
rateLimiter.validate("test1", 20);
|
||||
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test1", 20));
|
||||
|
||||
// embedding an existing value in the new format
|
||||
redisCluster.useCluster(c -> c.sync().hset(
|
||||
StaticRateLimiter.bucketName(descriptor.id(), "test2"),
|
||||
Map.of("s", "30", "t", String.valueOf(System.currentTimeMillis() + 10000))
|
||||
));
|
||||
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test2", 40));
|
||||
|
||||
// embedding an existing value in the new format
|
||||
redisCluster.useCluster(c -> c.sync().hset(
|
||||
StaticRateLimiter.bucketName(descriptor.id(), "test3"),
|
||||
Map.of("s", "30", "t", String.valueOf(System.currentTimeMillis() + 10000))
|
||||
));
|
||||
rateLimiter.validate("test3", 20);
|
||||
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate("test3", 20));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -117,6 +162,19 @@ public class RateLimitersLuaScriptTest {
|
|||
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
|
||||
}
|
||||
|
||||
private String serializeToOldBucketValueFormat(final TokenBucket bucket) {
|
||||
try {
|
||||
return SystemMapper.jsonMapper().writeValueAsString(Map.of(
|
||||
"bucketSize", bucket.bucketSize,
|
||||
"leakRatePerMillis", bucket.leakRatePerMillis,
|
||||
"spaceRemaining", bucket.spaceRemaining,
|
||||
"lastUpdateTimeMillis", bucket.lastUpdateTimeMillis
|
||||
));
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private Optional<TokenBucket> decodeBucket(final String key) {
|
||||
try {
|
||||
final String json = redisCommandsHandler.get(key);
|
||||
|
|
|
@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* This class is to be extended with implementations of Redis commands as needed.
|
||||
|
@ -28,23 +29,79 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler {
|
|||
yield get(args.get(0).toString());
|
||||
}
|
||||
case "DEL" -> {
|
||||
assertTrue(args.size() > 1);
|
||||
yield del(args.get(0).toString());
|
||||
assertTrue(args.size() >= 1);
|
||||
yield del(args.stream().map(Object::toString).toList());
|
||||
}
|
||||
case "HSET" -> {
|
||||
assertTrue(args.size() >= 3);
|
||||
yield hset(args.get(0).toString(), args.get(1).toString(), args.get(2).toString(), tail(args, 3));
|
||||
}
|
||||
case "HGET" -> {
|
||||
assertEquals(2, args.size());
|
||||
yield hget(args.get(0).toString(), args.get(1).toString());
|
||||
}
|
||||
case "PEXPIRE" -> {
|
||||
assertEquals(2, args.size());
|
||||
yield pexpire(args.get(0).toString(), Double.valueOf(args.get(1).toString()).longValue(), tail(args, 2));
|
||||
}
|
||||
case "TYPE" -> {
|
||||
assertEquals(1, args.size());
|
||||
yield type(args.get(0).toString());
|
||||
}
|
||||
case "RPUSH" -> {
|
||||
assertTrue(args.size() > 1);
|
||||
yield push(false, args.get(0).toString(), tail(args, 1));
|
||||
}
|
||||
case "LPUSH" -> {
|
||||
assertTrue(args.size() > 1);
|
||||
yield push(true, args.get(0).toString(), tail(args, 1));
|
||||
}
|
||||
case "RPOP" -> {
|
||||
assertEquals(2, args.size());
|
||||
yield pop(false, args.get(0).toString(), Double.valueOf(args.get(1).toString()).intValue());
|
||||
}
|
||||
case "LPOP" -> {
|
||||
assertEquals(2, args.size());
|
||||
yield pop(true, args.get(0).toString(), Double.valueOf(args.get(1).toString()).intValue());
|
||||
}
|
||||
|
||||
default -> other(command, args);
|
||||
};
|
||||
}
|
||||
|
||||
public Object[] pop(final boolean left, final String key, final int count) {
|
||||
return new Object[count];
|
||||
}
|
||||
|
||||
public Object push(final boolean left, final String key, final List<Object> values) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
public Object type(final String key) {
|
||||
return Map.of("ok", "none");
|
||||
}
|
||||
|
||||
public Object pexpire(final String key, final long ttlMillis, final List<Object> args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
public Object hset(final String key, final String field, final String value, final List<Object> other) {
|
||||
return "OK";
|
||||
}
|
||||
|
||||
public Object hget(final String key, final String field) {
|
||||
return null;
|
||||
}
|
||||
|
||||
public Object set(final String key, final String value, final List<Object> tail) {
|
||||
return "OK";
|
||||
}
|
||||
|
||||
public String get(final String key) {
|
||||
return null;
|
||||
|
||||
}
|
||||
|
||||
public int del(final String key) {
|
||||
public int del(final List<String> keys) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,12 @@ public class RedisLuaScriptSandbox {
|
|||
function redis_call(...)
|
||||
-- variable name needs to match the one used in the `L.setGlobal()` call
|
||||
-- method name needs to match method name of the Java class
|
||||
return proxy:redisCall(arg)
|
||||
local result = proxy:redisCall(arg)
|
||||
if type(result) == "userdata" then
|
||||
return java.luaify(result)
|
||||
else
|
||||
return result
|
||||
end
|
||||
end
|
||||
|
||||
function json_encode(obj)
|
||||
|
|
|
@ -6,13 +6,15 @@
|
|||
package org.whispersystems.textsecuregcm.util.redis;
|
||||
|
||||
import java.time.Clock;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler {
|
||||
|
||||
public record Entry(String value, long expirationEpochMillis) {
|
||||
public record Entry(Object value, long expirationEpochMillis) {
|
||||
}
|
||||
|
||||
private final Map<String, Entry> cache = new ConcurrentHashMap<>();
|
||||
|
@ -32,20 +34,106 @@ public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler {
|
|||
|
||||
@Override
|
||||
public String get(final String key) {
|
||||
return getIfNotExpired(key, String.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int del(final List<String> key) {
|
||||
return key.stream()
|
||||
.mapToInt(k -> cache.remove(k) != null ? 1 : 0)
|
||||
.sum();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public Object hset(final String key, final String field, final String value, final List<Object> other) {
|
||||
Map<Object, Object> map = getIfNotExpired(key, Map.class);
|
||||
if (map == null) {
|
||||
map = new ConcurrentHashMap<>();
|
||||
cache.put(key, new Entry(map, Long.MAX_VALUE));
|
||||
}
|
||||
map.put(field, value);
|
||||
return "OK";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object hget(final String key, final String field) {
|
||||
final Map<?, ?> map = getIfNotExpired(key, Map.class);
|
||||
return map == null ? null : map.get(field);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public Object push(final boolean left, final String key, final List<Object> values) {
|
||||
LinkedList<Object> list = getIfNotExpired(key, LinkedList.class);
|
||||
if (list == null) {
|
||||
list = new LinkedList<>();
|
||||
cache.put(key, new Entry(list, Long.MAX_VALUE));
|
||||
}
|
||||
for (Object v: values) {
|
||||
if (left) {
|
||||
list.addFirst(v.toString());
|
||||
} else {
|
||||
list.addLast(v.toString());
|
||||
}
|
||||
}
|
||||
return list.size();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public Object[] pop(final boolean left, final String key, final int count) {
|
||||
final Object[] result = new String[count];
|
||||
final LinkedList<Object> list = getIfNotExpired(key, LinkedList.class);
|
||||
if (list == null) {
|
||||
return result;
|
||||
}
|
||||
for (int i = 0; i < Math.min(count, list.size()); i++) {
|
||||
result[i] = left ? list.removeFirst() : list.removeLast();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object pexpire(final String key, final long ttlMillis, final List<Object> args) {
|
||||
final Entry e = cache.get(key);
|
||||
if (e == null) {
|
||||
return 0;
|
||||
}
|
||||
final Entry updated = new Entry(e.value(), clock.millis() + ttlMillis);
|
||||
cache.put(key, updated);
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object type(final String key) {
|
||||
final Object o = getIfNotExpired(key, Object.class);
|
||||
final String type;
|
||||
if (o == null) {
|
||||
type = "none";
|
||||
} else if (o.getClass() == String.class) {
|
||||
type = "string";
|
||||
} else if (Map.class.isAssignableFrom(o.getClass())) {
|
||||
type = "hash";
|
||||
} else if (List.class.isAssignableFrom(o.getClass())) {
|
||||
type = "list";
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unsupported value type: " + o.getClass());
|
||||
}
|
||||
return Map.of("ok", type);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
protected <T> T getIfNotExpired(final String key, final Class<T> expectedType) {
|
||||
final Entry entry = cache.get(key);
|
||||
if (entry == null) {
|
||||
return null;
|
||||
}
|
||||
if (entry.expirationEpochMillis() < clock.millis()) {
|
||||
del(key);
|
||||
del(List.of(key));
|
||||
return null;
|
||||
}
|
||||
return entry.value();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int del(final String key) {
|
||||
return cache.remove(key) != null ? 1 : 0;
|
||||
return expectedType.cast(entry.value());
|
||||
}
|
||||
|
||||
protected long resolveExpirationEpochMillis(final List<Object> args) {
|
||||
|
|
Loading…
Reference in New Issue