diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java
index 9f72400f6..fe64dd2de 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java
@@ -18,6 +18,7 @@ import io.lettuce.core.cluster.ClusterClientOptions;
import io.lettuce.core.cluster.ClusterTopologyRefreshOptions;
import io.lettuce.core.cluster.RedisClusterClient;
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
+import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection;
import io.lettuce.core.codec.ByteArrayCodec;
import io.lettuce.core.resource.ClientResources;
@@ -73,9 +74,10 @@ public class FaultTolerantRedisCluster {
this.name = name;
+ final LettuceShardCircuitBreaker lettuceShardCircuitBreaker = new LettuceShardCircuitBreaker(name,
+ circuitBreakerConfig.toCircuitBreakerConfig(), Schedulers.newSingle("topology-changed-" + name, true));
this.clusterClient = RedisClusterClient.create(
- clientResourcesBuilder.nettyCustomizer(
- new LettuceShardCircuitBreaker(name, circuitBreakerConfig.toCircuitBreakerConfig())).
+ clientResourcesBuilder.nettyCustomizer(lettuceShardCircuitBreaker).
build(),
redisUris);
this.clusterClient.setOptions(ClusterClientOptions.builder()
@@ -91,9 +93,15 @@ public class FaultTolerantRedisCluster {
.publishOnScheduler(true)
.build());
+ lettuceShardCircuitBreaker.setEventBus(clusterClient.getResources().eventBus());
+
this.stringConnection = clusterClient.connect();
this.binaryConnection = clusterClient.connect(ByteArrayCodec.INSTANCE);
+ // create a synthetic topology changed event to notify shard circuit breakers of initial upstreams
+ clusterClient.getResources().eventBus().publish(
+ new ClusterTopologyChangedEvent(Collections.emptyList(), clusterClient.getPartitions().getPartitions()));
+
this.retry = Retry.of(name + "-retry", retryConfiguration.toRetryConfigBuilder()
.retryOnException(exception -> exception instanceof RedisCommandTimeoutException).build());
final RetryConfig topologyChangedEventRetryConfig = RetryConfig.custom()
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java
index df9db0112..e0d007c78 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java
@@ -10,6 +10,9 @@ import io.github.resilience4j.circuitbreaker.CallNotPermittedException;
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig;
import io.lettuce.core.RedisNoScriptException;
+import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
+import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
+import io.lettuce.core.event.EventBus;
import io.lettuce.core.protocol.CommandHandler;
import io.lettuce.core.protocol.CompleteableCommand;
import io.lettuce.core.protocol.RedisCommand;
@@ -21,28 +24,92 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import java.net.SocketAddress;
import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil;
+import reactor.core.scheduler.Scheduler;
+/**
+ * Adds a circuit breaker to every Netty {@link Channel} that gets created, so that a single unhealthy shard does not
+ * impact all cluster operations.
+ *
+ * For metrics to be registered, users must create a synthetic {@link ClusterTopologyChangedEvent} after the
+ * initial connection. For example:
+ *
+ * clusterClient.connect();
+ * clusterClient.getResources().eventBus().publish(
+ * new ClusterTopologyChangedEvent(Collections.emptyList(), clusterClient.getPartitions().getPartitions()));
+ *
+ */
public class LettuceShardCircuitBreaker implements NettyCustomizer {
+ private static final Logger logger = LoggerFactory.getLogger(LettuceShardCircuitBreaker.class);
+
private final String clusterName;
private final CircuitBreakerConfig circuitBreakerConfig;
+ private final Scheduler scheduler;
+ // this set will be shared with all child channel breakers
+ private final Set upstreamAddresses = ConcurrentHashMap.newKeySet();
+ // The EventBus is not available at construction time, because it is one of the client
+ // resources, which cannot be built without this NettyCustomizer
+ private EventBus eventBus;
- public LettuceShardCircuitBreaker(final String clusterName, final CircuitBreakerConfig circuitBreakerConfig) {
+ public LettuceShardCircuitBreaker(final String clusterName, final CircuitBreakerConfig circuitBreakerConfig,
+ final Scheduler scheduler) {
this.clusterName = clusterName;
this.circuitBreakerConfig = circuitBreakerConfig;
+ this.scheduler = scheduler;
+ }
+
+ private static String toShardAddress(final RedisClusterNode redisClusterNode) {
+ return "%s:%s".formatted(redisClusterNode.getUri().getHost(), redisClusterNode.getUri().getPort());
+ }
+
+ void setEventBus(final EventBus eventBus) {
+ this.eventBus = eventBus;
+
+ eventBus.get()
+ .filter(e -> e instanceof ClusterTopologyChangedEvent)
+ .map(e -> (ClusterTopologyChangedEvent) e)
+ .subscribeOn(scheduler)
+ .subscribe(event -> {
+
+ final Set currentUpstreams = event.after().stream()
+ .filter(node -> node.getRole().isUpstream())
+ .map(LettuceShardCircuitBreaker::toShardAddress)
+ .collect(Collectors.toSet());
+
+ final Set previousUpstreams = event.before().stream()
+ .filter(node -> node.getRole().isUpstream())
+ .map(LettuceShardCircuitBreaker::toShardAddress)
+ .collect(Collectors.toSet());
+ if (previousUpstreams.removeAll(currentUpstreams)) {
+ logger.info("No longer upstream in cluster {}: {}", clusterName, StringUtils.join(previousUpstreams, ", "));
+ }
+
+ // Channels may be created at any time, not just immediately after the cluster client connect()s or when topology
+ // changes, so we maintain a set that can be queried by channel handlers during their connect() method.
+ upstreamAddresses.addAll(currentUpstreams);
+ upstreamAddresses.removeAll(previousUpstreams);
+ });
}
@Override
public void afterChannelInitialized(final Channel channel) {
+ if (eventBus == null) {
+ throw new IllegalStateException("Event bus must be set before channel customization can occur");
+ }
+
final ChannelCircuitBreakerHandler channelCircuitBreakerHandler = new ChannelCircuitBreakerHandler(clusterName,
- circuitBreakerConfig);
+ circuitBreakerConfig, upstreamAddresses, eventBus, scheduler);
final String commandHandlerName = StreamSupport.stream(channel.pipeline().spliterator(), false)
.filter(entry -> entry.getValue() instanceof CommandHandler)
@@ -60,13 +127,48 @@ public class LettuceShardCircuitBreaker implements NettyCustomizer {
private final String clusterName;
private final CircuitBreakerConfig circuitBreakerConfig;
+ private final AtomicBoolean registeredMetrics = new AtomicBoolean(false);
+ private final Set upstreamAddresses;
+
+ private String shardAddress;
@VisibleForTesting
CircuitBreaker breaker;
- public ChannelCircuitBreakerHandler(final String name, CircuitBreakerConfig circuitBreakerConfig) {
+ public ChannelCircuitBreakerHandler(final String name, final CircuitBreakerConfig circuitBreakerConfig,
+ final Set upstreamAddresses,
+ final EventBus eventBus, final Scheduler scheduler) {
this.clusterName = name;
this.circuitBreakerConfig = circuitBreakerConfig;
+ this.upstreamAddresses = upstreamAddresses;
+
+ eventBus.get()
+ .filter(e -> e instanceof ClusterTopologyChangedEvent)
+ .map(e -> (ClusterTopologyChangedEvent) e)
+ .subscribeOn(scheduler)
+ .subscribe(event -> {
+ if (shardAddress == null) {
+ logger.warn("Received a topology changed event without a shard address");
+ return;
+ }
+
+ final Set newUpstreams = event.after().stream().filter(node -> node.getRole().isUpstream())
+ .map(LettuceShardCircuitBreaker::toShardAddress)
+ .collect(Collectors.toSet());
+
+ if (newUpstreams.contains(shardAddress)) {
+ registerMetrics();
+ }
+ });
+ }
+
+ void registerMetrics() {
+ // Registering metrics is not idempotent--some counters are added as event listeners,
+ // and there would be duplicated calls to increment()
+ if (registeredMetrics.compareAndSet(false, true)) {
+ logger.info("Registered metrics for: {}/{}", clusterName, shardAddress);
+ CircuitBreakerUtil.registerMetrics(breaker, getClass(), Tags.of(CLUSTER_TAG_NAME, clusterName));
+ }
}
@Override
@@ -79,9 +181,12 @@ public class LettuceShardCircuitBreaker implements NettyCustomizer {
// match remote address, as it is inherited from the Bootstrap attributes and not updated for the Channel connection
// In some cases, like the default connection, the remote address includes the DNS hostname, which we want to exclude.
- final String shardAddress = StringUtils.substringAfter(remoteAddress.toString(), "/");
+ shardAddress = StringUtils.substringAfter(remoteAddress.toString(), "/");
breaker = CircuitBreaker.of("%s/%s-breaker".formatted(clusterName, shardAddress), circuitBreakerConfig);
- CircuitBreakerUtil.registerMetrics(breaker, getClass(), Tags.of(CLUSTER_TAG_NAME, clusterName));
+
+ if (upstreamAddresses.contains(shardAddress)) {
+ registerMetrics();
+ }
}
@Override
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java
index 4696aa398..24c6e9b82 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java
@@ -13,11 +13,13 @@ import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
import io.github.resilience4j.circuitbreaker.CallNotPermittedException;
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
import io.lettuce.core.ClientOptions;
import io.lettuce.core.codec.StringCodec;
+import io.lettuce.core.event.EventBus;
import io.lettuce.core.output.StatusOutput;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.Command;
@@ -32,6 +34,7 @@ import io.netty.channel.embedded.EmbeddedChannel;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -42,23 +45,29 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
+import reactor.core.publisher.Flux;
+import reactor.core.scheduler.Schedulers;
class LettuceShardCircuitBreakerTest {
private LettuceShardCircuitBreaker.ChannelCircuitBreakerHandler channelCircuitBreakerHandler;
+ private EventBus eventBus;
@BeforeEach
void setUp() {
+ eventBus = mock(EventBus.class);
+ when(eventBus.get()).thenReturn(Flux.never());
channelCircuitBreakerHandler = new LettuceShardCircuitBreaker.ChannelCircuitBreakerHandler(
- "test",
- new CircuitBreakerConfiguration().toCircuitBreakerConfig());
+ "test", new CircuitBreakerConfiguration().toCircuitBreakerConfig(), Collections.emptySet(), eventBus,
+ Schedulers.immediate());
}
@Test
void testAfterChannelInitialized() {
final LettuceShardCircuitBreaker lettuceShardCircuitBreaker = new LettuceShardCircuitBreaker("test",
- new CircuitBreakerConfiguration().toCircuitBreakerConfig());
+ new CircuitBreakerConfiguration().toCircuitBreakerConfig(), Schedulers.immediate());
+ lettuceShardCircuitBreaker.setEventBus(eventBus);
final Channel channel = new EmbeddedChannel(
new CommandHandler(ClientOptions.create(), ClientResources.create(), mock(Endpoint.class)));