diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedisConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedisConfiguration.java index 3bb70ac00..6bc4015e1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedisConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedisConfiguration.java @@ -23,6 +23,7 @@ import org.hibernate.validator.constraints.URL; import javax.validation.Valid; import javax.validation.constraints.NotNull; +import java.time.Duration; import java.util.List; public class RedisConfiguration { @@ -35,6 +36,10 @@ public class RedisConfiguration { @NotNull private List replicaUrls; + @JsonProperty + @NotNull + private Duration timeout = Duration.ofSeconds(10); + @JsonProperty @NotNull @Valid @@ -48,6 +53,10 @@ public class RedisConfiguration { return replicaUrls; } + public Duration getTimeout() { + return timeout; + } + public CircuitBreakerConfiguration getCircuitBreakerConfiguration() { return circuitBreaker; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClient.java new file mode 100644 index 000000000..b02dc4ccd --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClient.java @@ -0,0 +1,114 @@ +package org.whispersystems.textsecuregcm.redis; + +import com.codahale.metrics.SharedMetricRegistries; +import com.codahale.metrics.Timer; +import com.google.common.annotations.VisibleForTesting; +import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.lettuce.core.RedisClient; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.codec.ByteArrayCodec; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.configuration.RedisConfiguration; +import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; +import org.whispersystems.textsecuregcm.util.Constants; + +import java.time.Duration; +import java.util.function.Consumer; +import java.util.function.Function; + +import static com.codahale.metrics.MetricRegistry.name; + +public class FaultTolerantRedisClient { + + private final RedisClient client; + + private final StatefulRedisConnection stringConnection; + private final StatefulRedisConnection binaryConnection; + private final CircuitBreaker circuitBreaker; + + private final Timer executeTimer; + + private static final Logger log = LoggerFactory.getLogger(FaultTolerantRedisClient.class); + + public FaultTolerantRedisClient(final String name, final RedisConfiguration redisConfiguration) { + this(name, RedisClient.create(redisConfiguration.getUrl()), redisConfiguration.getTimeout(), redisConfiguration.getCircuitBreakerConfiguration()); + } + + @VisibleForTesting + FaultTolerantRedisClient(final String name, final RedisClient redisClient, final Duration commandTimeout, final CircuitBreakerConfiguration circuitBreakerConfiguration) { + this.client = redisClient; + this.client.setDefaultTimeout(commandTimeout); + + this.stringConnection = client.connect(); + this.binaryConnection = client.connect(ByteArrayCodec.INSTANCE); + + this.circuitBreaker = CircuitBreaker.of(name, circuitBreakerConfiguration.toCircuitBreakerConfig()); + + CircuitBreakerUtil.registerMetrics(SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME), + circuitBreaker, + FaultTolerantRedisCluster.class); + + this.executeTimer = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME).timer(name(getClass(), name, "execute")); + } + + @VisibleForTesting + void shutdown() { + stringConnection.close(); + client.shutdown(); + } + + public void useClient(final Consumer> consumer) { + useConnection(stringConnection, consumer); + } + + public T withClient(final Function, T> function) { + return withConnection(stringConnection, function); + } + + public void useBinaryClient(final Consumer> consumer) { + useConnection(binaryConnection, consumer); + } + + public T withBinaryClient(final Function, T> function) { + return withConnection(binaryConnection, function); + } + + private void useConnection(final StatefulRedisConnection connection, final Consumer> consumer) { + try { + circuitBreaker.executeCheckedRunnable(() -> { + try (final Timer.Context ignored = executeTimer.time()) { + consumer.accept(connection); + } + }); + } catch (final Throwable t) { + log.warn("Redis operation failure", t); + + if (t instanceof RuntimeException) { + throw (RuntimeException) t; + } else { + throw new RuntimeException(t); + } + } + } + + private T withConnection(final StatefulRedisConnection connection, final Function, T> function) { + try { + return circuitBreaker.executeCheckedSupplier(() -> { + try (final Timer.Context ignored = executeTimer.time()) { + return function.apply(connection); + } + }); + } catch (final Throwable t) { + log.warn("Redis operation failure", t); + + if (t instanceof RuntimeException) { + throw (RuntimeException) t; + } else { + throw new RuntimeException(t); + } + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClientTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClientTest.java new file mode 100644 index 000000000..e5fa667c2 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClientTest.java @@ -0,0 +1,58 @@ +package org.whispersystems.textsecuregcm.redis; + +import io.github.resilience4j.circuitbreaker.CallNotPermittedException; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisException; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; +import org.junit.Before; +import org.junit.Test; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; + +import java.time.Duration; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class FaultTolerantRedisClientTest { + + private RedisCommands commands; + private FaultTolerantRedisClient faultTolerantRedisClient; + + @SuppressWarnings("unchecked") + @Before + public void setUp() { + final RedisClient redisClient = mock(RedisClient.class); + final StatefulRedisConnection clusterConnection = mock(StatefulRedisConnection.class); + + commands = mock(RedisCommands.class); + + when(redisClient.connect()).thenReturn(clusterConnection); + when(clusterConnection.sync()).thenReturn(commands); + + final CircuitBreakerConfiguration breakerConfiguration = new CircuitBreakerConfiguration(); + breakerConfiguration.setFailureRateThreshold(100); + breakerConfiguration.setRingBufferSizeInClosedState(1); + breakerConfiguration.setWaitDurationInOpenStateInSeconds(Integer.MAX_VALUE); + + faultTolerantRedisClient = new FaultTolerantRedisClient("test", redisClient, Duration.ofSeconds(2), breakerConfiguration); + } + + @Test + public void testBreaker() { + when(commands.get(anyString())) + .thenReturn("value") + .thenThrow(new io.lettuce.core.RedisException("Badness has ensued.")); + + assertEquals("value", faultTolerantRedisClient.withClient(connection -> connection.sync().get("key"))); + + assertThrows(RedisException.class, + () -> faultTolerantRedisClient.withClient(connection -> connection.sync().get("OH NO"))); + + assertThrows(CallNotPermittedException.class, + () -> faultTolerantRedisClient.withClient(connection -> connection.sync().get("OH NO"))); + } +}