diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index ed2ad5ede..d4f9a7bcb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -176,8 +176,10 @@ import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.push.PushLatencyManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; +import org.whispersystems.textsecuregcm.redis.ClusterFaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.ConnectionEventLogger; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.redis.ShardFaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; @@ -412,18 +414,25 @@ public class WhisperServerService extends Application keyspaceNotificationDispatchQueue = new ArrayBlockingQueue<>(100_000); Metrics.gaugeCollectionSize(name(getClass(), "keyspaceNotificationDispatchQueueSize"), Collections.emptyList(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/CircuitBreakerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/CircuitBreakerConfiguration.java index 2fd70bec1..96488e158 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/CircuitBreakerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/CircuitBreakerConfiguration.java @@ -41,8 +41,7 @@ public class CircuitBreakerConfiguration { @JsonProperty @NotNull - @Min(1) - private long waitDurationInOpenStateInSeconds = 10; + private Duration waitDurationInOpenState = Duration.ofSeconds(10); @JsonProperty private List ignoredExceptions = Collections.emptyList(); @@ -64,8 +63,8 @@ public class CircuitBreakerConfiguration { return slidingWindowMinimumNumberOfCalls; } - public long getWaitDurationInOpenStateInSeconds() { - return waitDurationInOpenStateInSeconds; + public Duration getWaitDurationInOpenState() { + return waitDurationInOpenState; } public List> getIgnoredExceptions() { @@ -101,8 +100,8 @@ public class CircuitBreakerConfiguration { } @VisibleForTesting - public void setWaitDurationInOpenStateInSeconds(int seconds) { - this.waitDurationInOpenStateInSeconds = seconds; + public void setWaitDurationInOpenState(Duration duration) { + this.waitDurationInOpenState = duration; } @VisibleForTesting @@ -115,7 +114,7 @@ public class CircuitBreakerConfiguration { .failureRateThreshold(getFailureRateThreshold()) .ignoreExceptions(getIgnoredExceptions().toArray(new Class[0])) .permittedNumberOfCallsInHalfOpenState(getPermittedNumberOfCallsInHalfOpenState()) - .waitDurationInOpenState(Duration.ofSeconds(getWaitDurationInOpenStateInSeconds())) + .waitDurationInOpenState(getWaitDurationInOpenState()) .slidingWindow(getSlidingWindowSize(), getSlidingWindowMinimumNumberOfCalls(), CircuitBreakerConfig.SlidingWindowType.COUNT_BASED) .build(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClient.java index cac822e86..110385428 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClient.java @@ -24,6 +24,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.IntStream; +import io.micrometer.core.instrument.Tags; import org.glassfish.jersey.SslConfigurator; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; @@ -56,7 +57,7 @@ public class FaultTolerantHttpClient { this.defaultRequestTimeout = defaultRequestTimeout; this.breaker = CircuitBreaker.of(name + "-breaker", circuitBreakerConfiguration.toCircuitBreakerConfig()); - CircuitBreakerUtil.registerMetrics(breaker, FaultTolerantHttpClient.class); + CircuitBreakerUtil.registerMetrics(breaker, FaultTolerantHttpClient.class, Tags.empty()); if (retryConfiguration != null) { if (this.retryExecutor == null) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java index 5fe177aaa..3db915a6c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java @@ -78,7 +78,7 @@ public class ProvisioningManager extends RedisPubSubAdapter impl this.circuitBreaker = CircuitBreaker.of("pubsub-breaker", circuitBreakerConfiguration.toCircuitBreakerConfig()); - CircuitBreakerUtil.registerMetrics(circuitBreaker, ProvisioningManager.class); + CircuitBreakerUtil.registerMetrics(circuitBreaker, ProvisioningManager.class, Tags.empty()); Metrics.gaugeMapSize(ACTIVE_LISTENERS_GAUGE_NAME, Tags.empty(), listenersByProvisioningAddress); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterFaultTolerantPubSubConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterFaultTolerantPubSubConnection.java new file mode 100644 index 000000000..b2e62da19 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterFaultTolerantPubSubConnection.java @@ -0,0 +1,106 @@ +/* + * Copyright 2013-2020 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.redis; + +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + +import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.github.resilience4j.retry.Retry; +import io.lettuce.core.RedisException; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tags; +import io.micrometer.core.instrument.Timer; +import java.util.function.Consumer; +import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; +import reactor.core.scheduler.Scheduler; + +public class ClusterFaultTolerantPubSubConnection implements FaultTolerantPubSubConnection { + + private static final Logger logger = LoggerFactory.getLogger(ClusterFaultTolerantPubSubConnection.class); + + + private final String name; + private final StatefulRedisClusterPubSubConnection pubSubConnection; + + private final CircuitBreaker circuitBreaker; + private final Retry retry; + private final Retry resubscribeRetry; + private final Scheduler topologyChangedEventScheduler; + + private final Timer executeTimer; + + public ClusterFaultTolerantPubSubConnection(final String name, + final StatefulRedisClusterPubSubConnection pubSubConnection, final CircuitBreaker circuitBreaker, + final Retry retry, final Retry resubscribeRetry, final Scheduler topologyChangedEventScheduler) { + this.name = name; + this.pubSubConnection = pubSubConnection; + this.circuitBreaker = circuitBreaker; + this.retry = retry; + this.resubscribeRetry = resubscribeRetry; + this.topologyChangedEventScheduler = topologyChangedEventScheduler; + + this.pubSubConnection.setNodeMessagePropagation(true); + + this.executeTimer = Metrics.timer(name(getClass(), "execute"), "clusterName", name + "-pubsub"); + + CircuitBreakerUtil.registerMetrics(circuitBreaker, ClusterFaultTolerantPubSubConnection.class, Tags.empty()); + } + + @Override + public void usePubSubConnection(final Consumer> consumer) { + try { + circuitBreaker.executeCheckedRunnable( + () -> retry.executeRunnable(() -> executeTimer.record(() -> consumer.accept(pubSubConnection)))); + } catch (final Throwable t) { + if (t instanceof RedisException) { + throw (RedisException) t; + } else { + throw new RedisException(t); + } + } + } + + @Override + public T withPubSubConnection(final Function, T> function) { + try { + return circuitBreaker.executeCheckedSupplier( + () -> retry.executeCallable(() -> executeTimer.record(() -> function.apply(pubSubConnection)))); + } catch (final Throwable t) { + if (t instanceof RedisException) { + throw (RedisException) t; + } else { + throw new RedisException(t); + } + } + } + + + @Override + public void subscribeToClusterTopologyChangedEvents(final Runnable eventHandler) { + + usePubSubConnection(connection -> connection.getResources().eventBus().get() + .filter(event -> event instanceof ClusterTopologyChangedEvent) + .subscribeOn(topologyChangedEventScheduler) + .subscribe(event -> { + logger.info("Got topology change event for {}, resubscribing all keyspace notifications", name); + + resubscribeRetry.executeRunnable(() -> { + try { + eventHandler.run(); + } catch (final RuntimeException e) { + logger.warn("Resubscribe for {} failed", name, e); + throw e; + } + }); + })); + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterFaultTolerantRedisCluster.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterFaultTolerantRedisCluster.java new file mode 100644 index 000000000..64b53f2d3 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterFaultTolerantRedisCluster.java @@ -0,0 +1,196 @@ +/* + * Copyright 2013-2020 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.redis; + +import com.google.common.annotations.VisibleForTesting; +import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.github.resilience4j.core.IntervalFunction; +import io.github.resilience4j.reactor.circuitbreaker.operator.CircuitBreakerOperator; +import io.github.resilience4j.reactor.retry.RetryOperator; +import io.github.resilience4j.retry.Retry; +import io.github.resilience4j.retry.RetryConfig; +import io.lettuce.core.ClientOptions.DisconnectedBehavior; +import io.lettuce.core.RedisCommandTimeoutException; +import io.lettuce.core.RedisException; +import io.lettuce.core.TimeoutOptions; +import io.lettuce.core.cluster.ClusterClientOptions; +import io.lettuce.core.cluster.ClusterTopologyRefreshOptions; +import io.lettuce.core.cluster.RedisClusterClient; +import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.resource.ClientResources; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; +import io.micrometer.core.instrument.Tags; +import org.reactivestreams.Publisher; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration; +import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; +import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + +/** + * A fault-tolerant access manager for a Redis cluster. A single circuit breaker protects all cluster + * calls. + */ +public class ClusterFaultTolerantRedisCluster implements FaultTolerantRedisCluster { + + private final String name; + + private final RedisClusterClient clusterClient; + + private final StatefulRedisClusterConnection stringConnection; + private final StatefulRedisClusterConnection binaryConnection; + + private final List> pubSubConnections = new ArrayList<>(); + + private final CircuitBreaker circuitBreaker; + private final Retry retry; + private final Retry topologyChangedEventRetry; + + public ClusterFaultTolerantRedisCluster(final String name, final RedisClusterConfiguration clusterConfiguration, + final ClientResources clientResources) { + this(name, + RedisClusterClient.create(clientResources, + RedisUriUtil.createRedisUriWithTimeout(clusterConfiguration.getConfigurationUri(), + clusterConfiguration.getTimeout())), + clusterConfiguration.getTimeout(), + clusterConfiguration.getCircuitBreakerConfiguration(), + clusterConfiguration.getRetryConfiguration()); + } + + @VisibleForTesting + ClusterFaultTolerantRedisCluster(final String name, final RedisClusterClient clusterClient, + final Duration commandTimeout, + final CircuitBreakerConfiguration circuitBreakerConfiguration, final RetryConfiguration retryConfiguration) { + this.name = name; + + this.clusterClient = clusterClient; + this.clusterClient.setOptions(ClusterClientOptions.builder() + .disconnectedBehavior(DisconnectedBehavior.REJECT_COMMANDS) + .validateClusterNodeMembership(false) + .topologyRefreshOptions(ClusterTopologyRefreshOptions.builder() + .enableAllAdaptiveRefreshTriggers() + .build()) + // for asynchronous commands + .timeoutOptions(TimeoutOptions.builder() + .fixedTimeout(commandTimeout) + .build()) + .publishOnScheduler(true) + .build()); + + this.stringConnection = clusterClient.connect(); + this.binaryConnection = clusterClient.connect(ByteArrayCodec.INSTANCE); + + this.circuitBreaker = CircuitBreaker.of(name + "-breaker", circuitBreakerConfiguration.toCircuitBreakerConfig()); + this.retry = Retry.of(name + "-retry", retryConfiguration.toRetryConfigBuilder() + .retryOnException(exception -> exception instanceof RedisCommandTimeoutException).build()); + final RetryConfig topologyChangedEventRetryConfig = RetryConfig.custom() + .maxAttempts(Integer.MAX_VALUE) + .intervalFunction( + IntervalFunction.ofExponentialRandomBackoff(Duration.ofSeconds(1), 1.5, Duration.ofSeconds(30))) + .build(); + + this.topologyChangedEventRetry = Retry.of(name + "-topologyChangedRetry", topologyChangedEventRetryConfig); + + CircuitBreakerUtil.registerMetrics(circuitBreaker, FaultTolerantRedisCluster.class, Tags.empty()); + CircuitBreakerUtil.registerMetrics(retry, FaultTolerantRedisCluster.class); + } + + @Override + public void shutdown() { + stringConnection.close(); + binaryConnection.close(); + + for (final StatefulRedisClusterPubSubConnection pubSubConnection : pubSubConnections) { + pubSubConnection.close(); + } + + clusterClient.shutdown(); + } + + @Override + public String getName() { + return name; + } + + @Override + public void useCluster(final Consumer> consumer) { + useConnection(stringConnection, consumer); + } + + @Override + public T withCluster(final Function, T> function) { + return withConnection(stringConnection, function); + } + + @Override + public void useBinaryCluster(final Consumer> consumer) { + useConnection(binaryConnection, consumer); + } + + @Override + public T withBinaryCluster(final Function, T> function) { + return withConnection(binaryConnection, function); + } + + @Override + public Publisher withBinaryClusterReactive( + final Function, Publisher> function) { + return withConnectionReactive(binaryConnection, function); + } + + @Override + public void useConnection(final StatefulRedisClusterConnection connection, + final Consumer> consumer) { + try { + circuitBreaker.executeCheckedRunnable(() -> retry.executeRunnable(() -> consumer.accept(connection))); + } catch (final Throwable t) { + if (t instanceof RedisException) { + throw (RedisException) t; + } else { + throw new RedisException(t); + } + } + } + + @Override + public T withConnection(final StatefulRedisClusterConnection connection, + final Function, T> function) { + try { + return circuitBreaker.executeCheckedSupplier(() -> retry.executeCallable(() -> function.apply(connection))); + } catch (final Throwable t) { + if (t instanceof RedisException) { + throw (RedisException) t; + } else { + throw new RedisException(t); + } + } + } + + @Override + public Publisher withConnectionReactive(final StatefulRedisClusterConnection connection, + final Function, Publisher> function) { + + return Flux.from(function.apply(connection)) + .transformDeferred(RetryOperator.of(retry)) + .transformDeferred(CircuitBreakerOperator.of(circuitBreaker)); + } + + public FaultTolerantPubSubConnection createPubSubConnection() { + final StatefulRedisClusterPubSubConnection pubSubConnection = clusterClient.connectPubSub(); + pubSubConnections.add(pubSubConnection); + + return new ClusterFaultTolerantPubSubConnection<>(name, pubSubConnection, circuitBreaker, retry, + topologyChangedEventRetry, + Schedulers.newSingle(name + "-redisPubSubEvents", true)); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubConnection.java index 7660b8042..6ffef4b81 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubConnection.java @@ -1,102 +1,19 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2024 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.redis; -import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; - -import io.github.resilience4j.circuitbreaker.CircuitBreaker; -import io.github.resilience4j.retry.Retry; -import io.lettuce.core.RedisException; -import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; -import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Timer; import java.util.function.Consumer; import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; -import reactor.core.scheduler.Scheduler; -public class FaultTolerantPubSubConnection { +public interface FaultTolerantPubSubConnection { - private static final Logger logger = LoggerFactory.getLogger(FaultTolerantPubSubConnection.class); + void usePubSubConnection(Consumer> consumer); + T withPubSubConnection(Function, T> function); - private final String name; - private final StatefulRedisClusterPubSubConnection pubSubConnection; - - private final CircuitBreaker circuitBreaker; - private final Retry retry; - private final Retry resubscribeRetry; - private final Scheduler topologyChangedEventScheduler; - - private final Timer executeTimer; - - public FaultTolerantPubSubConnection(final String name, - final StatefulRedisClusterPubSubConnection pubSubConnection, final CircuitBreaker circuitBreaker, - final Retry retry, final Retry resubscribeRetry, final Scheduler topologyChangedEventScheduler) { - this.name = name; - this.pubSubConnection = pubSubConnection; - this.circuitBreaker = circuitBreaker; - this.retry = retry; - this.resubscribeRetry = resubscribeRetry; - this.topologyChangedEventScheduler = topologyChangedEventScheduler; - - this.pubSubConnection.setNodeMessagePropagation(true); - - this.executeTimer = Metrics.timer(name(getClass(), "execute"), "clusterName", name + "-pubsub"); - - CircuitBreakerUtil.registerMetrics(circuitBreaker, FaultTolerantPubSubConnection.class); - } - - public void usePubSubConnection(final Consumer> consumer) { - try { - circuitBreaker.executeCheckedRunnable( - () -> retry.executeRunnable(() -> executeTimer.record(() -> consumer.accept(pubSubConnection)))); - } catch (final Throwable t) { - if (t instanceof RedisException) { - throw (RedisException) t; - } else { - throw new RedisException(t); - } - } - } - - public T withPubSubConnection(final Function, T> function) { - try { - return circuitBreaker.executeCheckedSupplier( - () -> retry.executeCallable(() -> executeTimer.record(() -> function.apply(pubSubConnection)))); - } catch (final Throwable t) { - if (t instanceof RedisException) { - throw (RedisException) t; - } else { - throw new RedisException(t); - } - } - } - - - public void subscribeToClusterTopologyChangedEvents(final Runnable eventHandler) { - - usePubSubConnection(connection -> connection.getResources().eventBus().get() - .filter(event -> event instanceof ClusterTopologyChangedEvent) - .subscribeOn(topologyChangedEventScheduler) - .subscribe(event -> { - logger.info("Got topology change event for {}, resubscribing all keyspace notifications", name); - - resubscribeRetry.executeRunnable(() -> { - try { - eventHandler.run(); - } catch (final RuntimeException e) { - logger.warn("Resubscribe for {} failed", name, e); - throw e; - } - }); - })); - } - + void subscribeToClusterTopologyChangedEvents(Runnable eventHandler); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java index f9dd2659d..ad6203f7f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java @@ -1,183 +1,40 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2024 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.redis; -import com.google.common.annotations.VisibleForTesting; -import io.github.resilience4j.circuitbreaker.CircuitBreaker; -import io.github.resilience4j.core.IntervalFunction; -import io.github.resilience4j.reactor.circuitbreaker.operator.CircuitBreakerOperator; -import io.github.resilience4j.reactor.retry.RetryOperator; -import io.github.resilience4j.retry.Retry; -import io.github.resilience4j.retry.RetryConfig; -import io.lettuce.core.ClientOptions.DisconnectedBehavior; -import io.lettuce.core.RedisCommandTimeoutException; -import io.lettuce.core.RedisException; -import io.lettuce.core.TimeoutOptions; -import io.lettuce.core.cluster.ClusterClientOptions; -import io.lettuce.core.cluster.ClusterTopologyRefreshOptions; -import io.lettuce.core.cluster.RedisClusterClient; import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; -import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; -import io.lettuce.core.codec.ByteArrayCodec; -import io.lettuce.core.resource.ClientResources; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; import java.util.function.Consumer; import java.util.function.Function; import org.reactivestreams.Publisher; -import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; -import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration; -import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; -import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; -import reactor.core.publisher.Flux; -import reactor.core.scheduler.Schedulers; -/** - * A fault-tolerant access manager for a Redis cluster. A fault-tolerant Redis cluster provides managed, - * circuit-breaker-protected access to connections. - */ -public class FaultTolerantRedisCluster { +public interface FaultTolerantRedisCluster { - private final String name; + void shutdown(); - private final RedisClusterClient clusterClient; + String getName(); - private final StatefulRedisClusterConnection stringConnection; - private final StatefulRedisClusterConnection binaryConnection; + void useCluster(Consumer> consumer); - private final List> pubSubConnections = new ArrayList<>(); + T withCluster(Function, T> function); - private final CircuitBreaker circuitBreaker; - private final Retry retry; - private final Retry topologyChangedEventRetry; + void useBinaryCluster(Consumer> consumer); - public FaultTolerantRedisCluster(final String name, final RedisClusterConfiguration clusterConfiguration, - final ClientResources clientResources) { - this(name, - RedisClusterClient.create(clientResources, - RedisUriUtil.createRedisUriWithTimeout(clusterConfiguration.getConfigurationUri(), - clusterConfiguration.getTimeout())), - clusterConfiguration.getTimeout(), - clusterConfiguration.getCircuitBreakerConfiguration(), - clusterConfiguration.getRetryConfiguration()); - } + T withBinaryCluster(Function, T> function); - @VisibleForTesting - FaultTolerantRedisCluster(final String name, final RedisClusterClient clusterClient, final Duration commandTimeout, - final CircuitBreakerConfiguration circuitBreakerConfiguration, final RetryConfiguration retryConfiguration) { - this.name = name; + Publisher withBinaryClusterReactive( + Function, Publisher> function); - this.clusterClient = clusterClient; - this.clusterClient.setOptions(ClusterClientOptions.builder() - .disconnectedBehavior(DisconnectedBehavior.REJECT_COMMANDS) - .validateClusterNodeMembership(false) - .topologyRefreshOptions(ClusterTopologyRefreshOptions.builder() - .enableAllAdaptiveRefreshTriggers() - .build()) - // for asynchronous commands - .timeoutOptions(TimeoutOptions.builder() - .fixedTimeout(commandTimeout) - .build()) - .publishOnScheduler(true) - .build()); + void useConnection(StatefulRedisClusterConnection connection, + Consumer> consumer); - this.stringConnection = clusterClient.connect(); - this.binaryConnection = clusterClient.connect(ByteArrayCodec.INSTANCE); + T withConnection(StatefulRedisClusterConnection connection, + Function, T> function); - this.circuitBreaker = CircuitBreaker.of(name + "-breaker", circuitBreakerConfiguration.toCircuitBreakerConfig()); - this.retry = Retry.of(name + "-retry", retryConfiguration.toRetryConfigBuilder() - .retryOnException(exception -> exception instanceof RedisCommandTimeoutException).build()); - final RetryConfig topologyChangedEventRetryConfig = RetryConfig.custom() - .maxAttempts(Integer.MAX_VALUE) - .intervalFunction( - IntervalFunction.ofExponentialRandomBackoff(Duration.ofSeconds(1), 1.5, Duration.ofSeconds(30))) - .build(); + Publisher withConnectionReactive(StatefulRedisClusterConnection connection, + Function, Publisher> function); - this.topologyChangedEventRetry = Retry.of(name + "-topologyChangedRetry", topologyChangedEventRetryConfig); - - CircuitBreakerUtil.registerMetrics(circuitBreaker, FaultTolerantRedisCluster.class); - CircuitBreakerUtil.registerMetrics(retry, FaultTolerantRedisCluster.class); - } - - void shutdown() { - stringConnection.close(); - binaryConnection.close(); - - for (final StatefulRedisClusterPubSubConnection pubSubConnection : pubSubConnections) { - pubSubConnection.close(); - } - - clusterClient.shutdown(); - } - - public String getName() { - return name; - } - - public void useCluster(final Consumer> consumer) { - useConnection(stringConnection, consumer); - } - - public T withCluster(final Function, T> function) { - return withConnection(stringConnection, function); - } - - public void useBinaryCluster(final Consumer> consumer) { - useConnection(binaryConnection, consumer); - } - - public T withBinaryCluster(final Function, T> function) { - return withConnection(binaryConnection, function); - } - - public Publisher withBinaryClusterReactive( - final Function, Publisher> function) { - return withConnectionReactive(binaryConnection, function); - } - - private void useConnection(final StatefulRedisClusterConnection connection, - final Consumer> consumer) { - try { - circuitBreaker.executeCheckedRunnable(() -> retry.executeRunnable(() -> consumer.accept(connection))); - } catch (final Throwable t) { - if (t instanceof RedisException) { - throw (RedisException) t; - } else { - throw new RedisException(t); - } - } - } - - private T withConnection(final StatefulRedisClusterConnection connection, - final Function, T> function) { - try { - return circuitBreaker.executeCheckedSupplier(() -> retry.executeCallable(() -> function.apply(connection))); - } catch (final Throwable t) { - if (t instanceof RedisException) { - throw (RedisException) t; - } else { - throw new RedisException(t); - } - } - } - - private Publisher withConnectionReactive(final StatefulRedisClusterConnection connection, - final Function, Publisher> function) { - - return Flux.from(function.apply(connection)) - .transformDeferred(RetryOperator.of(retry)) - .transformDeferred(CircuitBreakerOperator.of(circuitBreaker)); - } - - public FaultTolerantPubSubConnection createPubSubConnection() { - final StatefulRedisClusterPubSubConnection pubSubConnection = clusterClient.connectPubSub(); - pubSubConnections.add(pubSubConnection); - - return new FaultTolerantPubSubConnection<>(name, pubSubConnection, circuitBreaker, retry, topologyChangedEventRetry, - Schedulers.newSingle(name + "-redisPubSubEvents", true)); - } + FaultTolerantPubSubConnection createPubSubConnection(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java new file mode 100644 index 000000000..f859ae216 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java @@ -0,0 +1,136 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.redis; + +import com.google.common.annotations.VisibleForTesting; +import io.github.resilience4j.circuitbreaker.CallNotPermittedException; +import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig; +import io.lettuce.core.protocol.CommandHandler; +import io.lettuce.core.protocol.CompleteableCommand; +import io.lettuce.core.protocol.RedisCommand; +import io.lettuce.core.resource.NettyCustomizer; +import io.micrometer.core.instrument.Tags; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import java.net.SocketAddress; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.StreamSupport; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; + +public class LettuceShardCircuitBreaker implements NettyCustomizer { + + private final String clusterName; + private final CircuitBreakerConfig circuitBreakerConfig; + + public LettuceShardCircuitBreaker(final String clusterName, final CircuitBreakerConfig circuitBreakerConfig) { + this.clusterName = clusterName; + this.circuitBreakerConfig = circuitBreakerConfig; + } + + @Override + public void afterChannelInitialized(final Channel channel) { + + final ChannelCircuitBreakerHandler channelCircuitBreakerHandler = new ChannelCircuitBreakerHandler(clusterName, + circuitBreakerConfig); + + final String commandHandlerName = StreamSupport.stream(channel.pipeline().spliterator(), false) + .filter(entry -> entry.getValue() instanceof CommandHandler) + .map(Map.Entry::getKey) + .findFirst() + .orElseThrow(); + channel.pipeline().addBefore(commandHandlerName, null, channelCircuitBreakerHandler); + } + + static final class ChannelCircuitBreakerHandler extends ChannelDuplexHandler { + + private static final Logger logger = LoggerFactory.getLogger(ChannelCircuitBreakerHandler.class); + + private static final String SHARD_TAG_NAME = "shard"; + private static final String CLUSTER_TAG_NAME = "cluster"; + + private final String clusterName; + private final CircuitBreakerConfig circuitBreakerConfig; + + @VisibleForTesting + CircuitBreaker breaker; + + public ChannelCircuitBreakerHandler(final String name, CircuitBreakerConfig circuitBreakerConfig) { + this.clusterName = name; + this.circuitBreakerConfig = circuitBreakerConfig; + } + + @Override + public void connect(final ChannelHandlerContext ctx, final SocketAddress remoteAddress, + final SocketAddress localAddress, final ChannelPromise promise) throws Exception { + super.connect(ctx, remoteAddress, localAddress, promise); + // Unfortunately, the Channel's remote address is null until connect() is called, so we have to wait to initialize + // the breaker with the remote’s name. + // There is a Channel attribute, io.lettuce.core.ConnectionBuilder.REDIS_URI, but this does not always + // match remote address, as it is inherited from the Bootstrap attributes and not updated for the Channel connection + + // In some cases, like the default connection, the remote address includes the DNS hostname, which we want to exclude. + final String shardAddress = StringUtils.substringAfter(remoteAddress.toString(), "/"); + breaker = CircuitBreaker.of("%s/%s-breaker".formatted(clusterName, shardAddress), circuitBreakerConfig); + CircuitBreakerUtil.registerMetrics(breaker, getClass(), + Tags.of(CLUSTER_TAG_NAME, clusterName, SHARD_TAG_NAME, shardAddress)); + } + + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) + throws Exception { + + logger.trace("Breaker state is {}", breaker.getState()); + + // Note: io.lettuce.core.protocol.CommandHandler also supports batches (List/Collection), + // but we do not use that feature, so we can just check for single commands + // + // There are two types of RedisCommands that are not CompleteableCommand: + // - io.lettuce.core.protocol.Command + // - io.lettuce.core.protocol.PristineFallbackCommand + // + // The former always get wrapped by one of the other command types, and the latter is only used in an edge case + // to consume responses. + if (msg instanceof RedisCommand rc && rc instanceof CompleteableCommand command) { + try { + breaker.acquirePermission(); + + // state can change in acquirePermission() + logger.trace("Breaker is permitted: {}", breaker.getState()); + + final long startNanos = System.nanoTime(); + + command.onComplete((ignored, throwable) -> { + final long durationNanos = System.nanoTime() - startNanos; + + if (throwable != null) { + breaker.onError(durationNanos, TimeUnit.NANOSECONDS, throwable); + logger.debug("Command completed with error", throwable); + } else { + breaker.onSuccess(durationNanos, TimeUnit.NANOSECONDS); + } + }); + + } catch (final CallNotPermittedException e) { + rc.completeExceptionally(e); + promise.tryFailure(e); + return; + } + + } else { + logger.warn("Unexpected msg type: {}", msg.getClass()); + } + + super.write(ctx, msg, promise); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantPubSubConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantPubSubConnection.java new file mode 100644 index 000000000..bbd389c0c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantPubSubConnection.java @@ -0,0 +1,97 @@ +/* + * Copyright 2013-2020 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.redis; + +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + +import io.github.resilience4j.retry.Retry; +import io.lettuce.core.RedisException; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import java.util.function.Consumer; +import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.scheduler.Scheduler; + +public class ShardFaultTolerantPubSubConnection implements FaultTolerantPubSubConnection { + + private static final Logger logger = LoggerFactory.getLogger(ShardFaultTolerantPubSubConnection.class); + + + private final String name; + private final StatefulRedisClusterPubSubConnection pubSubConnection; + + private final Retry retry; + private final Retry resubscribeRetry; + private final Scheduler topologyChangedEventScheduler; + + private final Timer executeTimer; + + public ShardFaultTolerantPubSubConnection(final String name, + final StatefulRedisClusterPubSubConnection pubSubConnection, + final Retry retry, final Retry resubscribeRetry, final Scheduler topologyChangedEventScheduler) { + this.name = name; + this.pubSubConnection = pubSubConnection; + this.retry = retry; + this.resubscribeRetry = resubscribeRetry; + this.topologyChangedEventScheduler = topologyChangedEventScheduler; + + this.pubSubConnection.setNodeMessagePropagation(true); + + this.executeTimer = Metrics.timer(name(getClass(), "execute"), "clusterName", name + "-pubsub"); + } + + @Override + public void usePubSubConnection(final Consumer> consumer) { + try { + retry.executeRunnable(() -> executeTimer.record(() -> consumer.accept(pubSubConnection))); + } catch (final Throwable t) { + if (t instanceof RedisException) { + throw (RedisException) t; + } else { + throw new RedisException(t); + } + } + } + + @Override + public T withPubSubConnection(final Function, T> function) { + try { + return retry.executeCallable(() -> executeTimer.record(() -> function.apply(pubSubConnection))); + } catch (final Throwable t) { + if (t instanceof RedisException) { + throw (RedisException) t; + } else { + throw new RedisException(t); + } + } + } + + + @Override + public void subscribeToClusterTopologyChangedEvents(final Runnable eventHandler) { + + usePubSubConnection(connection -> connection.getResources().eventBus().get() + .filter(event -> event instanceof ClusterTopologyChangedEvent) + .subscribeOn(topologyChangedEventScheduler) + .subscribe(event -> { + logger.info("Got topology change event for {}, resubscribing all keyspace notifications", name); + + resubscribeRetry.executeRunnable(() -> { + try { + eventHandler.run(); + } catch (final RuntimeException e) { + logger.warn("Resubscribe for {} failed", name, e); + throw e; + } + }); + })); + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantRedisCluster.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantRedisCluster.java new file mode 100644 index 000000000..854980f29 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantRedisCluster.java @@ -0,0 +1,197 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.redis; + +import io.github.resilience4j.core.IntervalFunction; +import io.github.resilience4j.reactor.retry.RetryOperator; +import io.github.resilience4j.retry.Retry; +import io.github.resilience4j.retry.RetryConfig; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.RedisCommandTimeoutException; +import io.lettuce.core.RedisException; +import io.lettuce.core.RedisURI; +import io.lettuce.core.TimeoutOptions; +import io.lettuce.core.cluster.ClusterClientOptions; +import io.lettuce.core.cluster.ClusterTopologyRefreshOptions; +import io.lettuce.core.cluster.RedisClusterClient; +import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.resource.ClientResources; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration; +import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; +import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + +/** + * A fault-tolerant access manager for a Redis cluster. Each shard in the cluster has a dedicated circuit breaker. + * + * @see LettuceShardCircuitBreaker + */ +public class ShardFaultTolerantRedisCluster implements FaultTolerantRedisCluster { + + private final String name; + + private final RedisClusterClient clusterClient; + + private final StatefulRedisClusterConnection stringConnection; + private final StatefulRedisClusterConnection binaryConnection; + + private final List> pubSubConnections = new ArrayList<>(); + + private final Retry retry; + private final Retry topologyChangedEventRetry; + + + public ShardFaultTolerantRedisCluster(final String name, final RedisClusterConfiguration clusterConfiguration, + final ClientResources.Builder clientResourcesBuilder) { + + this(name, clientResourcesBuilder, + Collections.singleton(RedisUriUtil.createRedisUriWithTimeout(clusterConfiguration.getConfigurationUri(), + clusterConfiguration.getTimeout())), + clusterConfiguration.getTimeout(), + clusterConfiguration.getCircuitBreakerConfiguration(), + clusterConfiguration.getRetryConfiguration()); + + } + + ShardFaultTolerantRedisCluster(String name, final ClientResources.Builder clientResourcesBuilder, + Iterable redisUris, Duration commandTimeout, CircuitBreakerConfiguration circuitBreakerConfig, + RetryConfiguration retryConfiguration) { + + this.name = name; + + this.clusterClient = RedisClusterClient.create( + clientResourcesBuilder.nettyCustomizer( + new LettuceShardCircuitBreaker(name, circuitBreakerConfig.toCircuitBreakerConfig())). + build(), + redisUris); + this.clusterClient.setOptions(ClusterClientOptions.builder() + .disconnectedBehavior(ClientOptions.DisconnectedBehavior.REJECT_COMMANDS) + .validateClusterNodeMembership(false) + .topologyRefreshOptions(ClusterTopologyRefreshOptions.builder() + .enableAllAdaptiveRefreshTriggers() + .build()) + // for asynchronous commands + .timeoutOptions(TimeoutOptions.builder() + .fixedTimeout(commandTimeout) + .build()) + .publishOnScheduler(true) + .build()); + + this.stringConnection = clusterClient.connect(); + this.binaryConnection = clusterClient.connect(ByteArrayCodec.INSTANCE); + + this.retry = Retry.of(name + "-retry", retryConfiguration.toRetryConfigBuilder() + .retryOnException(exception -> exception instanceof RedisCommandTimeoutException).build()); + final RetryConfig topologyChangedEventRetryConfig = RetryConfig.custom() + .maxAttempts(Integer.MAX_VALUE) + .intervalFunction( + IntervalFunction.ofExponentialRandomBackoff(Duration.ofSeconds(1), 1.5, Duration.ofSeconds(30))) + .build(); + + this.topologyChangedEventRetry = Retry.of(name + "-topologyChangedRetry", topologyChangedEventRetryConfig); + + CircuitBreakerUtil.registerMetrics(retry, ShardFaultTolerantRedisCluster.class); + } + + @Override + public void shutdown() { + stringConnection.close(); + binaryConnection.close(); + + for (final StatefulRedisClusterPubSubConnection pubSubConnection : pubSubConnections) { + pubSubConnection.close(); + } + + clusterClient.shutdown(); + } + + public String getName() { + return name; + } + + @Override + public void useCluster(final Consumer> consumer) { + useConnection(stringConnection, consumer); + } + + @Override + public T withCluster(final Function, T> function) { + return withConnection(stringConnection, function); + } + + @Override + public void useBinaryCluster(final Consumer> consumer) { + useConnection(binaryConnection, consumer); + } + + @Override + public T withBinaryCluster(final Function, T> function) { + return withConnection(binaryConnection, function); + } + + @Override + public Publisher withBinaryClusterReactive( + final Function, Publisher> function) { + return withConnectionReactive(binaryConnection, function); + } + + @Override + public void useConnection(final StatefulRedisClusterConnection connection, + final Consumer> consumer) { + try { + retry.executeRunnable(() -> consumer.accept(connection)); + } catch (final Throwable t) { + if (t instanceof RedisException) { + throw (RedisException) t; + } else { + throw new RedisException(t); + } + } + } + + @Override + public T withConnection(final StatefulRedisClusterConnection connection, + final Function, T> function) { + try { + return retry.executeCallable(() -> function.apply(connection)); + } catch (final Throwable t) { + if (t instanceof RedisException) { + throw (RedisException) t; + } else { + throw new RedisException(t); + } + } + } + + @Override + public Publisher withConnectionReactive(final StatefulRedisClusterConnection connection, + final Function, Publisher> function) { + + return Flux.from(function.apply(connection)) + .transformDeferred(RetryOperator.of(retry)); + } + + @Override + public FaultTolerantPubSubConnection createPubSubConnection() { + final StatefulRedisClusterPubSubConnection pubSubConnection = clusterClient.connectPubSub(); + pubSubConnections.add(pubSubConnection); + + return new ShardFaultTolerantPubSubConnection<>(name, pubSubConnection, retry, topologyChangedEventRetry, + Schedulers.newSingle(name + "-redisPubSubEvents", true)); + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/CircuitBreakerUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/CircuitBreakerUtil.java index 99fb0dc8e..54e9de026 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/CircuitBreakerUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/CircuitBreakerUtil.java @@ -23,20 +23,22 @@ public class CircuitBreakerUtil { private static final String BREAKER_NAME_TAG_NAME = "breakerName"; private static final String OUTCOME_TAG_NAME = "outcome"; - public static void registerMetrics(CircuitBreaker circuitBreaker, Class clazz) { + public static void registerMetrics(CircuitBreaker circuitBreaker, Class clazz, Tags additionalTags) { final String breakerName = clazz.getSimpleName() + "/" + circuitBreaker.getName(); final Counter successCounter = Metrics.counter(CIRCUIT_BREAKER_CALL_COUNTER_NAME, - BREAKER_NAME_TAG_NAME, breakerName, - OUTCOME_TAG_NAME, "success"); + additionalTags.and( + BREAKER_NAME_TAG_NAME, breakerName, + OUTCOME_TAG_NAME, "success")); final Counter failureCounter = Metrics.counter(CIRCUIT_BREAKER_CALL_COUNTER_NAME, - BREAKER_NAME_TAG_NAME, breakerName, - OUTCOME_TAG_NAME, "failure"); + additionalTags.and( + BREAKER_NAME_TAG_NAME, breakerName, + OUTCOME_TAG_NAME, "failure")); final Counter unpermittedCounter = Metrics.counter(CIRCUIT_BREAKER_CALL_COUNTER_NAME, - BREAKER_NAME_TAG_NAME, breakerName, - OUTCOME_TAG_NAME, "unpermitted"); + additionalTags.and(BREAKER_NAME_TAG_NAME, breakerName, + OUTCOME_TAG_NAME, "unpermitted")); circuitBreaker.getEventPublisher().onSuccess(event -> { successCounter.increment(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java index 5b4d77132..5a51b4fca 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java @@ -29,6 +29,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.redis.ClusterFaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; @@ -107,7 +108,7 @@ public class AssignUsernameCommand extends EnvironmentCommand("test", pubSubConnection, circuitBreaker, + faultTolerantPubSubConnection = new ClusterFaultTolerantPubSubConnection<>("test", pubSubConnection, circuitBreaker, retry, resubscribeRetry, Schedulers.newSingle("test")); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java index 415a4b775..166c9d0e6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java @@ -6,9 +6,9 @@ package org.whispersystems.textsecuregcm.redis; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -60,13 +60,13 @@ class FaultTolerantRedisClusterTest { breakerConfiguration.setFailureRateThreshold(100); breakerConfiguration.setSlidingWindowSize(1); breakerConfiguration.setSlidingWindowMinimumNumberOfCalls(1); - breakerConfiguration.setWaitDurationInOpenStateInSeconds(Integer.MAX_VALUE); + breakerConfiguration.setWaitDurationInOpenState(Duration.ofSeconds(Integer.MAX_VALUE)); final RetryConfiguration retryConfiguration = new RetryConfiguration(); retryConfiguration.setMaxAttempts(3); retryConfiguration.setWaitDuration(0); - faultTolerantCluster = new FaultTolerantRedisCluster("test", clusterClient, Duration.ofSeconds(2), + faultTolerantCluster = new ClusterFaultTolerantRedisCluster("test", clusterClient, Duration.ofSeconds(2), breakerConfiguration, retryConfiguration); } @@ -84,7 +84,7 @@ class FaultTolerantRedisClusterTest { final RedisException redisException = assertThrows(RedisException.class, () -> faultTolerantCluster.withCluster(connection -> connection.sync().get("OH NO"))); - assertTrue(redisException.getCause() instanceof CallNotPermittedException); + assertInstanceOf(CallNotPermittedException.class, redisException.getCause()); } @Test @@ -132,7 +132,7 @@ class FaultTolerantRedisClusterTest { assertTimeoutPreemptively(Duration.ofSeconds(1), () -> { final ExecutionException asyncException = assertThrows(ExecutionException.class, () -> cluster.withCluster(connection -> connection.async().blpop(TIMEOUT.toMillis() * 2, "key")).get()); - assertTrue(asyncException.getCause() instanceof RedisCommandTimeoutException); + assertInstanceOf(RedisCommandTimeoutException.class, asyncException.getCause()); assertThrows(RedisCommandTimeoutException.class, () -> cluster.withCluster(connection -> connection.sync().blpop(TIMEOUT.toMillis() * 2, "key"))); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java new file mode 100644 index 000000000..4696aa398 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java @@ -0,0 +1,149 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.redis; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import io.github.resilience4j.circuitbreaker.CallNotPermittedException; +import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.codec.StringCodec; +import io.lettuce.core.output.StatusOutput; +import io.lettuce.core.protocol.AsyncCommand; +import io.lettuce.core.protocol.Command; +import io.lettuce.core.protocol.CommandHandler; +import io.lettuce.core.protocol.CommandType; +import io.lettuce.core.protocol.Endpoint; +import io.lettuce.core.resource.ClientResources; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import java.io.IOException; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.StreamSupport; +import javax.annotation.Nullable; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; + +class LettuceShardCircuitBreakerTest { + + private LettuceShardCircuitBreaker.ChannelCircuitBreakerHandler channelCircuitBreakerHandler; + + @BeforeEach + void setUp() { + channelCircuitBreakerHandler = new LettuceShardCircuitBreaker.ChannelCircuitBreakerHandler( + "test", + new CircuitBreakerConfiguration().toCircuitBreakerConfig()); + } + + @Test + void testAfterChannelInitialized() { + + final LettuceShardCircuitBreaker lettuceShardCircuitBreaker = new LettuceShardCircuitBreaker("test", + new CircuitBreakerConfiguration().toCircuitBreakerConfig()); + + final Channel channel = new EmbeddedChannel( + new CommandHandler(ClientOptions.create(), ClientResources.create(), mock(Endpoint.class))); + + lettuceShardCircuitBreaker.afterChannelInitialized(channel); + + final AtomicBoolean foundCommandHandler = new AtomicBoolean(false); + final AtomicBoolean foundChannelCircuitBreakerHandler = new AtomicBoolean(false); + StreamSupport.stream(channel.pipeline().spliterator(), false) + .forEach(nameAndHandler -> { + if (nameAndHandler.getValue() instanceof CommandHandler) { + foundCommandHandler.set(true); + } + if (nameAndHandler.getValue() instanceof LettuceShardCircuitBreaker.ChannelCircuitBreakerHandler) { + foundChannelCircuitBreakerHandler.set(true); + } + if (foundCommandHandler.get()) { + assertTrue(foundChannelCircuitBreakerHandler.get(), + "circuit breaker handler should be before the command handler"); + } + }); + + assertTrue(foundChannelCircuitBreakerHandler.get()); + assertTrue(foundCommandHandler.get()); + } + + @Test + void testHandlerConnect() throws Exception { + channelCircuitBreakerHandler.connect(mock(ChannelHandlerContext.class), mock(SocketAddress.class), + mock(SocketAddress.class), mock(ChannelPromise.class)); + + assertNotNull(channelCircuitBreakerHandler.breaker); + } + + @ParameterizedTest + @MethodSource + void testHandlerWriteBreakerClosed(@Nullable final Throwable t) throws Exception { + final CircuitBreaker breaker = mock(CircuitBreaker.class); + channelCircuitBreakerHandler.breaker = breaker; + + final AsyncCommand command = new AsyncCommand<>( + new Command<>(CommandType.PING, new StatusOutput<>(StringCodec.ASCII))); + final ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class); + final ChannelPromise channelPromise = mock(ChannelPromise.class); + channelCircuitBreakerHandler.write(channelHandlerContext, command, channelPromise); + + verify(breaker).acquirePermission(); + + if (t != null) { + command.completeExceptionally(t); + verify(breaker).onError(anyLong(), eq(TimeUnit.NANOSECONDS), eq(t)); + } else { + command.complete("PONG"); + verify(breaker).onSuccess(anyLong(), eq(TimeUnit.NANOSECONDS)); + } + + // write should always be forwarded when the breaker is closed + verify(channelHandlerContext).write(command, channelPromise); + } + + static List testHandlerWriteBreakerClosed() { + final List errors = new ArrayList<>(); + errors.add(null); + errors.add(new IOException("timeout")); + + return errors; + } + + @Test + void testHandlerWriteBreakerOpen() throws Exception { + final CircuitBreaker breaker = mock(CircuitBreaker.class); + channelCircuitBreakerHandler.breaker = breaker; + + final CallNotPermittedException callNotPermittedException = mock(CallNotPermittedException.class); + doThrow(callNotPermittedException).when(breaker).acquirePermission(); + + @SuppressWarnings("unchecked") final AsyncCommand command = mock(AsyncCommand.class); + final ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class); + final ChannelPromise channelPromise = mock(ChannelPromise.class); + channelCircuitBreakerHandler.write(channelHandlerContext, command, channelPromise); + + verify(command).completeExceptionally(callNotPermittedException); + verify(channelPromise).tryFailure(callNotPermittedException); + + verifyNoInteractions(channelHandlerContext); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java index 453aea38f..8521aaaa8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java @@ -20,7 +20,6 @@ import java.net.ServerSocket; import java.time.Duration; import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; import org.junit.jupiter.api.extension.AfterAllCallback; import org.junit.jupiter.api.extension.AfterEachCallback; import org.junit.jupiter.api.extension.BeforeAllCallback; @@ -50,8 +49,8 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb } - public static RedisClusterExtensionBuilder builder() { - return new RedisClusterExtensionBuilder(); + public static Builder builder() { + return new Builder(); } @Override @@ -81,12 +80,9 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb @Override public void beforeEach(final ExtensionContext context) throws Exception { - final List urls = Arrays.stream(CLUSTER_NODES) - .map(node -> String.format("redis://127.0.0.1:%d", node.ports().get(0))) - .toList(); - redisCluster = new FaultTolerantRedisCluster("test-cluster", - RedisClusterClient.create(urls.stream().map(RedisURI::create).collect(Collectors.toList())), + redisCluster = new ClusterFaultTolerantRedisCluster("test-cluster", + RedisClusterClient.create(getRedisURIs()), timeout, new CircuitBreakerConfiguration(), retryConfiguration); @@ -120,6 +116,13 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb redisCluster.useCluster(connection -> connection.sync().flushall()); } + public static List getRedisURIs() { + return Arrays.stream(CLUSTER_NODES) + .map(node -> "redis://127.0.0.1:%d".formatted(node.ports().getFirst())) + .map(RedisURI::create) + .toList(); + } + public FaultTolerantRedisCluster getRedisCluster() { return redisCluster; } @@ -140,12 +143,12 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb } private static void assembleCluster(final RedisServer... nodes) throws InterruptedException { - try (final RedisClient meetClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0)))) { + try (final RedisClient meetClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().getFirst()))) { final StatefulRedisConnection connection = meetClient.connect(); final RedisCommands commands = connection.sync(); for (int i = 1; i < nodes.length; i++) { - commands.clusterMeet("127.0.0.1", nodes[i].ports().get(0)); + commands.clusterMeet("127.0.0.1", nodes[i].ports().getFirst()); } } @@ -155,7 +158,8 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb final int startInclusive = i * slotsPerNode; final int endExclusive = i == nodes.length - 1 ? SlotHash.SLOT_COUNT : (i + 1) * slotsPerNode; - try (final RedisClient assignSlotClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[i].ports().get(0))); + try (final RedisClient assignSlotClient = RedisClient.create( + RedisURI.create("127.0.0.1", nodes[i].ports().getFirst())); final StatefulRedisConnection assignSlotConnection = assignSlotClient.connect()) { final int[] slots = new int[endExclusive - startInclusive]; @@ -167,7 +171,7 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb } } - try (final RedisClient waitClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0))); + try (final RedisClient waitClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().getFirst())); final StatefulRedisConnection connection = waitClient.connect()) { // CLUSTER INFO gives us a big blob of key-value pairs, but the one we're interested in is `cluster_state`. // According to https://redis.io/commands/cluster-info, `cluster_state:ok` means that the node is ready to @@ -181,7 +185,7 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb if (tries == 20) { throw new RuntimeException( - String.format("Timeout: Redis not ready after waiting %d milliseconds", tries * sleepMillis)); + "Timeout: Redis not ready after waiting %d milliseconds".formatted(tries * sleepMillis)); } } } @@ -215,20 +219,20 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb } } - public static class RedisClusterExtensionBuilder { + public static class Builder { private Duration timeout = DEFAULT_TIMEOUT; private RetryConfiguration retryConfiguration = new RetryConfiguration(); - private RedisClusterExtensionBuilder() { + private Builder() { } - RedisClusterExtensionBuilder timeout(Duration timeout) { + Builder timeout(Duration timeout) { this.timeout = timeout; return this; } - RedisClusterExtensionBuilder retryConfiguration(RetryConfiguration retryConfiguration) { + Builder retryConfiguration(RetryConfiguration retryConfiguration) { this.retryConfiguration = retryConfiguration; return this; } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantPubSubConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantPubSubConnectionTest.java new file mode 100644 index 000000000..f06b808cd --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantPubSubConnectionTest.java @@ -0,0 +1,197 @@ +/* + * Copyright 2013-2020 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.redis; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.github.resilience4j.core.IntervalFunction; +import io.github.resilience4j.retry.Retry; +import io.github.resilience4j.retry.RetryConfig; +import io.lettuce.core.RedisCommandTimeoutException; +import io.lettuce.core.RedisException; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; +import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands; +import io.lettuce.core.event.Event; +import io.lettuce.core.event.EventBus; +import io.lettuce.core.resource.ClientResources; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; +import reactor.test.publisher.TestPublisher; + +class ShardFaultTolerantPubSubConnectionTest { + + private StatefulRedisClusterPubSubConnection pubSubConnection; + private RedisClusterPubSubCommands pubSubCommands; + private ShardFaultTolerantPubSubConnection faultTolerantPubSubConnection; + + + @SuppressWarnings("unchecked") + @BeforeEach + public void setUp() { + pubSubConnection = mock(StatefulRedisClusterPubSubConnection.class); + + pubSubCommands = mock(RedisClusterPubSubCommands.class); + + when(pubSubConnection.sync()).thenReturn(pubSubCommands); + + final RetryConfiguration retryConfiguration = new RetryConfiguration(); + retryConfiguration.setMaxAttempts(3); + retryConfiguration.setWaitDuration(10); + + final Retry retry = Retry.of("test", retryConfiguration.toRetryConfig()); + + final RetryConfig resubscribeRetryConfiguration = RetryConfig.custom() + .maxAttempts(Integer.MAX_VALUE) + .intervalFunction(IntervalFunction.ofExponentialBackoff(5)) + .build(); + final Retry resubscribeRetry = Retry.of("test-resubscribe", resubscribeRetryConfiguration); + + faultTolerantPubSubConnection = new ShardFaultTolerantPubSubConnection<>("test", pubSubConnection, + retry, resubscribeRetry, Schedulers.newSingle("test")); + } + + @Test + void testRetry() { + when(pubSubCommands.get(anyString())) + .thenThrow(new RedisCommandTimeoutException()) + .thenThrow(new RedisCommandTimeoutException()) + .thenReturn("value"); + + assertEquals("value", + faultTolerantPubSubConnection.withPubSubConnection(connection -> connection.sync().get("key"))); + + when(pubSubCommands.get(anyString())) + .thenThrow(new RedisCommandTimeoutException()) + .thenThrow(new RedisCommandTimeoutException()) + .thenThrow(new RedisCommandTimeoutException()) + .thenReturn("value"); + + assertThrows(RedisCommandTimeoutException.class, + () -> faultTolerantPubSubConnection.withPubSubConnection(connection -> connection.sync().get("key"))); + } + + @Nested + class ClusterTopologyChangedEventTest { + + private TestPublisher eventPublisher; + + private Runnable resubscribe; + + private AtomicInteger resubscribeCounter; + private CountDownLatch resubscribeFailure; + private CountDownLatch resubscribeSuccess; + + @BeforeEach + @SuppressWarnings("unchecked") + void setup() { + // ignore inherited stubbing + reset(pubSubConnection); + + eventPublisher = TestPublisher.createCold(); + + final ClientResources clientResources = mock(ClientResources.class); + when(pubSubConnection.getResources()) + .thenReturn(clientResources); + final EventBus eventBus = mock(EventBus.class); + when(clientResources.eventBus()) + .thenReturn(eventBus); + + final Flux eventFlux = Flux.from(eventPublisher); + when(eventBus.get()).thenReturn(eventFlux); + + resubscribeCounter = new AtomicInteger(); + + resubscribe = () -> { + try { + resubscribeCounter.incrementAndGet(); + pubSubConnection.sync().nodes((ignored) -> true); + resubscribeSuccess.countDown(); + } catch (final RuntimeException e) { + resubscribeFailure.countDown(); + throw e; + } + }; + + resubscribeSuccess = new CountDownLatch(1); + resubscribeFailure = new CountDownLatch(1); + } + + @SuppressWarnings("unchecked") + @Test + void testSubscribeToClusterTopologyChangedEvents() throws Exception { + + when(pubSubConnection.sync()) + .thenThrow(new RedisException("Cluster unavailable")); + + eventPublisher.next(new ClusterTopologyChangedEvent(Collections.emptyList(), Collections.emptyList())); + + faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(resubscribe); + + assertTrue(resubscribeFailure.await(1, TimeUnit.SECONDS)); + + // simulate cluster recovery - no more exceptions, run the retry + reset(pubSubConnection); + clearInvocations(pubSubCommands); + when(pubSubConnection.sync()) + .thenReturn(pubSubCommands); + + assertTrue(resubscribeSuccess.await(1, TimeUnit.SECONDS)); + + assertTrue(resubscribeCounter.get() >= 2, String.format("resubscribe called %d times", resubscribeCounter.get())); + verify(pubSubCommands).nodes(any()); + } + + @Test + @SuppressWarnings("unchecked") + void testMultipleEventsWithPendingRetries() throws Exception { + // more complicated scenario: multiple events while retries are pending + + // cluster is down + when(pubSubConnection.sync()) + .thenThrow(new RedisException("Cluster unavailable")); + + // publish multiple topology changed events + eventPublisher.next(new ClusterTopologyChangedEvent(Collections.emptyList(), Collections.emptyList())); + eventPublisher.next(new ClusterTopologyChangedEvent(Collections.emptyList(), Collections.emptyList())); + eventPublisher.next(new ClusterTopologyChangedEvent(Collections.emptyList(), Collections.emptyList())); + eventPublisher.next(new ClusterTopologyChangedEvent(Collections.emptyList(), Collections.emptyList())); + + faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(resubscribe); + + assertTrue(resubscribeFailure.await(1, TimeUnit.SECONDS)); + + // simulate cluster recovery - no more exceptions, run the retry + reset(pubSubConnection); + clearInvocations(pubSubCommands); + when(pubSubConnection.sync()) + .thenReturn(pubSubCommands); + + assertTrue(resubscribeSuccess.await(1, TimeUnit.SECONDS)); + + verify(pubSubCommands, atLeastOnce()).nodes(any()); + } + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantRedisClusterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantRedisClusterTest.java new file mode 100644 index 000000000..be5714889 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/ShardFaultTolerantRedisClusterTest.java @@ -0,0 +1,495 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.redis; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.github.resilience4j.circuitbreaker.CallNotPermittedException; +import io.lettuce.core.RedisCommandTimeoutException; +import io.lettuce.core.RedisException; +import io.lettuce.core.RedisURI; +import io.lettuce.core.cluster.models.partitions.ClusterPartitionParser; +import io.lettuce.core.cluster.models.partitions.Partitions; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; +import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; +import io.lettuce.core.event.EventBus; +import io.lettuce.core.event.EventPublisherOptions; +import io.lettuce.core.metrics.CommandLatencyCollectorOptions; +import io.lettuce.core.metrics.CommandLatencyRecorder; +import io.lettuce.core.resource.ClientResources; +import io.lettuce.core.resource.Delay; +import io.lettuce.core.resource.DnsResolver; +import io.lettuce.core.resource.EventLoopGroupProvider; +import io.lettuce.core.resource.NettyCustomizer; +import io.lettuce.core.resource.SocketAddressResolver; +import io.lettuce.core.resource.ThreadFactoryProvider; +import io.lettuce.core.tracing.Tracing; +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.resolver.AddressResolverGroup; +import io.netty.util.Timer; +import io.netty.util.concurrent.EventExecutorGroup; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import org.apache.commons.lang3.StringUtils; +import org.jetbrains.annotations.Nullable; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; +import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.RedisClusterUtil; + +// ThreadMode.SEPARATE_THREAD protects against hangs in the remote Redis calls, as this mode allows the test code to be +// preempted by the timeout check +@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) +class ShardFaultTolerantRedisClusterTest { + + private static final Duration TIMEOUT = Duration.ofMillis(50); + + private static final RetryConfiguration RETRY_CONFIGURATION = new RetryConfiguration(); + + static { + RETRY_CONFIGURATION.setMaxAttempts(1); + RETRY_CONFIGURATION.setWaitDuration(50); + } + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder() + .retryConfiguration(RETRY_CONFIGURATION) + .timeout(TIMEOUT) + .build(); + + private ShardFaultTolerantRedisCluster cluster; + + private static ShardFaultTolerantRedisCluster buildCluster( + @Nullable final CircuitBreakerConfiguration circuitBreakerConfiguration, + final ClientResources.Builder clientResourcesBuilder) { + + return new ShardFaultTolerantRedisCluster("test", clientResourcesBuilder, + RedisClusterExtension.getRedisURIs(), TIMEOUT, + Optional.ofNullable(circuitBreakerConfiguration).orElseGet(CircuitBreakerConfiguration::new), + RETRY_CONFIGURATION); + } + + @AfterEach + void tearDown() { + cluster.shutdown(); + } + + @Test + void testTimeout() { + cluster = buildCluster(null, ClientResources.builder()); + + final ExecutionException asyncException = assertThrows(ExecutionException.class, + () -> cluster.withCluster(connection -> connection.async().blpop(2 * TIMEOUT.toMillis() / 1000d, "key")) + .get()); + + assertInstanceOf(RedisCommandTimeoutException.class, asyncException.getCause()); + + assertThrows(RedisCommandTimeoutException.class, + () -> cluster.withCluster(connection -> connection.sync().blpop(2 * TIMEOUT.toMillis() / 1000d, "key"))); + } + + @Test + void testTimeoutCircuitBreaker() throws Exception { + // because we’re using a single key, and blpop involves *Redis* also blocking, the breaker wait duration must be + // longer than the sum of the remote timeouts + final Duration breakerWaitDuration = TIMEOUT.multipliedBy(5); + + final CircuitBreakerConfiguration circuitBreakerConfig = new CircuitBreakerConfiguration(); + circuitBreakerConfig.setFailureRateThreshold(1); + circuitBreakerConfig.setSlidingWindowMinimumNumberOfCalls(1); + circuitBreakerConfig.setSlidingWindowSize(1); + circuitBreakerConfig.setWaitDurationInOpenState(breakerWaitDuration); + + cluster = buildCluster(circuitBreakerConfig, ClientResources.builder()); + + final String key = "key"; + + // the first call should time out and open the breaker + assertThrows(RedisCommandTimeoutException.class, + () -> cluster.withCluster(connection -> connection.sync().blpop(2 * TIMEOUT.toMillis() / 1000d, key))); + + // the second call gets blocked by the breaker + final RedisException e = assertThrows(RedisException.class, + () -> cluster.withCluster(connection -> connection.sync().blpop(2 * TIMEOUT.toMillis() / 1000d, key))); + assertInstanceOf(CallNotPermittedException.class, e.getCause()); + + // wait for breaker to be half-open + Thread.sleep(breakerWaitDuration.toMillis() * 2); + + assertEquals(0, (Long) cluster.withCluster(connection -> connection.sync().llen(key))); + } + + @Test + void testShardUnavailable() { + final TestBreakerManager testBreakerManager = new TestBreakerManager(); + final CircuitBreakerConfiguration circuitBreakerConfig = new CircuitBreakerConfiguration(); + circuitBreakerConfig.setFailureRateThreshold(1); + circuitBreakerConfig.setSlidingWindowMinimumNumberOfCalls(2); + circuitBreakerConfig.setSlidingWindowSize(5); + + final ClientResources.Builder builder = CompositeNettyCustomizerClientResourcesBuilder.builder() + .nettyCustomizer(testBreakerManager); + + cluster = buildCluster(circuitBreakerConfig, builder); + + // this test will open the breaker on one shard and check that other shards are still available, + // so we get two nodes and a slot+key on each to test + final Pair nodePair = + cluster.withCluster(connection -> { + Partitions partitions = ClusterPartitionParser.parse(connection.sync().clusterNodes()); + + assertTrue(partitions.size() >= 2); + + return new Pair<>(partitions.getPartition(0), partitions.getPartition(1)); + }); + + final RedisClusterNode unavailableNode = nodePair.first(); + final int unavailableSlot = unavailableNode.getSlots().getFirst(); + final String unavailableKey = "key::{%s}".formatted(RedisClusterUtil.getMinimalHashTag(unavailableSlot)); + + final int availableSlot = nodePair.second().getSlots().getFirst(); + final String availableKey = "key::{%s}".formatted(RedisClusterUtil.getMinimalHashTag(availableSlot)); + + cluster.useCluster(connection -> { + connection.sync().set(unavailableKey, "unavailable"); + connection.sync().set(availableKey, "available"); + + assertEquals("unavailable", connection.sync().get(unavailableKey)); + assertEquals("available", connection.sync().get(availableKey)); + }); + + // shard is now unavailable + testBreakerManager.openBreaker(unavailableNode.getUri()); + final RedisException e = assertThrows(RedisException.class, () -> + cluster.useCluster(connection -> connection.sync().get(unavailableKey))); + assertInstanceOf(CallNotPermittedException.class, e.getCause()); + + // other shard is still available + assertEquals("available", cluster.withCluster(connection -> connection.sync().get(availableKey))); + + // shard is available again + testBreakerManager.closeBreaker(unavailableNode.getUri()); + assertEquals("unavailable", cluster.withCluster(connection -> connection.sync().get(unavailableKey))); + } + + @Test + void testShardUnavailablePubSub() throws Exception { + final TestBreakerManager testBreakerManager = new TestBreakerManager(); + final CircuitBreakerConfiguration circuitBreakerConfig = new CircuitBreakerConfiguration(); + circuitBreakerConfig.setFailureRateThreshold(1); + circuitBreakerConfig.setSlidingWindowMinimumNumberOfCalls(2); + circuitBreakerConfig.setSlidingWindowSize(5); + + final ClientResources.Builder builder = CompositeNettyCustomizerClientResourcesBuilder.builder() + .nettyCustomizer(testBreakerManager); + + cluster = buildCluster(circuitBreakerConfig, builder); + + cluster.useCluster( + connection -> connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz")); + + // this test will open the breaker on one shard and check that other shards are still available, + // so we get two nodes and a slot+key on each to test + final Pair nodePair = + cluster.withCluster(connection -> { + Partitions partitions = ClusterPartitionParser.parse(connection.sync().clusterNodes()); + + assertTrue(partitions.size() >= 2); + + return new Pair<>(partitions.getPartition(0), partitions.getPartition(1)); + }); + + final RedisClusterNode unavailableNode = nodePair.first(); + final int unavailableSlot = unavailableNode.getSlots().getFirst(); + final String unavailableKey = "key::{%s}".formatted(RedisClusterUtil.getMinimalHashTag(unavailableSlot)); + + final RedisClusterNode availableNode = nodePair.second(); + final int availableSlot = availableNode.getSlots().getFirst(); + final String availableKey = "key::{%s}".formatted(RedisClusterUtil.getMinimalHashTag(availableSlot)); + + final FaultTolerantPubSubConnection pubSubConnection = cluster.createPubSubConnection(); + + // Keyspace notifications are delivered on a different thread, so we use a CountDownLatch to wait for the + // expected number of notifications to arrive + final AtomicReference countDownLatchRef = new AtomicReference<>(); + + final Map channelMessageCounts = new ConcurrentHashMap<>(); + final String keyspacePrefix = "__keyspace@0__:"; + final RedisClusterPubSubAdapter listener = new RedisClusterPubSubAdapter<>() { + @Override + public void message(final RedisClusterNode node, final String channel, final String message) { + channelMessageCounts.computeIfAbsent(StringUtils.substringAfter(channel, keyspacePrefix), + k -> new AtomicInteger(0)) + .incrementAndGet(); + + countDownLatchRef.get().countDown(); + } + }; + + countDownLatchRef.set(new CountDownLatch(2)); + pubSubConnection.usePubSubConnection(c -> { + c.addListener(listener); + c.sync().nodes(node -> node.is(RedisClusterNode.NodeFlag.UPSTREAM) && node.hasSlot(availableSlot)) + .commands() + .subscribe(keyspacePrefix + availableKey); + c.sync().nodes(node -> node.is(RedisClusterNode.NodeFlag.UPSTREAM) && node.hasSlot(unavailableSlot)) + .commands() + .subscribe(keyspacePrefix + unavailableKey); + }); + + cluster.useCluster(connection -> { + connection.sync().set(availableKey, "ping1"); + connection.sync().set(unavailableKey, "ping1"); + }); + + countDownLatchRef.get().await(); + + assertEquals(1, channelMessageCounts.get(availableKey).get()); + assertEquals(1, channelMessageCounts.get(unavailableKey).get()); + + // shard is now unavailable + testBreakerManager.openBreaker(unavailableNode.getUri()); + + final RedisException e = assertThrows(RedisException.class, () -> + cluster.useCluster(connection -> connection.sync().set(unavailableKey, "ping2"))); + assertInstanceOf(CallNotPermittedException.class, e.getCause()); + assertEquals(1, channelMessageCounts.get(unavailableKey).get()); + assertEquals(1, channelMessageCounts.get(availableKey).get()); + + countDownLatchRef.set(new CountDownLatch(1)); + pubSubConnection.usePubSubConnection(connection -> connection.sync().set(availableKey, "ping2")); + + countDownLatchRef.get().await(); + + assertEquals(1, channelMessageCounts.get(unavailableKey).get()); + assertEquals(2, channelMessageCounts.get(availableKey).get()); + + // shard is available again + testBreakerManager.closeBreaker(unavailableNode.getUri()); + + countDownLatchRef.set(new CountDownLatch(2)); + + cluster.useCluster(connection -> { + connection.sync().set(availableKey, "ping3"); + connection.sync().set(unavailableKey, "ping3"); + }); + + countDownLatchRef.get().await(); + + assertEquals(2, channelMessageCounts.get(unavailableKey).get()); + assertEquals(3, channelMessageCounts.get(availableKey).get()); + } + + @ChannelHandler.Sharable + private static class TestBreakerManager extends ChannelDuplexHandler implements NettyCustomizer { + + private final Map> urisToChannelBreakers = new ConcurrentHashMap<>(); + private final AtomicInteger counter = new AtomicInteger(); + + @Override + public void afterChannelInitialized(Channel channel) { + channel.pipeline().addFirst("TestBreakerManager#" + counter.getAndIncrement(), this); + } + + @Override + public void connect(final ChannelHandlerContext ctx, final SocketAddress remoteAddress, + final SocketAddress localAddress, final ChannelPromise promise) throws Exception { + + super.connect(ctx, remoteAddress, localAddress, promise); + + final LettuceShardCircuitBreaker.ChannelCircuitBreakerHandler channelCircuitBreakerHandler = + ctx.channel().pipeline().get(LettuceShardCircuitBreaker.ChannelCircuitBreakerHandler.class); + + urisToChannelBreakers.computeIfAbsent(getRedisURI(ctx.channel()), ignored -> new HashSet<>()) + .add(channelCircuitBreakerHandler); + } + + private static RedisURI getRedisURI(Channel channel) { + final InetSocketAddress inetAddress = (InetSocketAddress) channel.remoteAddress(); + return RedisURI.create(inetAddress.getHostString(), inetAddress.getPort()); + } + + void openBreaker(final RedisURI redisURI) { + urisToChannelBreakers.get(redisURI).forEach(handler -> handler.breaker.transitionToOpenState()); + } + + void closeBreaker(final RedisURI redisURI) { + urisToChannelBreakers.get(redisURI).forEach(handler -> handler.breaker.transitionToClosedState()); + } + } + + static class CompositeNettyCustomizer implements NettyCustomizer { + + private final List nettyCustomizers = new ArrayList<>(); + + @Override + public void afterBootstrapInitialized(final Bootstrap bootstrap) { + nettyCustomizers.forEach(nc -> nc.afterBootstrapInitialized(bootstrap)); + } + + @Override + public void afterChannelInitialized(final Channel channel) { + nettyCustomizers.forEach(nc -> nc.afterChannelInitialized(channel)); + } + + void add(NettyCustomizer customizer) { + nettyCustomizers.add(customizer); + } + } + + static class CompositeNettyCustomizerClientResourcesBuilder implements ClientResources.Builder { + + private final CompositeNettyCustomizer compositeNettyCustomizer; + private final ClientResources.Builder delegate; + + static CompositeNettyCustomizerClientResourcesBuilder builder() { + return new CompositeNettyCustomizerClientResourcesBuilder(); + } + + private CompositeNettyCustomizerClientResourcesBuilder() { + this.compositeNettyCustomizer = new CompositeNettyCustomizer(); + this.delegate = ClientResources.builder().nettyCustomizer(compositeNettyCustomizer); + } + + + @Override + public ClientResources.Builder addressResolverGroup(final AddressResolverGroup addressResolverGroup) { + delegate.addressResolverGroup(addressResolverGroup); + return this; + } + + @Override + public ClientResources.Builder commandLatencyRecorder(final CommandLatencyRecorder latencyRecorder) { + delegate.commandLatencyRecorder(latencyRecorder); + return this; + } + + @Override + @Deprecated + public ClientResources.Builder commandLatencyCollectorOptions( + final CommandLatencyCollectorOptions commandLatencyCollectorOptions) { + delegate.commandLatencyCollectorOptions(commandLatencyCollectorOptions); + return this; + } + + @Override + public ClientResources.Builder commandLatencyPublisherOptions( + final EventPublisherOptions commandLatencyPublisherOptions) { + delegate.commandLatencyPublisherOptions(commandLatencyPublisherOptions); + return this; + } + + @Override + public ClientResources.Builder computationThreadPoolSize(final int computationThreadPoolSize) { + delegate.computationThreadPoolSize(computationThreadPoolSize); + return this; + } + + @Override + @Deprecated + public ClientResources.Builder dnsResolver(final DnsResolver dnsResolver) { + delegate.dnsResolver(dnsResolver); + return this; + } + + @Override + public ClientResources.Builder eventBus(final EventBus eventBus) { + delegate.eventBus(eventBus); + return this; + } + + @Override + public ClientResources.Builder eventExecutorGroup(final EventExecutorGroup eventExecutorGroup) { + delegate.eventExecutorGroup(eventExecutorGroup); + return this; + } + + @Override + public ClientResources.Builder eventLoopGroupProvider(final EventLoopGroupProvider eventLoopGroupProvider) { + delegate.eventLoopGroupProvider(eventLoopGroupProvider); + return this; + } + + @Override + public ClientResources.Builder ioThreadPoolSize(final int ioThreadPoolSize) { + delegate.ioThreadPoolSize(ioThreadPoolSize); + return this; + } + + @Override + public ClientResources.Builder nettyCustomizer(final NettyCustomizer nettyCustomizer) { + compositeNettyCustomizer.add(nettyCustomizer); + return this; + } + + @Override + public ClientResources.Builder reconnectDelay(final Delay reconnectDelay) { + delegate.reconnectDelay(reconnectDelay); + return this; + } + + @Override + public ClientResources.Builder reconnectDelay(final Supplier reconnectDelay) { + delegate.reconnectDelay(reconnectDelay); + return this; + } + + @Override + public ClientResources.Builder socketAddressResolver(final SocketAddressResolver socketAddressResolver) { + delegate.socketAddressResolver(socketAddressResolver); + return this; + } + + @Override + public ClientResources.Builder threadFactoryProvider(final ThreadFactoryProvider threadFactoryProvider) { + delegate.threadFactoryProvider(threadFactoryProvider); + return this; + } + + @Override + public ClientResources.Builder timer(final Timer timer) { + delegate.timer(timer); + return this; + } + + @Override + public ClientResources.Builder tracing(final Tracing tracing) { + delegate.tracing(tracing); + return this; + } + + @Override + public ClientResources build() { + return delegate.build(); + } + } + +}