Moving RateLimiter logic to Redis Lua and adding async API
This commit is contained in:
		
							parent
							
								
									46fef4082c
								
							
						
					
					
						commit
						4c85e7ba66
					
				
							
								
								
									
										1
									
								
								pom.xml
								
								
								
								
							
							
						
						
									
										1
									
								
								pom.xml
								
								
								
								
							|  | @ -64,6 +64,7 @@ | |||
|     <lettuce.version>6.2.1.RELEASE</lettuce.version> | ||||
|     <libphonenumber.version>8.12.54</libphonenumber.version> | ||||
|     <logstash.logback.version>7.2</logstash.logback.version> | ||||
|     <luajava.version>3.3.0</luajava.version> | ||||
|     <micrometer.version>1.10.3</micrometer.version> | ||||
|     <mockito.version>4.11.0</mockito.version> | ||||
|     <netty.version>4.1.82.Final</netty.version> | ||||
|  |  | |||
|  | @ -172,6 +172,26 @@ | |||
|       </exclusions> | ||||
|     </dependency> | ||||
| 
 | ||||
|     <dependency> | ||||
|       <groupId>party.iroiro.luajava</groupId> | ||||
|       <artifactId>luajava</artifactId> | ||||
|       <version>${luajava.version}</version> | ||||
|       <scope>test</scope> | ||||
|     </dependency> | ||||
|     <dependency> | ||||
|       <groupId>party.iroiro.luajava</groupId> | ||||
|       <artifactId>lua51</artifactId> | ||||
|       <version>${luajava.version}</version> | ||||
|       <scope>test</scope> | ||||
|     </dependency> | ||||
|     <dependency> | ||||
|       <groupId>party.iroiro.luajava</groupId> | ||||
|       <artifactId>lua51-platform</artifactId> | ||||
|       <version>${luajava.version}</version> | ||||
|       <classifier>natives-desktop</classifier> | ||||
|       <scope>runtime</scope> | ||||
|     </dependency> | ||||
| 
 | ||||
|     <dependency> | ||||
|       <groupId>org.eclipse.jetty.websocket</groupId> | ||||
|       <artifactId>websocket-api</artifactId> | ||||
|  |  | |||
|  | @ -7,7 +7,11 @@ package org.whispersystems.textsecuregcm.limits; | |||
| 
 | ||||
| import static java.util.Objects.requireNonNull; | ||||
| 
 | ||||
| import io.lettuce.core.ScriptOutputType; | ||||
| import java.io.IOException; | ||||
| import java.io.UncheckedIOException; | ||||
| import java.lang.invoke.MethodHandles; | ||||
| import java.time.Clock; | ||||
| import java.util.Arrays; | ||||
| import java.util.Map; | ||||
| import java.util.Set; | ||||
|  | @ -17,6 +21,7 @@ import org.apache.commons.lang3.tuple.Pair; | |||
| import org.slf4j.Logger; | ||||
| import org.slf4j.LoggerFactory; | ||||
| import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; | ||||
| import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; | ||||
| import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; | ||||
| import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; | ||||
| 
 | ||||
|  | @ -33,12 +38,14 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> { | |||
|       final T[] values, | ||||
|       final Map<String, RateLimiterConfig> configs, | ||||
|       final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, | ||||
|       final FaultTolerantRedisCluster cacheCluster) { | ||||
|       final ClusterLuaScript validateScript, | ||||
|       final FaultTolerantRedisCluster cacheCluster, | ||||
|       final Clock clock) { | ||||
|     this.configs = configs; | ||||
|     this.rateLimiterByDescriptor = Arrays.stream(values) | ||||
|         .map(descriptor -> Pair.of( | ||||
|             descriptor, | ||||
|             createForDescriptor(descriptor, configs, dynamicConfigurationManager, cacheCluster))) | ||||
|             createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock))) | ||||
|         .collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue)); | ||||
|   } | ||||
| 
 | ||||
|  | @ -62,11 +69,22 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> { | |||
|     } | ||||
|   } | ||||
| 
 | ||||
|   protected static ClusterLuaScript defaultScript(final FaultTolerantRedisCluster cacheCluster) { | ||||
|     try { | ||||
|       return ClusterLuaScript.fromResource( | ||||
|           cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER); | ||||
|     } catch (final IOException e) { | ||||
|       throw new UncheckedIOException("Failed to load rate limit validation script", e); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   private static RateLimiter createForDescriptor( | ||||
|       final RateLimiterDescriptor descriptor, | ||||
|       final Map<String, RateLimiterConfig> configs, | ||||
|       final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, | ||||
|       final FaultTolerantRedisCluster cacheCluster) { | ||||
|       final ClusterLuaScript validateScript, | ||||
|       final FaultTolerantRedisCluster cacheCluster, | ||||
|       final Clock clock) { | ||||
|     if (descriptor.isDynamic()) { | ||||
|       final Supplier<RateLimiterConfig> configResolver = () -> { | ||||
|         final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id()); | ||||
|  | @ -74,9 +92,9 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> { | |||
|             ? config | ||||
|             : configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); | ||||
|       }; | ||||
|       return new DynamicRateLimiter(descriptor.id(), configResolver, cacheCluster); | ||||
|       return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock); | ||||
|     } | ||||
|     final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); | ||||
|     return new StaticRateLimiter(descriptor.id(), cfg, cacheCluster); | ||||
|     return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock); | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -7,10 +7,13 @@ package org.whispersystems.textsecuregcm.limits; | |||
| 
 | ||||
| import static java.util.Objects.requireNonNull; | ||||
| 
 | ||||
| import java.time.Clock; | ||||
| import java.util.concurrent.CompletionStage; | ||||
| import java.util.concurrent.atomic.AtomicReference; | ||||
| import java.util.function.Supplier; | ||||
| import org.apache.commons.lang3.tuple.Pair; | ||||
| import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; | ||||
| import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; | ||||
| import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; | ||||
| 
 | ||||
| public class DynamicRateLimiter implements RateLimiter { | ||||
|  | @ -19,18 +22,26 @@ public class DynamicRateLimiter implements RateLimiter { | |||
| 
 | ||||
|   private final Supplier<RateLimiterConfig> configResolver; | ||||
| 
 | ||||
|   private final ClusterLuaScript validateScript; | ||||
| 
 | ||||
|   private final FaultTolerantRedisCluster cluster; | ||||
| 
 | ||||
|   private final Clock clock; | ||||
| 
 | ||||
|   private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>(); | ||||
| 
 | ||||
| 
 | ||||
|   public DynamicRateLimiter( | ||||
|       final String name, | ||||
|       final Supplier<RateLimiterConfig> configResolver, | ||||
|       final FaultTolerantRedisCluster cluster) { | ||||
|       final ClusterLuaScript validateScript, | ||||
|       final FaultTolerantRedisCluster cluster, | ||||
|       final Clock clock) { | ||||
|     this.name = requireNonNull(name); | ||||
|     this.configResolver = requireNonNull(configResolver); | ||||
|     this.validateScript = requireNonNull(validateScript); | ||||
|     this.cluster = requireNonNull(cluster); | ||||
|     this.clock = requireNonNull(clock); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|  | @ -38,16 +49,31 @@ public class DynamicRateLimiter implements RateLimiter { | |||
|     current().getRight().validate(key, amount); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public CompletionStage<Void> validateAsync(final String key, final int amount) { | ||||
|     return current().getRight().validateAsync(key, amount); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public boolean hasAvailablePermits(final String key, final int permits) { | ||||
|     return current().getRight().hasAvailablePermits(key, permits); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) { | ||||
|     return current().getRight().hasAvailablePermitsAsync(key, amount); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public void clear(final String key) { | ||||
|     current().getRight().clear(key); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public CompletionStage<Void> clearAsync(final String key) { | ||||
|     return current().getRight().clearAsync(key); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public RateLimiterConfig config() { | ||||
|     return current().getLeft(); | ||||
|  | @ -57,7 +83,7 @@ public class DynamicRateLimiter implements RateLimiter { | |||
|     final RateLimiterConfig cfg = configResolver.get(); | ||||
|     return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg) | ||||
|         ? p | ||||
|         : Pair.of(cfg, new StaticRateLimiter(name, cfg, cluster)) | ||||
|         : Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock)) | ||||
|     ); | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -1,99 +0,0 @@ | |||
| /* | ||||
|  * Copyright 2013-2020 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| package org.whispersystems.textsecuregcm.limits; | ||||
| 
 | ||||
| import com.fasterxml.jackson.annotation.JsonProperty; | ||||
| import com.fasterxml.jackson.core.JsonProcessingException; | ||||
| import com.fasterxml.jackson.databind.ObjectMapper; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| import java.time.Duration; | ||||
| 
 | ||||
| public class LeakyBucket { | ||||
| 
 | ||||
|   private final int    bucketSize; | ||||
|   private final double leakRatePerMillis; | ||||
| 
 | ||||
|   private int spaceRemaining; | ||||
|   private long lastUpdateTimeMillis; | ||||
| 
 | ||||
|   public LeakyBucket(int bucketSize, double leakRatePerMillis) { | ||||
|     this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis()); | ||||
|   } | ||||
| 
 | ||||
|   private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) { | ||||
|     this.bucketSize           = bucketSize; | ||||
|     this.leakRatePerMillis    = leakRatePerMillis; | ||||
|     this.spaceRemaining       = spaceRemaining; | ||||
|     this.lastUpdateTimeMillis = lastUpdateTimeMillis; | ||||
|   } | ||||
| 
 | ||||
|   public boolean add(int amount) { | ||||
|     this.spaceRemaining       = getUpdatedSpaceRemaining(); | ||||
|     this.lastUpdateTimeMillis = System.currentTimeMillis(); | ||||
| 
 | ||||
|     if (this.spaceRemaining >= amount) { | ||||
|       this.spaceRemaining -= amount; | ||||
|       return true; | ||||
|     } else { | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   private int getUpdatedSpaceRemaining() { | ||||
|     long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis; | ||||
| 
 | ||||
|     return Math.min(this.bucketSize, | ||||
|                     (int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis))); | ||||
|   } | ||||
| 
 | ||||
|   public Duration getTimeUntilSpaceAvailable(int amount) { | ||||
|     int currentSpaceRemaining = getUpdatedSpaceRemaining(); | ||||
|     if (currentSpaceRemaining >= amount) { | ||||
|       return Duration.ZERO; | ||||
|     } else if (amount > this.bucketSize) { | ||||
|       // This shouldn't happen today but if so we should bubble this to the clients somehow | ||||
|       throw new IllegalArgumentException("Requested permits exceed maximum bucket size"); | ||||
|     } else { | ||||
|       return Duration.ofMillis((long)Math.ceil((double)(amount - currentSpaceRemaining) / this.leakRatePerMillis)); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   public String serialize(ObjectMapper mapper) throws JsonProcessingException { | ||||
|     return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis)); | ||||
|   } | ||||
| 
 | ||||
|   public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException { | ||||
|     LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class); | ||||
| 
 | ||||
|     return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis, | ||||
|                            entity.spaceRemaining, entity.lastUpdateTimeMillis); | ||||
|   } | ||||
| 
 | ||||
|   private static class LeakyBucketEntity { | ||||
|     @JsonProperty | ||||
|     private int    bucketSize; | ||||
| 
 | ||||
|     @JsonProperty | ||||
|     private double leakRatePerMillis; | ||||
| 
 | ||||
|     @JsonProperty | ||||
|     private int    spaceRemaining; | ||||
| 
 | ||||
|     @JsonProperty | ||||
|     private long   lastUpdateTimeMillis; | ||||
| 
 | ||||
|     public LeakyBucketEntity() {} | ||||
| 
 | ||||
|     private LeakyBucketEntity(int bucketSize, double leakRatePerMillis, | ||||
|                               int spaceRemaining, long lastUpdateTimeMillis) | ||||
|     { | ||||
|       this.bucketSize           = bucketSize; | ||||
|       this.leakRatePerMillis    = leakRatePerMillis; | ||||
|       this.spaceRemaining       = spaceRemaining; | ||||
|       this.lastUpdateTimeMillis = lastUpdateTimeMillis; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | @ -1,58 +0,0 @@ | |||
| /* | ||||
|  * Copyright 2013 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.limits; | ||||
| 
 | ||||
| import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; | ||||
| 
 | ||||
| import io.lettuce.core.SetArgs; | ||||
| import io.micrometer.core.instrument.Counter; | ||||
| import io.micrometer.core.instrument.Metrics; | ||||
| import java.time.Duration; | ||||
| import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; | ||||
| import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; | ||||
| 
 | ||||
| public class LockingRateLimiter extends StaticRateLimiter { | ||||
| 
 | ||||
|   private static final RateLimitExceededException REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION | ||||
|       = new RateLimitExceededException(Duration.ZERO, true); | ||||
| 
 | ||||
|   private final Counter counter; | ||||
| 
 | ||||
| 
 | ||||
|   public LockingRateLimiter( | ||||
|       final String name, | ||||
|       final RateLimiterConfig config, | ||||
|       final FaultTolerantRedisCluster cacheCluster) { | ||||
|     super(name, config, cacheCluster); | ||||
|     this.counter = Metrics.counter(name(getClass(), "locked"), "name", name); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public void validate(final String key, final int amount) throws RateLimitExceededException { | ||||
|     if (!acquireLock(key)) { | ||||
|       counter.increment(); | ||||
|       throw REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION; | ||||
|     } | ||||
| 
 | ||||
|     try { | ||||
|       super.validate(key, amount); | ||||
|     } finally { | ||||
|       releaseLock(key); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   private void releaseLock(final String key) { | ||||
|     cacheCluster.useCluster(connection -> connection.sync().del(getLockName(key))); | ||||
|   } | ||||
| 
 | ||||
|   private boolean acquireLock(final String key) { | ||||
|     return cacheCluster.withCluster(connection -> connection.sync().set(getLockName(key), "L", SetArgs.Builder.nx().ex(10))) != null; | ||||
|   } | ||||
| 
 | ||||
|   private String getLockName(final String key) { | ||||
|     return "leaky_lock::" + name + "::" + key; | ||||
|   } | ||||
| } | ||||
|  | @ -6,16 +6,23 @@ | |||
| package org.whispersystems.textsecuregcm.limits; | ||||
| 
 | ||||
| import java.util.UUID; | ||||
| import java.util.concurrent.CompletionStage; | ||||
| import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; | ||||
| 
 | ||||
| public interface RateLimiter { | ||||
| 
 | ||||
|   void validate(String key, int amount) throws RateLimitExceededException; | ||||
| 
 | ||||
|   CompletionStage<Void> validateAsync(String key, int amount); | ||||
| 
 | ||||
|   boolean hasAvailablePermits(String key, int permits); | ||||
| 
 | ||||
|   CompletionStage<Boolean> hasAvailablePermitsAsync(String key, int amount); | ||||
| 
 | ||||
|   void clear(String key); | ||||
| 
 | ||||
|   CompletionStage<Void> clearAsync(String key); | ||||
| 
 | ||||
|   RateLimiterConfig config(); | ||||
| 
 | ||||
|   default void validate(final String key) throws RateLimitExceededException { | ||||
|  | @ -30,14 +37,34 @@ public interface RateLimiter { | |||
|     validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString()); | ||||
|   } | ||||
| 
 | ||||
|   default CompletionStage<Void> validateAsync(final String key) { | ||||
|     return validateAsync(key, 1); | ||||
|   } | ||||
| 
 | ||||
|   default CompletionStage<Void> validateAsync(final UUID accountUuid) { | ||||
|     return validateAsync(accountUuid.toString()); | ||||
|   } | ||||
| 
 | ||||
|   default CompletionStage<Void> validateAsync(final UUID srcAccountUuid, final UUID dstAccountUuid) { | ||||
|     return validateAsync(srcAccountUuid.toString() + "__" + dstAccountUuid.toString()); | ||||
|   } | ||||
| 
 | ||||
|   default boolean hasAvailablePermits(final UUID accountUuid, final int permits) { | ||||
|     return hasAvailablePermits(accountUuid.toString(), permits); | ||||
|   } | ||||
| 
 | ||||
|   default CompletionStage<Boolean> hasAvailablePermitsAsync(final UUID accountUuid, final int permits) { | ||||
|     return hasAvailablePermitsAsync(accountUuid.toString(), permits); | ||||
|   } | ||||
| 
 | ||||
|   default void clear(final UUID accountUuid) { | ||||
|     clear(accountUuid.toString()); | ||||
|   } | ||||
| 
 | ||||
|   default CompletionStage<Void> clearAsync(final UUID accountUuid) { | ||||
|     return clearAsync(accountUuid.toString()); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that | ||||
|    * {@link RateLimitExceededException#isLegacy()} returns {@code true} | ||||
|  |  | |||
|  | @ -6,8 +6,10 @@ package org.whispersystems.textsecuregcm.limits; | |||
| 
 | ||||
| 
 | ||||
| import com.google.common.annotations.VisibleForTesting; | ||||
| import java.time.Clock; | ||||
| import java.util.Map; | ||||
| import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; | ||||
| import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; | ||||
| import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; | ||||
| import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; | ||||
| 
 | ||||
|  | @ -103,7 +105,8 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> { | |||
|       final Map<String, RateLimiterConfig> configs, | ||||
|       final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, | ||||
|       final FaultTolerantRedisCluster cacheCluster) { | ||||
|     final RateLimiters rateLimiters = new RateLimiters(configs, dynamicConfigurationManager, cacheCluster); | ||||
|     final RateLimiters rateLimiters = new RateLimiters( | ||||
|         configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC()); | ||||
|     rateLimiters.validateValuesAndConfigs(); | ||||
|     return rateLimiters; | ||||
|   } | ||||
|  | @ -112,8 +115,10 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> { | |||
|   RateLimiters( | ||||
|       final Map<String, RateLimiterConfig> configs, | ||||
|       final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, | ||||
|       final FaultTolerantRedisCluster cacheCluster) { | ||||
|     super(For.values(), configs, dynamicConfigurationManager, cacheCluster); | ||||
|       final ClusterLuaScript validateScript, | ||||
|       final FaultTolerantRedisCluster cacheCluster, | ||||
|       final Clock clock) { | ||||
|     super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock); | ||||
|   } | ||||
| 
 | ||||
|   public RateLimiter getAllocateDeviceLimiter() { | ||||
|  |  | |||
|  | @ -5,60 +5,96 @@ | |||
| package org.whispersystems.textsecuregcm.limits; | ||||
| 
 | ||||
| import static java.util.Objects.requireNonNull; | ||||
| import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; | ||||
| import static java.util.concurrent.CompletableFuture.completedFuture; | ||||
| import static java.util.concurrent.CompletableFuture.failedFuture; | ||||
| 
 | ||||
| import com.fasterxml.jackson.core.JsonProcessingException; | ||||
| import io.micrometer.core.instrument.Counter; | ||||
| import io.micrometer.core.instrument.Metrics; | ||||
| import java.io.IOException; | ||||
| import java.time.Clock; | ||||
| import java.time.Duration; | ||||
| import org.slf4j.Logger; | ||||
| import org.slf4j.LoggerFactory; | ||||
| import java.util.List; | ||||
| import java.util.concurrent.CompletionStage; | ||||
| import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; | ||||
| import org.whispersystems.textsecuregcm.metrics.MetricsUtil; | ||||
| import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; | ||||
| import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; | ||||
| import org.whispersystems.textsecuregcm.util.SystemMapper; | ||||
| import org.whispersystems.textsecuregcm.util.Util; | ||||
| 
 | ||||
| public class StaticRateLimiter implements RateLimiter { | ||||
| 
 | ||||
|   private static final Logger logger = LoggerFactory.getLogger(StaticRateLimiter.class); | ||||
| 
 | ||||
|   protected final String name; | ||||
| 
 | ||||
|   private final RateLimiterConfig config; | ||||
| 
 | ||||
|   protected final FaultTolerantRedisCluster cacheCluster; | ||||
| 
 | ||||
|   private final Counter counter; | ||||
| 
 | ||||
|   private final ClusterLuaScript validateScript; | ||||
| 
 | ||||
|   private final FaultTolerantRedisCluster cacheCluster; | ||||
| 
 | ||||
|   private final Clock clock; | ||||
| 
 | ||||
| 
 | ||||
|   public StaticRateLimiter( | ||||
|       final String name, | ||||
|       final RateLimiterConfig config, | ||||
|       final FaultTolerantRedisCluster cacheCluster) { | ||||
|       final ClusterLuaScript validateScript, | ||||
|       final FaultTolerantRedisCluster cacheCluster, | ||||
|       final Clock clock) { | ||||
|     this.name = requireNonNull(name); | ||||
|     this.config = requireNonNull(config); | ||||
|     this.validateScript = requireNonNull(validateScript); | ||||
|     this.cacheCluster = requireNonNull(cacheCluster); | ||||
|     this.counter = Metrics.counter(name(getClass(), "exceeded"), "name", name); | ||||
|     this.clock = requireNonNull(clock); | ||||
|     this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "name", name); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public void validate(final String key, final int amount) throws RateLimitExceededException { | ||||
|     final LeakyBucket bucket = getBucket(key); | ||||
|     if (bucket.add(amount)) { | ||||
|       setBucket(key, bucket); | ||||
|     } else { | ||||
|     final long deficitPermitsAmount = executeValidateScript(key, amount, true); | ||||
|     if (deficitPermitsAmount > 0) { | ||||
|       counter.increment(); | ||||
|       throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true); | ||||
|       final Duration retryAfter = Duration.ofMillis( | ||||
|           (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); | ||||
|       throw new RateLimitExceededException(retryAfter, true); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public boolean hasAvailablePermits(final String key, final int permits) { | ||||
|     return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO); | ||||
|   public CompletionStage<Void> validateAsync(final String key, final int amount) { | ||||
|     return executeValidateScriptAsync(key, amount, true) | ||||
|         .thenCompose(deficitPermitsAmount -> { | ||||
|           if (deficitPermitsAmount == 0) { | ||||
|             return completedFuture(null); | ||||
|           } | ||||
|           counter.increment(); | ||||
|           final Duration retryAfter = Duration.ofMillis( | ||||
|               (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); | ||||
|           return failedFuture(new RateLimitExceededException(retryAfter, true)); | ||||
|         }); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public boolean hasAvailablePermits(final String key, final int amount) { | ||||
|     final long deficitPermitsAmount = executeValidateScript(key, amount, false); | ||||
|     return deficitPermitsAmount == 0; | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) { | ||||
|     return executeValidateScriptAsync(key, amount, false) | ||||
|         .thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public void clear(final String key) { | ||||
|     cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key))); | ||||
|     cacheCluster.useCluster(connection -> connection.sync().del(bucketName(key))); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public CompletionStage<Void> clearAsync(final String key) { | ||||
|     return cacheCluster.withCluster(connection -> connection.async().del(bucketName(key))) | ||||
|         .thenRun(Util.NOOP); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|  | @ -66,33 +102,31 @@ public class StaticRateLimiter implements RateLimiter { | |||
|     return config; | ||||
|   } | ||||
| 
 | ||||
|   private void setBucket(final String key, final LeakyBucket bucket) { | ||||
|     try { | ||||
|       final String serialized = bucket.serialize(SystemMapper.jsonMapper()); | ||||
|       cacheCluster.useCluster(connection -> connection.sync().setex( | ||||
|           getBucketName(key), | ||||
|           (int) Math.ceil((config.bucketSize() / config.leakRatePerMillis()) / 1000), | ||||
|           serialized)); | ||||
|     } catch (final JsonProcessingException e) { | ||||
|       throw new IllegalArgumentException(e); | ||||
|     } | ||||
|   private long executeValidateScript(final String key, final int amount, final boolean applyChanges) { | ||||
|     final List<String> keys = List.of(bucketName(key)); | ||||
|     final List<String> arguments = List.of( | ||||
|         String.valueOf(config.bucketSize()), | ||||
|         String.valueOf(config.leakRatePerMillis()), | ||||
|         String.valueOf(clock.millis()), | ||||
|         String.valueOf(amount), | ||||
|         String.valueOf(applyChanges) | ||||
|     ); | ||||
|     return (Long) validateScript.execute(keys, arguments); | ||||
|   } | ||||
| 
 | ||||
|   private LeakyBucket getBucket(final String key) { | ||||
|     try { | ||||
|       final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key))); | ||||
| 
 | ||||
|       if (serialized != null) { | ||||
|         return LeakyBucket.fromSerialized(SystemMapper.jsonMapper(), serialized); | ||||
|       } | ||||
|     } catch (final IOException e) { | ||||
|       logger.warn("Deserialization error", e); | ||||
|     } | ||||
| 
 | ||||
|     return new LeakyBucket(config.bucketSize(), config.leakRatePerMillis()); | ||||
|   private CompletionStage<Long> executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) { | ||||
|     final List<String> keys = List.of(bucketName(key)); | ||||
|     final List<String> arguments = List.of( | ||||
|         String.valueOf(config.bucketSize()), | ||||
|         String.valueOf(config.leakRatePerMillis()), | ||||
|         String.valueOf(clock.millis()), | ||||
|         String.valueOf(amount), | ||||
|         String.valueOf(applyChanges) | ||||
|     ); | ||||
|     return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o); | ||||
|   } | ||||
| 
 | ||||
|   private String getBucketName(final String key) { | ||||
|   private String bucketName(final String key) { | ||||
|     return "leaky_bucket::" + name + "::" + key; | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -0,0 +1,67 @@ | |||
| -- The script encapsulates the logic of a token bucket rate limiter. | ||||
| -- Two types of operations are supported: 'check-only' and 'use-if-available' (controlled by the 'useTokens' arg). | ||||
| -- Both operations take in rate limiter configuration parameters and the requested amount of tokens. | ||||
| -- Both operations return 0, if the rate limiter has enough tokens to cover the requested amount, | ||||
| -- and the deficit amount otherwise. | ||||
| -- However, 'check-only' operation doesn't modify the bucket, while 'use-if-available' (if successful) | ||||
| -- reduces the amount of available tokens by the requested amount. | ||||
| 
 | ||||
| local bucketId = KEYS[1] | ||||
| 
 | ||||
| local bucketSize = tonumber(ARGV[1]) | ||||
| local refillRatePerMillis = tonumber(ARGV[2]) | ||||
| local currentTimeMillis = tonumber(ARGV[3]) | ||||
| local requestedAmount = tonumber(ARGV[4]) | ||||
| local useTokens = ARGV[5] and string.lower(ARGV[5]) == "true" | ||||
| 
 | ||||
| local tokenBucketJson = redis.call("GET", bucketId) | ||||
| local tokenBucket | ||||
| local changesMade = false | ||||
| 
 | ||||
| if tokenBucketJson then | ||||
|     tokenBucket = cjson.decode(tokenBucketJson) | ||||
| else | ||||
|     tokenBucket = { | ||||
|         ["bucketSize"] = bucketSize, | ||||
|         ["leakRatePerMillis"] = refillRatePerMillis, | ||||
|         ["spaceRemaining"] = bucketSize, | ||||
|         ["lastUpdateTimeMillis"] = currentTimeMillis | ||||
|     } | ||||
| end | ||||
| 
 | ||||
| -- this can happen if rate limiter configuration has changed while the key is still in Redis | ||||
| if tokenBucket["bucketSize"] ~= bucketSize or tokenBucket["leakRatePerMillis"] ~= refillRatePerMillis then | ||||
|     tokenBucket["bucketSize"] = bucketSize | ||||
|     tokenBucket["leakRatePerMillis"] = refillRatePerMillis | ||||
|     changesMade = true | ||||
| end | ||||
| 
 | ||||
| local elapsedTime = currentTimeMillis - tokenBucket["lastUpdateTimeMillis"] | ||||
| local availableAmount = math.min( | ||||
|     tokenBucket["bucketSize"], | ||||
|     math.floor(tokenBucket["spaceRemaining"] + (elapsedTime * tokenBucket["leakRatePerMillis"])) | ||||
| ) | ||||
| 
 | ||||
| if availableAmount >= requestedAmount then | ||||
|     if useTokens then | ||||
|         tokenBucket["spaceRemaining"] = availableAmount - requestedAmount | ||||
|         tokenBucket["lastUpdateTimeMillis"] = currentTimeMillis | ||||
|         changesMade = true | ||||
|     end | ||||
|     if changesMade then | ||||
|         local tokensUsed = tokenBucket["bucketSize"] - tokenBucket["spaceRemaining"] | ||||
|         -- Storing a 'full' bucket is equivalent of not storing any state at all | ||||
|         -- (in which case a bucket will be just initialized from the input configs as a 'full' one). | ||||
|         -- For this reason, we either set an expiration time on the record (calculated to let the bucket fully replenish) | ||||
|         -- or we just delete the key if the bucket is full. | ||||
|         if tokensUsed > 0 then | ||||
|             local ttlMillis = math.ceil(tokensUsed / tokenBucket["leakRatePerMillis"]) | ||||
|             redis.call("SET", bucketId, cjson.encode(tokenBucket), "PX", ttlMillis) | ||||
|         else | ||||
|             redis.call("DEL", bucketId) | ||||
|         end | ||||
|     end | ||||
|     return 0 | ||||
| else | ||||
|     return requestedAmount - availableAmount | ||||
| end | ||||
|  | @ -1,85 +0,0 @@ | |||
| /* | ||||
|  * Copyright 2013 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.limits; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertFalse; | ||||
| import static org.junit.jupiter.api.Assertions.assertThrows; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| import com.fasterxml.jackson.databind.ObjectMapper; | ||||
| import java.io.IOException; | ||||
| import java.time.Duration; | ||||
| import java.util.concurrent.TimeUnit; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.whispersystems.textsecuregcm.util.SystemMapper; | ||||
| 
 | ||||
| class LeakyBucketTest { | ||||
| 
 | ||||
|   @Test | ||||
|   void testFull() { | ||||
|     LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0); | ||||
| 
 | ||||
|     assertTrue(leakyBucket.add(1)); | ||||
|     assertTrue(leakyBucket.add(1)); | ||||
|     assertFalse(leakyBucket.add(1)); | ||||
| 
 | ||||
|     leakyBucket = new LeakyBucket(2, 1.0 / 2.0); | ||||
| 
 | ||||
|     assertTrue(leakyBucket.add(2)); | ||||
|     assertFalse(leakyBucket.add(1)); | ||||
|     assertFalse(leakyBucket.add(2)); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   void testLapseRate() throws IOException { | ||||
|     ObjectMapper mapper     = SystemMapper.jsonMapper(); | ||||
|     String       serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}"; | ||||
| 
 | ||||
|     LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); | ||||
|     assertTrue(leakyBucket.add(1)); | ||||
| 
 | ||||
|     String      serializedAgain  = leakyBucket.serialize(mapper); | ||||
|     LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain); | ||||
| 
 | ||||
|     assertFalse(leakyBucketAgain.add(1)); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   void testLapseShort() throws Exception { | ||||
|     ObjectMapper mapper     = new ObjectMapper(); | ||||
|     String       serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}"; | ||||
| 
 | ||||
|     LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); | ||||
|     assertFalse(leakyBucket.add(1)); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   void testGetTimeUntilSpaceAvailable() throws Exception { | ||||
|     ObjectMapper mapper = new ObjectMapper(); | ||||
| 
 | ||||
|     { | ||||
|       String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":2,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}"; | ||||
| 
 | ||||
|       LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); | ||||
| 
 | ||||
|       assertEquals(Duration.ZERO, leakyBucket.getTimeUntilSpaceAvailable(1)); | ||||
|       assertThrows(IllegalArgumentException.class, () -> leakyBucket.getTimeUntilSpaceAvailable(5000)); | ||||
|     } | ||||
| 
 | ||||
|     { | ||||
|       String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}"; | ||||
| 
 | ||||
|       LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); | ||||
| 
 | ||||
|       Duration timeUntilSpaceAvailable = leakyBucket.getTimeUntilSpaceAvailable(1); | ||||
| 
 | ||||
|       // TODO Refactor LeakyBucket to be more test-friendly and accept a Clock | ||||
|       assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMillis(119_000)) > 0); | ||||
|       assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMinutes(2)) <= 0); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | @ -0,0 +1,147 @@ | |||
| /* | ||||
|  * Copyright 2023 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.limits; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertThrows; | ||||
| import static org.mockito.Mockito.mock; | ||||
| import static org.mockito.Mockito.when; | ||||
| 
 | ||||
| import com.fasterxml.jackson.core.JsonProcessingException; | ||||
| import io.lettuce.core.ScriptOutputType; | ||||
| import java.time.Clock; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import java.util.Optional; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.extension.RegisterExtension; | ||||
| import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; | ||||
| import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; | ||||
| import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; | ||||
| import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; | ||||
| import org.whispersystems.textsecuregcm.util.MockUtils; | ||||
| import org.whispersystems.textsecuregcm.util.MutableClock; | ||||
| import org.whispersystems.textsecuregcm.util.SystemMapper; | ||||
| import org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox; | ||||
| import org.whispersystems.textsecuregcm.util.redis.SimpleCacheCommandsHandler; | ||||
| 
 | ||||
| public class RateLimitersLuaScriptTest { | ||||
| 
 | ||||
|   @RegisterExtension | ||||
|   private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); | ||||
| 
 | ||||
|   private final DynamicConfiguration configuration = mock(DynamicConfiguration.class); | ||||
| 
 | ||||
|   private final MutableClock clock = MockUtils.mutableClock(0); | ||||
| 
 | ||||
|   private final RedisLuaScriptSandbox sandbox = RedisLuaScriptSandbox.fromResource( | ||||
|       "lua/validate_rate_limit.lua", | ||||
|       ScriptOutputType.INTEGER); | ||||
| 
 | ||||
|   private final SimpleCacheCommandsHandler redisCommandsHandler = new SimpleCacheCommandsHandler(clock); | ||||
| 
 | ||||
|   private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig = | ||||
|       MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration)); | ||||
| 
 | ||||
|   @Test | ||||
|   public void testWithEmbeddedRedis() throws Exception { | ||||
|     final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; | ||||
|     final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); | ||||
|     final RateLimiters limiters = new RateLimiters( | ||||
|         Map.of(descriptor.id(), new RateLimiterConfig(60, 60)), | ||||
|         dynamicConfig, | ||||
|         RateLimiters.defaultScript(redisCluster), | ||||
|         redisCluster, | ||||
|         Clock.systemUTC()); | ||||
| 
 | ||||
|     final RateLimiter rateLimiter = limiters.forDescriptor(descriptor); | ||||
|     rateLimiter.validate("test", 25); | ||||
|     rateLimiter.validate("test", 25); | ||||
|     assertThrows(Exception.class, () -> rateLimiter.validate("test", 25)); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void testLuaBucketConfigurationUpdates() throws Exception { | ||||
|     final String key = "key1"; | ||||
|     clock.setTimeMillis(0); | ||||
|     long result = (long) sandbox.execute( | ||||
|         List.of(key), | ||||
|         scriptArgs(1000, 1, 1, true), | ||||
|         redisCommandsHandler | ||||
|     ); | ||||
|     assertEquals(0L, result); | ||||
|     assertEquals(1000L, decodeBucket(key).orElseThrow().bucketSize); | ||||
| 
 | ||||
|     // now making a check-only call, but changing the bucket size | ||||
|     result = (long) sandbox.execute( | ||||
|         List.of(key), | ||||
|         scriptArgs(2000, 1, 1, false), | ||||
|         redisCommandsHandler | ||||
|     ); | ||||
|     assertEquals(0L, result); | ||||
|     assertEquals(2000L, decodeBucket(key).orElseThrow().bucketSize); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void testLuaUpdatesTokenBucket() throws Exception { | ||||
|     final String key = "key1"; | ||||
|     clock.setTimeMillis(0); | ||||
|     long result = (long) sandbox.execute( | ||||
|         List.of(key), | ||||
|         scriptArgs(1000, 1, 200, true), | ||||
|         redisCommandsHandler | ||||
|     ); | ||||
|     assertEquals(0L, result); | ||||
|     assertEquals(800L, decodeBucket(key).orElseThrow().spaceRemaining); | ||||
| 
 | ||||
|     // 50 tokens replenished, acquiring 100 more, should end up with 750 available | ||||
|     clock.setTimeMillis(50); | ||||
|     result = (long) sandbox.execute( | ||||
|         List.of(key), | ||||
|         scriptArgs(1000, 1, 100, true), | ||||
|         redisCommandsHandler | ||||
|     ); | ||||
|     assertEquals(0L, result); | ||||
|     assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining); | ||||
| 
 | ||||
|     // now checking without an update, should not affect the count | ||||
|     result = (long) sandbox.execute( | ||||
|         List.of(key), | ||||
|         scriptArgs(1000, 1, 100, false), | ||||
|         redisCommandsHandler | ||||
|     ); | ||||
|     assertEquals(0L, result); | ||||
|     assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining); | ||||
|   } | ||||
| 
 | ||||
|   private Optional<TokenBucket> decodeBucket(final String key) { | ||||
|     try { | ||||
|       final String json = redisCommandsHandler.get(key); | ||||
|       return json == null | ||||
|           ? Optional.empty() | ||||
|           : Optional.of(SystemMapper.jsonMapper().readValue(json, TokenBucket.class)); | ||||
|     } catch (JsonProcessingException e) { | ||||
|       throw new RuntimeException(e); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   private List<String> scriptArgs( | ||||
|       final long bucketSize, | ||||
|       final long ratePerMillis, | ||||
|       final long requestedAmount, | ||||
|       final boolean useTokens) { | ||||
|     return List.of( | ||||
|         String.valueOf(bucketSize), | ||||
|         String.valueOf(ratePerMillis), | ||||
|         String.valueOf(clock.millis()), | ||||
|         String.valueOf(requestedAmount), | ||||
|         String.valueOf(useTokens) | ||||
|     ); | ||||
|   } | ||||
| 
 | ||||
|   private record TokenBucket(long bucketSize, long leakRatePerMillis, long spaceRemaining, long lastUpdateTimeMillis) { | ||||
|   } | ||||
| } | ||||
|  | @ -18,9 +18,11 @@ import javax.validation.Valid; | |||
| import javax.validation.constraints.NotNull; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; | ||||
| import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; | ||||
| import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; | ||||
| import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; | ||||
| import org.whispersystems.textsecuregcm.util.MockUtils; | ||||
| import org.whispersystems.textsecuregcm.util.MutableClock; | ||||
| 
 | ||||
| @SuppressWarnings("unchecked") | ||||
| public class RateLimitersTest { | ||||
|  | @ -30,8 +32,12 @@ public class RateLimitersTest { | |||
|   private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig = | ||||
|       MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration)); | ||||
| 
 | ||||
|   private final ClusterLuaScript validateScript = mock(ClusterLuaScript.class); | ||||
| 
 | ||||
|   private final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class); | ||||
| 
 | ||||
|   private final MutableClock clock = MockUtils.mutableClock(0); | ||||
| 
 | ||||
|   private static final String BAD_YAML = """ | ||||
|       limits: | ||||
|         smsVoicePrefix: | ||||
|  | @ -59,12 +65,12 @@ public class RateLimitersTest { | |||
|   public void testValidateConfigs() throws Exception { | ||||
|     assertThrows(IllegalArgumentException.class, () -> { | ||||
|       final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow(); | ||||
|       final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster); | ||||
|       final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock); | ||||
|       rateLimiters.validateValuesAndConfigs(); | ||||
|     }); | ||||
| 
 | ||||
|     final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow(); | ||||
|     final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster); | ||||
|     final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock); | ||||
|     rateLimiters.validateValuesAndConfigs(); | ||||
|   } | ||||
| 
 | ||||
|  | @ -79,18 +85,22 @@ public class RateLimitersTest { | |||
|         new TestDescriptor[] { td1, td2, td3, tdDup }, | ||||
|         Collections.emptyMap(), | ||||
|         dynamicConfig, | ||||
|         redisCluster) {}); | ||||
|         validateScript, | ||||
|         redisCluster, | ||||
|         clock) {}); | ||||
| 
 | ||||
|     new BaseRateLimiters<>( | ||||
|         new TestDescriptor[] { td1, td2, td3 }, | ||||
|         Collections.emptyMap(), | ||||
|         dynamicConfig, | ||||
|         redisCluster) {}; | ||||
|         validateScript, | ||||
|         redisCluster, | ||||
|         clock) {}; | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   void testUnchangingConfiguration() { | ||||
|     final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster); | ||||
|     final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock); | ||||
|     final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter(); | ||||
|     final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig(); | ||||
|     assertEquals(expected, limiter.config()); | ||||
|  | @ -109,7 +119,7 @@ public class RateLimitersTest { | |||
| 
 | ||||
|     when(configuration.getLimits()).thenReturn(limitsConfigMap); | ||||
| 
 | ||||
|     final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster); | ||||
|     final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock); | ||||
|     final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter(); | ||||
| 
 | ||||
|     limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig); | ||||
|  | @ -137,7 +147,7 @@ public class RateLimitersTest { | |||
| 
 | ||||
|     when(configuration.getLimits()).thenReturn(mapForDynamic); | ||||
| 
 | ||||
|     final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, redisCluster); | ||||
|     final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, validateScript, redisCluster, clock); | ||||
|     final RateLimiter limiter = rateLimiters.forDescriptor(descriptor); | ||||
| 
 | ||||
|     // test only default is present | ||||
|  |  | |||
|  | @ -0,0 +1,54 @@ | |||
| /* | ||||
|  * Copyright 2023 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.util.redis; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| /** | ||||
|  * This class is to be extended with implementations of Redis commands as needed. | ||||
|  */ | ||||
| public class BaseRedisCommandsHandler implements RedisCommandsHandler { | ||||
| 
 | ||||
|   @Override | ||||
|   public Object redisCommand(final String command, final List<Object> args) { | ||||
|     return switch (command) { | ||||
|       case "SET" -> { | ||||
|         assertTrue(args.size() > 2); | ||||
|         yield set(args.get(0).toString(), args.get(1).toString(), tail(args, 2)); | ||||
|       } | ||||
|       case "GET" -> { | ||||
|         assertEquals(1, args.size()); | ||||
|         yield get(args.get(0).toString()); | ||||
|       } | ||||
|       case "DEL" -> { | ||||
|         assertTrue(args.size() > 1); | ||||
|         yield del(args.get(0).toString()); | ||||
|       } | ||||
|       default -> other(command, args); | ||||
|     }; | ||||
|   } | ||||
| 
 | ||||
|   public Object set(final String key, final String value, final List<Object> tail) { | ||||
|     return "OK"; | ||||
|   } | ||||
| 
 | ||||
|   public String get(final String key) { | ||||
|     return null; | ||||
| 
 | ||||
|   } | ||||
| 
 | ||||
|   public int del(final String key) { | ||||
|     return 0; | ||||
|   } | ||||
| 
 | ||||
|   public Object other(final String command, final List<Object> args) { | ||||
|     return "OK"; | ||||
|   } | ||||
| } | ||||
|  | @ -0,0 +1,14 @@ | |||
| /* | ||||
|  * Copyright 2023 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.util.redis; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| @FunctionalInterface | ||||
| public interface RedisCommandsHandler { | ||||
| 
 | ||||
|   Object redisCommand(String command, List<Object> args); | ||||
| } | ||||
|  | @ -0,0 +1,167 @@ | |||
| /* | ||||
|  * Copyright 2023 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.util.redis; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertFalse; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| import com.fasterxml.jackson.core.JsonProcessingException; | ||||
| import com.google.common.io.Resources; | ||||
| import io.lettuce.core.ScriptOutputType; | ||||
| import java.io.IOException; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import org.whispersystems.textsecuregcm.util.SystemMapper; | ||||
| import party.iroiro.luajava.Lua; | ||||
| import party.iroiro.luajava.lua51.Lua51; | ||||
| import party.iroiro.luajava.value.ImmutableLuaValue; | ||||
| 
 | ||||
| public class RedisLuaScriptSandbox { | ||||
| 
 | ||||
|   private static final String PREFIX = """ | ||||
|       function redis_call(...) | ||||
|         -- variable name needs to match the one used in the `L.setGlobal()` call | ||||
|         -- method name needs to match method name of the Java class  | ||||
|         return proxy:redisCall(arg) | ||||
|       end | ||||
|        | ||||
|       function json_encode(obj) | ||||
|         return mapper:encode(obj) | ||||
|       end | ||||
|        | ||||
|       function json_decode(json) | ||||
|         return java.luaify(mapper:decode(json)) | ||||
|       end | ||||
|        | ||||
|       local redis = { call = redis_call } | ||||
|       local cjson = { encode = json_encode, decode = json_decode } | ||||
|        | ||||
|       """; | ||||
| 
 | ||||
|   private final String luaScript; | ||||
| 
 | ||||
|   private final ScriptOutputType scriptOutputType; | ||||
| 
 | ||||
| 
 | ||||
|   public static RedisLuaScriptSandbox fromResource( | ||||
|       final String resource, | ||||
|       final ScriptOutputType scriptOutputType) { | ||||
|     try { | ||||
|       final String src = Resources.toString(Resources.getResource(resource), StandardCharsets.UTF_8); | ||||
|       return new RedisLuaScriptSandbox(src, scriptOutputType); | ||||
|     } catch (IOException e) { | ||||
|       throw new RuntimeException(e); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   public RedisLuaScriptSandbox(final String luaScript, final ScriptOutputType scriptOutputType) { | ||||
|     this.luaScript = luaScript; | ||||
|     this.scriptOutputType = scriptOutputType; | ||||
|   } | ||||
| 
 | ||||
|   public Object execute( | ||||
|       final List<String> keys, | ||||
|       final List<String> args, | ||||
|       final RedisCommandsHandler redisCallsHandler) { | ||||
| 
 | ||||
|     try (final Lua lua = new Lua51()) { | ||||
|       lua.openLibraries(); | ||||
|       final RedisLuaProxy proxy = new RedisLuaProxy(redisCallsHandler); | ||||
|       lua.push(MapperLuaProxy.INSTANCE, Lua.Conversion.FULL); | ||||
|       lua.setGlobal("mapper"); | ||||
|       lua.push(proxy, Lua.Conversion.FULL); | ||||
|       lua.setGlobal("proxy"); | ||||
|       lua.push(keys, Lua.Conversion.FULL); | ||||
|       lua.setGlobal("KEYS"); | ||||
|       lua.push(args, Lua.Conversion.FULL); | ||||
|       lua.setGlobal("ARGV"); | ||||
|       final Lua.LuaError executionResult = lua.run(PREFIX + luaScript); | ||||
|       assertEquals("OK", executionResult.name(), "Runtime error during Lua script execution"); | ||||
|       return adaptOutputResult(lua.get()); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   protected Object adaptOutputResult(final Object luaObject) { | ||||
|     if (luaObject instanceof ImmutableLuaValue<?> luaValue) { | ||||
|       final Object javaValue = luaValue.toJavaObject(); | ||||
|       // validate expected script output type | ||||
|       switch (scriptOutputType) { | ||||
|         case INTEGER -> assertTrue(javaValue instanceof Double); // lua number is always Double | ||||
|         case STATUS -> assertTrue(javaValue instanceof String); | ||||
|         case BOOLEAN -> assertTrue(javaValue instanceof Boolean); | ||||
|       }; | ||||
|       if (javaValue instanceof Double d) { | ||||
|         return d.longValue(); | ||||
|       } | ||||
|       if (javaValue instanceof String s) { | ||||
|         return s; | ||||
|       } | ||||
|       if (javaValue instanceof Boolean b) { | ||||
|         return b; | ||||
|       } | ||||
|       if (javaValue == null) { | ||||
|         return null; | ||||
|       } | ||||
|       throw new IllegalStateException("unexpected script result java type: " + javaValue.getClass().getName()); | ||||
|     } | ||||
|     throw new IllegalStateException("unexpected script result lua type: " + luaObject.getClass().getName()); | ||||
|   } | ||||
| 
 | ||||
|   public static <T> List<T> tail(final List<T> list, final int fromIdx) { | ||||
|     return fromIdx < list.size() ? list.subList(fromIdx, list.size()) : Collections.emptyList(); | ||||
|   } | ||||
| 
 | ||||
|   public static final class MapperLuaProxy { | ||||
| 
 | ||||
|     public static final MapperLuaProxy INSTANCE = new MapperLuaProxy(); | ||||
| 
 | ||||
|     public String encode(final Map<Object, Object> obj) { | ||||
|       try { | ||||
|         return SystemMapper.jsonMapper().writeValueAsString(obj); | ||||
|       } catch (JsonProcessingException e) { | ||||
|         throw new RuntimeException(e); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     public Map<Object, Object> decode(final Object json) { | ||||
|       try { | ||||
|         //noinspection unchecked | ||||
|         return SystemMapper.jsonMapper().readValue(json.toString(), Map.class); | ||||
|       } catch (JsonProcessingException e) { | ||||
|         throw new RuntimeException(e); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Instances of this class are passed to the Lua scripting engine | ||||
|    * and serve as a stubs for the calls to `redis.call()`. | ||||
|    * | ||||
|    * @see #PREFIX | ||||
|    */ | ||||
|   public static final class RedisLuaProxy { | ||||
| 
 | ||||
|     private final RedisCommandsHandler handler; | ||||
| 
 | ||||
|     public RedisLuaProxy(final RedisCommandsHandler handler) { | ||||
|       this.handler = handler; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Method name needs to match the one from the {@link #PREFIX} code. | ||||
|      * The method is getting called from the Lua scripting engine. | ||||
|      */ | ||||
|     @SuppressWarnings("unused") | ||||
|     public Object redisCall(final List<Object> args) { | ||||
|       assertFalse(args.isEmpty(), "`redis.call()` in Lua script invoked without arguments"); | ||||
|       assertTrue(args.get(0) instanceof String, "first argument to `redis.call()` must be of type `String`"); | ||||
|       return handler.redisCommand((String) args.get(0), tail(args, 1)); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | @ -0,0 +1,73 @@ | |||
| /* | ||||
|  * Copyright 2023 Signal Messenger, LLC | ||||
|  * SPDX-License-Identifier: AGPL-3.0-only | ||||
|  */ | ||||
| 
 | ||||
| package org.whispersystems.textsecuregcm.util.redis; | ||||
| 
 | ||||
| import java.time.Clock; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import java.util.concurrent.ConcurrentHashMap; | ||||
| 
 | ||||
| public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler { | ||||
| 
 | ||||
|   public record Entry(String value, long expirationEpochMillis) { | ||||
|   } | ||||
| 
 | ||||
|   private final Map<String, Entry> cache = new ConcurrentHashMap<>(); | ||||
| 
 | ||||
|   private final Clock clock; | ||||
| 
 | ||||
| 
 | ||||
|   public SimpleCacheCommandsHandler(final Clock clock) { | ||||
|     this.clock = clock; | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public Object set(final String key, final String value, final List<Object> tail) { | ||||
|     cache.put(key, new Entry(value, resolveExpirationEpochMillis(tail))); | ||||
|     return "OK"; | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public String get(final String key) { | ||||
|     final Entry entry = cache.get(key); | ||||
|     if (entry == null) { | ||||
|       return null; | ||||
|     } | ||||
|     if (entry.expirationEpochMillis() < clock.millis()) { | ||||
|       del(key); | ||||
|       return null; | ||||
|     } | ||||
|     return entry.value(); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public int del(final String key) { | ||||
|     return cache.remove(key) != null ? 1 : 0; | ||||
|   } | ||||
| 
 | ||||
|   protected long resolveExpirationEpochMillis(final List<Object> args) { | ||||
|     for (int i = 0; i < args.size() - 1; i++) { | ||||
|       final long currentTimeMillis = clock.millis(); | ||||
|       final String param = args.get(i).toString(); | ||||
|       final String value = args.get(i + 1).toString(); | ||||
|       switch (param) { | ||||
|         case "EX" -> { | ||||
|           return currentTimeMillis + Double.valueOf(value).longValue() * 1000; | ||||
|         } | ||||
|         case "PX" -> { | ||||
|           return currentTimeMillis + Double.valueOf(value).longValue(); | ||||
|         } | ||||
|         case "EXAT" -> { | ||||
|           return Double.valueOf(value).longValue() * 1000; | ||||
|         } | ||||
|         case "PXAT" -> { | ||||
|           return Double.valueOf(value).longValue(); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     return Long.MAX_VALUE; | ||||
|   } | ||||
| } | ||||
		Loading…
	
		Reference in New Issue
	
	 Sergey Skrobotov
						Sergey Skrobotov