diff --git a/service/pom.xml b/service/pom.xml index 838f153f9..67f9357f4 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -117,6 +117,12 @@ 5.3.3.RELEASE + + org.apache.commons + commons-pool2 + 2.8.1 + + org.postgresql postgresql diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManager.java index 1f6aae3dc..557e8bb8e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManager.java @@ -3,13 +3,15 @@ package org.whispersystems.textsecuregcm.metrics; import com.codahale.metrics.MetricRegistry; import com.google.common.annotations.VisibleForTesting; import io.lettuce.core.SetArgs; -import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.micrometer.core.instrument.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import java.time.Duration; +import java.util.Optional; import java.util.UUID; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; /** @@ -27,6 +29,8 @@ public class PushLatencyManager { private final FaultTolerantRedisCluster redisCluster; + private static final Logger log = LoggerFactory.getLogger(PushLatencyManager.class); + public PushLatencyManager(final FaultTolerantRedisCluster redisCluster) { this.redisCluster = redisCluster; } @@ -37,29 +41,33 @@ public class PushLatencyManager { @VisibleForTesting void recordPushSent(final UUID accountUuid, final long deviceId, final long currentTime) { - redisCluster.useCluster(connection -> - connection.async().set(getFirstUnacknowledgedPushKey(accountUuid, deviceId), String.valueOf(currentTime), SetArgs.Builder.nx().ex(TTL))); + try { + redisCluster.useCluster(connection -> + connection.sync().set(getFirstUnacknowledgedPushKey(accountUuid, deviceId), String.valueOf(currentTime), SetArgs.Builder.nx().ex(TTL))); + } catch (final Exception e) { + log.warn("Failed to record \"push notification sent\" timestamp", e); + } } public void recordQueueRead(final UUID accountUuid, final long deviceId, final String userAgent) { - getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis()).thenAccept(latency -> { - if (latency != null) { - Metrics.timer(TIMER_NAME, UserAgentTagUtil.getUserAgentTags(userAgent)).record(latency, TimeUnit.MILLISECONDS); - } - }); + final Optional maybeLatency = getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis()); + + if (maybeLatency.isPresent()) { + Metrics.timer(TIMER_NAME, UserAgentTagUtil.getUserAgentTags(userAgent)).record(maybeLatency.get(), TimeUnit.MILLISECONDS); + } } @VisibleForTesting - CompletableFuture getLatencyAndClearTimestamp(final UUID accountUuid, final long deviceId, final long currentTimeMillis) { + Optional getLatencyAndClearTimestamp(final UUID accountUuid, final long deviceId, final long currentTimeMillis) { final String key = getFirstUnacknowledgedPushKey(accountUuid, deviceId); return redisCluster.withCluster(connection -> { - final RedisAdvancedClusterAsyncCommands commands = connection.async(); + final RedisAdvancedClusterCommands commands = connection.sync(); - final CompletableFuture getFuture = commands.get(key).toCompletableFuture(); + final String timestampString = commands.get(key); commands.del(key); - return getFuture.thenApply(timestampString -> timestampString != null ? currentTimeMillis - Long.parseLong(timestampString, 10) : null); + return timestampString != null ? Optional.of(currentTimeMillis - Long.parseLong(timestampString, 10)) : Optional.empty(); }); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/RedisClusterHealthCheck.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/RedisClusterHealthCheck.java index 3be9ab90a..762fd8a51 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/providers/RedisClusterHealthCheck.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/RedisClusterHealthCheck.java @@ -14,10 +14,8 @@ public class RedisClusterHealthCheck extends HealthCheck { } @Override - protected Result check() throws Exception { - return CompletableFuture.allOf(redisCluster.withCluster(connection -> connection.async().masters().commands().ping()).futures()) - .thenApply(v -> Result.healthy()) - .exceptionally(Result::unhealthy) - .get(); + protected Result check() { + redisCluster.withCluster(connection -> connection.sync().masters().commands().ping()); + return Result.healthy(); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java index 4332c6f18..73413dca8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java @@ -196,7 +196,7 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter connection.async().masters().commands().unsubscribe(getKeyspaceNotificationChannel(presenceKey))); + pubSubConnection.usePubSubConnection(connection -> connection.sync().masters().commands().unsubscribe(getKeyspaceNotificationChannel(presenceKey))); } void pruneMissingPeers() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java index aab61e4ce..80cfedd8c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java @@ -8,6 +8,11 @@ import io.lettuce.core.cluster.RedisClusterClient; import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.support.ConnectionPoolSupport; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.commons.pool2.impl.GenericObjectPoolConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; import org.whispersystems.textsecuregcm.util.Constants; @@ -20,9 +25,8 @@ import java.util.function.Function; import java.util.stream.Collectors; /** - * A fault-tolerant access manager for a Redis cluster. A fault-tolerant Redis cluster has separate circuit breakers for - * read and write operations because the leader in a Redis cluster shard may fail while its read-only replicas can still - * serve traffic. + * A fault-tolerant access manager for a Redis cluster. A fault-tolerant Redis cluster provides managed, + * circuit-breaker-protected access to a pool of connections. */ public class FaultTolerantRedisCluster { @@ -30,14 +34,16 @@ public class FaultTolerantRedisCluster { private final RedisClusterClient clusterClient; - private final StatefulRedisClusterConnection stringClusterConnection; - private final StatefulRedisClusterConnection binaryClusterConnection; + private final GenericObjectPool> stringConnectionPool; + private final GenericObjectPool> binaryConnectionPool; private final List> pubSubConnections = new ArrayList<>(); private final CircuitBreakerConfiguration circuitBreakerConfiguration; private final CircuitBreaker circuitBreaker; + private static final Logger log = LoggerFactory.getLogger(FaultTolerantRedisCluster.class); + public FaultTolerantRedisCluster(final String name, final List urls, final Duration timeout, final CircuitBreakerConfiguration circuitBreakerConfiguration) { this(name, RedisClusterClient.create(urls.stream().map(RedisURI::create).collect(Collectors.toList())), timeout, circuitBreakerConfiguration); } @@ -49,8 +55,11 @@ public class FaultTolerantRedisCluster { this.clusterClient = clusterClient; this.clusterClient.setDefaultTimeout(timeout); - this.stringClusterConnection = clusterClient.connect(); - this.binaryClusterConnection = clusterClient.connect(ByteArrayCodec.INSTANCE); + //noinspection unchecked,rawtypes,rawtypes + this.stringConnectionPool = ConnectionPoolSupport.createGenericObjectPool(clusterClient::connect, new GenericObjectPoolConfig()); + + //noinspection unchecked,rawtypes,rawtypes + this.binaryConnectionPool = ConnectionPoolSupport.createGenericObjectPool(() -> clusterClient.connect(ByteArrayCodec.INSTANCE), new GenericObjectPoolConfig()); this.circuitBreakerConfiguration = circuitBreakerConfiguration; this.circuitBreaker = CircuitBreaker.of(name + "-read", circuitBreakerConfiguration.toCircuitBreakerConfig()); @@ -61,8 +70,8 @@ public class FaultTolerantRedisCluster { } void shutdown() { - stringClusterConnection.close(); - binaryClusterConnection.close(); + stringConnectionPool.close(); + binaryConnectionPool.close(); for (final StatefulRedisClusterPubSubConnection pubSubConnection : pubSubConnections) { pubSubConnection.close(); @@ -72,19 +81,55 @@ public class FaultTolerantRedisCluster { } public void useCluster(final Consumer> consumer) { - this.circuitBreaker.executeRunnable(() -> consumer.accept(stringClusterConnection)); + acceptPooledConnection(stringConnectionPool, consumer); } - public T withCluster(final Function, T> consumer) { - return this.circuitBreaker.executeSupplier(() -> consumer.apply(stringClusterConnection)); + public T withCluster(final Function, T> function) { + return applyToPooledConnection(stringConnectionPool, function); } public void useBinaryCluster(final Consumer> consumer) { - this.circuitBreaker.executeRunnable(() -> consumer.accept(binaryClusterConnection)); + acceptPooledConnection(binaryConnectionPool, consumer); } - public T withBinaryCluster(final Function, T> consumer) { - return this.circuitBreaker.executeSupplier(() -> consumer.apply(binaryClusterConnection)); + public T withBinaryCluster(final Function, T> function) { + return applyToPooledConnection(binaryConnectionPool, function); + } + + private void acceptPooledConnection(final GenericObjectPool> pool, final Consumer> consumer) { + try { + circuitBreaker.executeCheckedRunnable(() -> { + try (final StatefulRedisClusterConnection connection = pool.borrowObject()) { + consumer.accept(connection); + } + }); + } catch (final Throwable t) { + log.warn("Redis operation failure", t); + + if (t instanceof RuntimeException) { + throw (RuntimeException) t; + } else { + throw new RuntimeException(t); + } + } + } + + private T applyToPooledConnection(final GenericObjectPool> pool, final Function, T> function) { + try { + return circuitBreaker.executeCheckedSupplier(() -> { + try (final StatefulRedisClusterConnection connection = pool.borrowObject()) { + return function.apply(connection); + } + }); + } catch (final Throwable t) { + log.warn("Redis operation failure", t); + + if (t instanceof RuntimeException) { + throw (RuntimeException) t; + } else { + throw new RuntimeException(t); + } + } } public FaultTolerantPubSubConnection createPubSubConnection() { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManagerTest.java index 4a9214a05..7024a1676 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManagerTest.java @@ -1,37 +1,17 @@ package org.whispersystems.textsecuregcm.metrics; -import org.junit.After; -import org.junit.Before; import org.junit.Test; -import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; -import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; -import redis.clients.jedis.Jedis; -import redis.embedded.RedisServer; import java.util.Optional; import java.util.UUID; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.fail; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; public class PushLatencyManagerTest extends AbstractRedisClusterTest { @Test - public void testGetLatency() throws ExecutionException, InterruptedException { + public void testGetLatency() { final PushLatencyManager pushLatencyManager = new PushLatencyManager(getRedisCluster()); final UUID accountUuid = UUID.randomUUID(); final long deviceId = 1; @@ -39,13 +19,13 @@ public class PushLatencyManagerTest extends AbstractRedisClusterTest { final long pushSentTimestamp = System.currentTimeMillis(); final long clearQueueTimestamp = pushSentTimestamp + expectedLatency; - assertNull(pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis()).get()); + assertEquals(Optional.empty(), pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis())); { pushLatencyManager.recordPushSent(accountUuid, deviceId, pushSentTimestamp); - assertEquals(expectedLatency, (long)pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, clearQueueTimestamp).get()); - assertNull(pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis()).get()); + assertEquals(Optional.of(expectedLatency), pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, clearQueueTimestamp)); + assertEquals(Optional.empty(), pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis())); } } }