diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java index 9f72400f6..fe64dd2de 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java @@ -18,6 +18,7 @@ import io.lettuce.core.cluster.ClusterClientOptions; import io.lettuce.core.cluster.ClusterTopologyRefreshOptions; import io.lettuce.core.cluster.RedisClusterClient; import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; import io.lettuce.core.codec.ByteArrayCodec; import io.lettuce.core.resource.ClientResources; @@ -73,9 +74,10 @@ public class FaultTolerantRedisCluster { this.name = name; + final LettuceShardCircuitBreaker lettuceShardCircuitBreaker = new LettuceShardCircuitBreaker(name, + circuitBreakerConfig.toCircuitBreakerConfig(), Schedulers.newSingle("topology-changed-" + name, true)); this.clusterClient = RedisClusterClient.create( - clientResourcesBuilder.nettyCustomizer( - new LettuceShardCircuitBreaker(name, circuitBreakerConfig.toCircuitBreakerConfig())). + clientResourcesBuilder.nettyCustomizer(lettuceShardCircuitBreaker). build(), redisUris); this.clusterClient.setOptions(ClusterClientOptions.builder() @@ -91,9 +93,15 @@ public class FaultTolerantRedisCluster { .publishOnScheduler(true) .build()); + lettuceShardCircuitBreaker.setEventBus(clusterClient.getResources().eventBus()); + this.stringConnection = clusterClient.connect(); this.binaryConnection = clusterClient.connect(ByteArrayCodec.INSTANCE); + // create a synthetic topology changed event to notify shard circuit breakers of initial upstreams + clusterClient.getResources().eventBus().publish( + new ClusterTopologyChangedEvent(Collections.emptyList(), clusterClient.getPartitions().getPartitions())); + this.retry = Retry.of(name + "-retry", retryConfiguration.toRetryConfigBuilder() .retryOnException(exception -> exception instanceof RedisCommandTimeoutException).build()); final RetryConfig topologyChangedEventRetryConfig = RetryConfig.custom() diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java index df9db0112..e0d007c78 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java @@ -10,6 +10,9 @@ import io.github.resilience4j.circuitbreaker.CallNotPermittedException; import io.github.resilience4j.circuitbreaker.CircuitBreaker; import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig; import io.lettuce.core.RedisNoScriptException; +import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; +import io.lettuce.core.event.EventBus; import io.lettuce.core.protocol.CommandHandler; import io.lettuce.core.protocol.CompleteableCommand; import io.lettuce.core.protocol.RedisCommand; @@ -21,28 +24,92 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import java.net.SocketAddress; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; import java.util.stream.StreamSupport; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; +import reactor.core.scheduler.Scheduler; +/** + * Adds a circuit breaker to every Netty {@link Channel} that gets created, so that a single unhealthy shard does not + * impact all cluster operations. + *

+ * For metrics to be registered, users must create a synthetic {@link ClusterTopologyChangedEvent} after the + * initial connection. For example: + *

+ *   clusterClient.connect();
+ *   clusterClient.getResources().eventBus().publish(
+ *         new ClusterTopologyChangedEvent(Collections.emptyList(), clusterClient.getPartitions().getPartitions()));
+ * 
+ */ public class LettuceShardCircuitBreaker implements NettyCustomizer { + private static final Logger logger = LoggerFactory.getLogger(LettuceShardCircuitBreaker.class); + private final String clusterName; private final CircuitBreakerConfig circuitBreakerConfig; + private final Scheduler scheduler; + // this set will be shared with all child channel breakers + private final Set upstreamAddresses = ConcurrentHashMap.newKeySet(); + // The EventBus is not available at construction time, because it is one of the client + // resources, which cannot be built without this NettyCustomizer + private EventBus eventBus; - public LettuceShardCircuitBreaker(final String clusterName, final CircuitBreakerConfig circuitBreakerConfig) { + public LettuceShardCircuitBreaker(final String clusterName, final CircuitBreakerConfig circuitBreakerConfig, + final Scheduler scheduler) { this.clusterName = clusterName; this.circuitBreakerConfig = circuitBreakerConfig; + this.scheduler = scheduler; + } + + private static String toShardAddress(final RedisClusterNode redisClusterNode) { + return "%s:%s".formatted(redisClusterNode.getUri().getHost(), redisClusterNode.getUri().getPort()); + } + + void setEventBus(final EventBus eventBus) { + this.eventBus = eventBus; + + eventBus.get() + .filter(e -> e instanceof ClusterTopologyChangedEvent) + .map(e -> (ClusterTopologyChangedEvent) e) + .subscribeOn(scheduler) + .subscribe(event -> { + + final Set currentUpstreams = event.after().stream() + .filter(node -> node.getRole().isUpstream()) + .map(LettuceShardCircuitBreaker::toShardAddress) + .collect(Collectors.toSet()); + + final Set previousUpstreams = event.before().stream() + .filter(node -> node.getRole().isUpstream()) + .map(LettuceShardCircuitBreaker::toShardAddress) + .collect(Collectors.toSet()); + if (previousUpstreams.removeAll(currentUpstreams)) { + logger.info("No longer upstream in cluster {}: {}", clusterName, StringUtils.join(previousUpstreams, ", ")); + } + + // Channels may be created at any time, not just immediately after the cluster client connect()s or when topology + // changes, so we maintain a set that can be queried by channel handlers during their connect() method. + upstreamAddresses.addAll(currentUpstreams); + upstreamAddresses.removeAll(previousUpstreams); + }); } @Override public void afterChannelInitialized(final Channel channel) { + if (eventBus == null) { + throw new IllegalStateException("Event bus must be set before channel customization can occur"); + } + final ChannelCircuitBreakerHandler channelCircuitBreakerHandler = new ChannelCircuitBreakerHandler(clusterName, - circuitBreakerConfig); + circuitBreakerConfig, upstreamAddresses, eventBus, scheduler); final String commandHandlerName = StreamSupport.stream(channel.pipeline().spliterator(), false) .filter(entry -> entry.getValue() instanceof CommandHandler) @@ -60,13 +127,48 @@ public class LettuceShardCircuitBreaker implements NettyCustomizer { private final String clusterName; private final CircuitBreakerConfig circuitBreakerConfig; + private final AtomicBoolean registeredMetrics = new AtomicBoolean(false); + private final Set upstreamAddresses; + + private String shardAddress; @VisibleForTesting CircuitBreaker breaker; - public ChannelCircuitBreakerHandler(final String name, CircuitBreakerConfig circuitBreakerConfig) { + public ChannelCircuitBreakerHandler(final String name, final CircuitBreakerConfig circuitBreakerConfig, + final Set upstreamAddresses, + final EventBus eventBus, final Scheduler scheduler) { this.clusterName = name; this.circuitBreakerConfig = circuitBreakerConfig; + this.upstreamAddresses = upstreamAddresses; + + eventBus.get() + .filter(e -> e instanceof ClusterTopologyChangedEvent) + .map(e -> (ClusterTopologyChangedEvent) e) + .subscribeOn(scheduler) + .subscribe(event -> { + if (shardAddress == null) { + logger.warn("Received a topology changed event without a shard address"); + return; + } + + final Set newUpstreams = event.after().stream().filter(node -> node.getRole().isUpstream()) + .map(LettuceShardCircuitBreaker::toShardAddress) + .collect(Collectors.toSet()); + + if (newUpstreams.contains(shardAddress)) { + registerMetrics(); + } + }); + } + + void registerMetrics() { + // Registering metrics is not idempotent--some counters are added as event listeners, + // and there would be duplicated calls to increment() + if (registeredMetrics.compareAndSet(false, true)) { + logger.info("Registered metrics for: {}/{}", clusterName, shardAddress); + CircuitBreakerUtil.registerMetrics(breaker, getClass(), Tags.of(CLUSTER_TAG_NAME, clusterName)); + } } @Override @@ -79,9 +181,12 @@ public class LettuceShardCircuitBreaker implements NettyCustomizer { // match remote address, as it is inherited from the Bootstrap attributes and not updated for the Channel connection // In some cases, like the default connection, the remote address includes the DNS hostname, which we want to exclude. - final String shardAddress = StringUtils.substringAfter(remoteAddress.toString(), "/"); + shardAddress = StringUtils.substringAfter(remoteAddress.toString(), "/"); breaker = CircuitBreaker.of("%s/%s-breaker".formatted(clusterName, shardAddress), circuitBreakerConfig); - CircuitBreakerUtil.registerMetrics(breaker, getClass(), Tags.of(CLUSTER_TAG_NAME, clusterName)); + + if (upstreamAddresses.contains(shardAddress)) { + registerMetrics(); + } } @Override diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java index 4696aa398..24c6e9b82 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java @@ -13,11 +13,13 @@ import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; import io.github.resilience4j.circuitbreaker.CallNotPermittedException; import io.github.resilience4j.circuitbreaker.CircuitBreaker; import io.lettuce.core.ClientOptions; import io.lettuce.core.codec.StringCodec; +import io.lettuce.core.event.EventBus; import io.lettuce.core.output.StatusOutput; import io.lettuce.core.protocol.AsyncCommand; import io.lettuce.core.protocol.Command; @@ -32,6 +34,7 @@ import io.netty.channel.embedded.EmbeddedChannel; import java.io.IOException; import java.net.SocketAddress; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -42,23 +45,29 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; class LettuceShardCircuitBreakerTest { private LettuceShardCircuitBreaker.ChannelCircuitBreakerHandler channelCircuitBreakerHandler; + private EventBus eventBus; @BeforeEach void setUp() { + eventBus = mock(EventBus.class); + when(eventBus.get()).thenReturn(Flux.never()); channelCircuitBreakerHandler = new LettuceShardCircuitBreaker.ChannelCircuitBreakerHandler( - "test", - new CircuitBreakerConfiguration().toCircuitBreakerConfig()); + "test", new CircuitBreakerConfiguration().toCircuitBreakerConfig(), Collections.emptySet(), eventBus, + Schedulers.immediate()); } @Test void testAfterChannelInitialized() { final LettuceShardCircuitBreaker lettuceShardCircuitBreaker = new LettuceShardCircuitBreaker("test", - new CircuitBreakerConfiguration().toCircuitBreakerConfig()); + new CircuitBreakerConfiguration().toCircuitBreakerConfig(), Schedulers.immediate()); + lettuceShardCircuitBreaker.setEventBus(eventBus); final Channel channel = new EmbeddedChannel( new CommandHandler(ClientOptions.create(), ClientResources.create(), mock(Endpoint.class)));