diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java index 6d380cdc0..bf0696649 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java @@ -6,6 +6,7 @@ import io.github.resilience4j.circuitbreaker.CircuitBreaker; import io.lettuce.core.RedisURI; import io.lettuce.core.cluster.RedisClusterClient; import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; import io.lettuce.core.codec.ByteArrayCodec; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; @@ -26,11 +27,13 @@ public class FaultTolerantRedisCluster { private final RedisClusterClient clusterClient; - private final StatefulRedisClusterConnection stringClusterConnection; - private final StatefulRedisClusterConnection binaryClusterConnection; + private final StatefulRedisClusterConnection stringClusterConnection; + private final StatefulRedisClusterConnection binaryClusterConnection; + private final StatefulRedisClusterPubSubConnection pubSubClusterConnection; private final CircuitBreaker readCircuitBreaker; private final CircuitBreaker writeCircuitBreaker; + private final CircuitBreaker pubSubCircuitBreaker; public FaultTolerantRedisCluster(final String name, final List urls, final Duration timeout, final CircuitBreakerConfiguration circuitBreakerConfiguration) { this(name, RedisClusterClient.create(urls.stream().map(RedisURI::create).collect(Collectors.toList())), timeout, circuitBreakerConfiguration); @@ -43,8 +46,10 @@ public class FaultTolerantRedisCluster { this.stringClusterConnection = clusterClient.connect(); this.binaryClusterConnection = clusterClient.connect(ByteArrayCodec.INSTANCE); + this.pubSubClusterConnection = clusterClient.connectPubSub(); this.readCircuitBreaker = CircuitBreaker.of(name + "-read", circuitBreakerConfiguration.toCircuitBreakerConfig()); this.writeCircuitBreaker = CircuitBreaker.of(name + "-write", circuitBreakerConfiguration.toCircuitBreakerConfig()); + this.pubSubCircuitBreaker = CircuitBreaker.of(name + "-pubsub", circuitBreakerConfiguration.toCircuitBreakerConfig()); CircuitBreakerUtil.registerMetrics(SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME), readCircuitBreaker, @@ -53,11 +58,16 @@ public class FaultTolerantRedisCluster { CircuitBreakerUtil.registerMetrics(SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME), writeCircuitBreaker, FaultTolerantRedisCluster.class); + + CircuitBreakerUtil.registerMetrics(SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME), + pubSubCircuitBreaker, + FaultTolerantRedisCluster.class); } void shutdown() { stringClusterConnection.close(); binaryClusterConnection.close(); + pubSubClusterConnection.close(); clusterClient.shutdown(); } @@ -93,4 +103,12 @@ public class FaultTolerantRedisCluster { public T withBinaryWriteCluster(final Function, T> consumer) { return this.writeCircuitBreaker.executeSupplier(() -> consumer.apply(binaryClusterConnection)); } + + public void usePubSubConnection(final Consumer> consumer) { + this.pubSubCircuitBreaker.executeRunnable(() -> consumer.accept(pubSubClusterConnection)); + } + + public T withPubSubConnection(final Function, T> consumer) { + return this.pubSubCircuitBreaker.executeSupplier(() -> consumer.apply(pubSubClusterConnection)); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java index f7ea9525f..b437c5437 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisClusterTest.java @@ -5,6 +5,8 @@ import io.lettuce.core.RedisException; import io.lettuce.core.cluster.RedisClusterClient; import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; +import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands; import org.junit.Before; import org.junit.Test; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; @@ -18,20 +20,25 @@ import static org.mockito.Mockito.when; public class FaultTolerantRedisClusterTest { - private RedisAdvancedClusterCommands clusterCommands; + private RedisAdvancedClusterCommands clusterCommands; + private RedisClusterPubSubCommands pubSubCommands; private FaultTolerantRedisCluster faultTolerantCluster; @SuppressWarnings("unchecked") @Before public void setUp() { - final RedisClusterClient clusterClient = mock(RedisClusterClient.class); - final StatefulRedisClusterConnection clusterConnection = mock(StatefulRedisClusterConnection.class); + final RedisClusterClient clusterClient = mock(RedisClusterClient.class); + final StatefulRedisClusterConnection clusterConnection = mock(StatefulRedisClusterConnection.class); + final StatefulRedisClusterPubSubConnection pubSubConnection = mock(StatefulRedisClusterPubSubConnection.class); clusterCommands = mock(RedisAdvancedClusterCommands.class); + pubSubCommands = mock(RedisClusterPubSubCommands.class); when(clusterClient.connect()).thenReturn(clusterConnection); + when(clusterClient.connectPubSub()).thenReturn(pubSubConnection); when(clusterConnection.sync()).thenReturn(clusterCommands); + when(pubSubConnection.sync()).thenReturn(pubSubCommands); final CircuitBreakerConfiguration breakerConfiguration = new CircuitBreakerConfiguration(); breakerConfiguration.setFailureRateThreshold(100); @@ -85,4 +92,19 @@ public class FaultTolerantRedisClusterTest { assertThrows(CircuitBreakerOpenException.class, () -> faultTolerantCluster.withWriteCluster(connection -> connection.sync().get("OH NO"))); } + + @Test + public void testPubSubBreaker() { + when(pubSubCommands.publish(anyString(), anyString())) + .thenReturn(1L) + .thenThrow(new RedisException("Badness has ensued.")); + + assertEquals(1L, (long)faultTolerantCluster.withPubSubConnection(connection -> connection.sync().publish("channel", "message"))); + + assertThrows(RedisException.class, + () -> faultTolerantCluster.withPubSubConnection(connection -> connection.sync().publish("channel", "OH NO"))); + + assertThrows(CircuitBreakerOpenException.class, + () -> faultTolerantCluster.withPubSubConnection(connection -> connection.sync().publish("channel", "OH NO"))); + } }