diff --git a/service/config/sample.yml b/service/config/sample.yml index a90f3409b..ddb3d08d0 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -91,9 +91,7 @@ clientPresenceCluster: # Redis server configuration for client presence cluster configurationUri: redis://redis.example.com:6379/ pubsub: # Redis server configuration for pubsub cluster - url: redis://redis.example.com:6379/ - replicaUrls: - - redis://redis.example.com:6379/ + uri: redis://redis.example.com:6379/ pushSchedulerCluster: # Redis server configuration for push scheduler cluster configurationUri: redis://redis.example.com:6379/ diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 3d096e2aa..af1aa2265 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -41,7 +41,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.EnumSet; import java.util.List; -import java.util.Optional; import java.util.ServiceLoader; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -65,7 +64,6 @@ import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation; import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.dispatch.DispatchManager; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.CertificateGenerator; @@ -147,7 +145,6 @@ import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge; import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener; import org.whispersystems.textsecuregcm.metrics.TrafficSource; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; -import org.whispersystems.textsecuregcm.providers.RedisClientFactory; import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck; import org.whispersystems.textsecuregcm.push.APNSender; import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler; @@ -160,7 +157,6 @@ import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.redis.ConnectionEventLogger; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; -import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; @@ -198,7 +194,6 @@ import org.whispersystems.textsecuregcm.storage.NonNormalizedAccountCrawlerListe import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.Profiles; import org.whispersystems.textsecuregcm.storage.ProfilesManager; -import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PushChallengeDynamoDb; import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor; import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager; @@ -387,11 +382,6 @@ public class WhisperServerService extends Application provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), 60000); provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); - provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager)); + provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager)); provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET)); provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedisConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedisConfiguration.java index affb597eb..07824b19e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedisConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RedisConfiguration.java @@ -16,11 +16,7 @@ public class RedisConfiguration { @JsonProperty @NotEmpty - private String url; - - @JsonProperty - @NotNull - private List replicaUrls; + private String uri; @JsonProperty @NotNull @@ -31,12 +27,8 @@ public class RedisConfiguration { @Valid private CircuitBreakerConfiguration circuitBreaker = new CircuitBreakerConfiguration(); - public String getUrl() { - return url; - } - - public List getReplicaUrls() { - return replicaUrls; + public String getUri() { + return uri; } public Duration getTimeout() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java index aa114c711..1c56cfa82 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java @@ -5,37 +5,157 @@ package org.whispersystems.textsecuregcm.push; -import com.google.protobuf.ByteString; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; -import org.whispersystems.textsecuregcm.storage.PubSubManager; -import org.whispersystems.textsecuregcm.storage.PubSubProtos; -import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; - import static com.codahale.metrics.MetricRegistry.name; -public class ProvisioningManager { - private final PubSubManager pubSubManager; +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import io.dropwizard.lifecycle.Managed; +import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.lettuce.core.RedisClient; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.pubsub.RedisPubSubAdapter; +import io.lettuce.core.pubsub.StatefulRedisPubSubConnection; +import io.lettuce.core.resource.ClientResources; +import io.micrometer.core.instrument.Metrics; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import io.micrometer.core.instrument.Tags; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.redis.RedisOperation; +import org.whispersystems.textsecuregcm.storage.PubSubProtos; +import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; +import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException; +import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; - private final Counter provisioningMessageOnlineCounter = Metrics.counter(name(getClass(), "sendProvisioningMessage"), "online", "true"); - private final Counter provisioningMessageOfflineCounter = Metrics.counter(name(getClass(), "sendProvisioningMessage"), "online", "false"); +public class ProvisioningManager extends RedisPubSubAdapter implements Managed { - public ProvisioningManager(final PubSubManager pubSubManager) { - this.pubSubManager = pubSubManager; - } + private final RedisClient redisClient; + private final StatefulRedisPubSubConnection subscriptionConnection; + private final StatefulRedisConnection publicationConnection; - public boolean sendProvisioningMessage(ProvisioningAddress address, byte[] body) { - PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.newBuilder() - .setType(PubSubProtos.PubSubMessage.Type.DELIVER) - .setContent(ByteString.copyFrom(body)) - .build(); + private final CircuitBreaker circuitBreaker; - if (pubSubManager.publish(address, pubSubMessage)) { - provisioningMessageOnlineCounter.increment(); - return true; - } else { - provisioningMessageOfflineCounter.increment(); - return false; + private final Map> listenersByProvisioningAddress = + new ConcurrentHashMap<>(); + + private static final String ACTIVE_LISTENERS_GAUGE_NAME = name(ProvisioningManager.class, "activeListeners"); + + private static final String SEND_PROVISIONING_MESSAGE_COUNTER_NAME = + name(ProvisioningManager.class, "sendProvisioningMessage"); + + private static final String RECEIVE_PROVISIONING_MESSAGE_COUNTER_NAME = + name(ProvisioningManager.class, "receiveProvisioningMessage"); + + private static final Logger logger = LoggerFactory.getLogger(ProvisioningManager.class); + + public ProvisioningManager(final String redisUri, + final ClientResources clientResources, + final Duration timeout, + final CircuitBreakerConfiguration circuitBreakerConfiguration) { + + this(RedisClient.create(clientResources, redisUri), timeout, circuitBreakerConfiguration); + } + + @VisibleForTesting + ProvisioningManager(final RedisClient redisClient, + final Duration timeout, + final CircuitBreakerConfiguration circuitBreakerConfiguration) { + + this.redisClient = redisClient; + this.redisClient.setDefaultTimeout(timeout); + + this.subscriptionConnection = redisClient.connectPubSub(new ByteArrayCodec()); + this.publicationConnection = redisClient.connect(new ByteArrayCodec()); + + this.circuitBreaker = CircuitBreaker.of("pubsub-breaker", circuitBreakerConfiguration.toCircuitBreakerConfig()); + + CircuitBreakerUtil.registerMetrics(circuitBreaker, ProvisioningManager.class); + + Metrics.gaugeMapSize(ACTIVE_LISTENERS_GAUGE_NAME, Tags.empty(), listenersByProvisioningAddress); + } + + @Override + public void start() throws Exception { + subscriptionConnection.addListener(this); + } + + @Override + public void stop() throws Exception { + subscriptionConnection.removeListener(this); + + subscriptionConnection.close(); + publicationConnection.close(); + + redisClient.shutdown(); + } + + public void addListener(final ProvisioningAddress address, final Consumer listener) { + listenersByProvisioningAddress.put(address, listener); + + circuitBreaker.executeRunnable( + () -> subscriptionConnection.sync().subscribe(address.serialize().getBytes(StandardCharsets.UTF_8))); + } + + public void removeListener(final ProvisioningAddress address) { + RedisOperation.unchecked(() -> circuitBreaker.executeRunnable( + () -> subscriptionConnection.sync().unsubscribe(address.serialize().getBytes(StandardCharsets.UTF_8)))); + + listenersByProvisioningAddress.remove(address); + } + + public boolean sendProvisioningMessage(final ProvisioningAddress address, final byte[] body) { + final PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.newBuilder() + .setType(PubSubProtos.PubSubMessage.Type.DELIVER) + .setContent(ByteString.copyFrom(body)) + .build(); + + final boolean receiverPresent = circuitBreaker.executeSupplier( + () -> publicationConnection.sync() + .publish(address.serialize().getBytes(StandardCharsets.UTF_8), pubSubMessage.toByteArray()) > 0); + + Metrics.counter(SEND_PROVISIONING_MESSAGE_COUNTER_NAME, "online", String.valueOf(receiverPresent)).increment(); + + return receiverPresent; + } + + @Override + public void message(final byte[] channel, final byte[] message) { + try { + final ProvisioningAddress address = new ProvisioningAddress(new String(channel, StandardCharsets.UTF_8)); + final PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.parseFrom(message); + + if (pubSubMessage.getType() == PubSubProtos.PubSubMessage.Type.DELIVER) { + final Consumer listener = listenersByProvisioningAddress.get(address); + + boolean listenerPresent = false; + + if (listener != null) { + listenerPresent = true; + listener.accept(pubSubMessage); } + + Metrics.counter(RECEIVE_PROVISIONING_MESSAGE_COUNTER_NAME, "listenerPresent", String.valueOf(listenerPresent)).increment(); + } + } catch (final InvalidWebsocketAddressException e) { + logger.warn("Failed to parse provisioning address", e); + } catch (final InvalidProtocolBufferException e) { + logger.warn("Failed to parse pub/sub message", e); } + } + + @Override + public void unsubscribed(final byte[] channel, final long count) { + try { + listenersByProvisioningAddress.remove(new ProvisioningAddress(new String(channel))); + } catch (final InvalidWebsocketAddressException e) { + logger.warn("Failed to parse provisioning address for `unsubscribe` event", e); + } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java index e1249da43..4be1431e1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java @@ -5,30 +5,41 @@ package org.whispersystems.textsecuregcm.websocket; -import org.whispersystems.textsecuregcm.storage.PubSubManager; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.push.ProvisioningManager; +import org.whispersystems.textsecuregcm.storage.PubSubProtos; +import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; +import java.util.List; +import java.util.Optional; public class ProvisioningConnectListener implements WebSocketConnectListener { - private final PubSubManager pubSubManager; + private final ProvisioningManager provisioningManager; - public ProvisioningConnectListener(PubSubManager pubSubManager) { - this.pubSubManager = pubSubManager; + public ProvisioningConnectListener(final ProvisioningManager provisioningManager) { + this.provisioningManager = provisioningManager; } @Override public void onWebSocketConnect(WebSocketSessionContext context) { - final ProvisioningConnection connection = new ProvisioningConnection(context.getClient()); - final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate(); + final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate(); + context.addListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress)); - pubSubManager.subscribe(provisioningAddress, connection); + provisioningManager.addListener(provisioningAddress, message -> { + assert message.getType() == PubSubProtos.PubSubMessage.Type.DELIVER; - context.addListener(new WebSocketSessionContext.WebSocketEventListener() { - @Override - public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) { - pubSubManager.unsubscribe(provisioningAddress, connection); - } + final Optional body = Optional.of(message.getContent().toByteArray()); + + context.getClient().sendRequest("PUT", "/v1/message", List.of(HeaderUtils.getTimestampHeader()), body) + .whenComplete((ignored, throwable) -> context.getClient().close(1000, "Closed")); }); + + context.getClient().sendRequest("PUT", "/v1/address", List.of(HeaderUtils.getTimestampHeader()), + Optional.of(MessageProtos.ProvisioningUuid.newBuilder() + .setUuid(provisioningAddress.getAddress()) + .build() + .toByteArray())); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java deleted file mode 100644 index c8521a26b..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.websocket; - -import com.google.protobuf.InvalidProtocolBufferException; -import java.util.Collections; -import java.util.Optional; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.dispatch.DispatchChannel; -import org.whispersystems.textsecuregcm.entities.MessageProtos.ProvisioningUuid; -import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; -import org.whispersystems.textsecuregcm.util.HeaderUtils; -import org.whispersystems.websocket.WebSocketClient; - -public class ProvisioningConnection implements DispatchChannel { - - private final Logger logger = LoggerFactory.getLogger(ProvisioningConnection.class); - - private final WebSocketClient client; - - public ProvisioningConnection(WebSocketClient client) { - this.client = client; - } - - @Override - public void onDispatchMessage(String channel, byte[] message) { - try { - PubSubMessage outgoingMessage = PubSubMessage.parseFrom(message); - - if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) { - Optional body = Optional.of(outgoingMessage.getContent().toByteArray()); - - client.sendRequest("PUT", "/v1/message", Collections.singletonList(HeaderUtils.getTimestampHeader()), body) - .thenAccept(response -> client.close(1001, "All you get.")) - .exceptionally(throwable -> { - client.close(1001, "That's all!"); - return null; - }); - } - } catch (InvalidProtocolBufferException e) { - logger.warn("Protobuf Error: ", e); - } - } - - @Override - public void onDispatchSubscribed(String channel) { - try { - ProvisioningAddress address = new ProvisioningAddress(channel); - this.client.sendRequest("PUT", "/v1/address", Collections.singletonList(HeaderUtils.getTimestampHeader()), - Optional.of(ProvisioningUuid.newBuilder() - .setUuid(address.getAddress()) - .build() - .toByteArray())); - - } catch (InvalidWebsocketAddressException e) { - logger.warn("Badly formatted address", e); - this.client.close(1001, "Server Error"); - } - } - - @Override - public void onDispatchUnsubscribed(String channel) { - this.client.close(1001, "Closed"); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java new file mode 100644 index 000000000..ceb4c9daf --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java @@ -0,0 +1,82 @@ +package org.whispersystems.textsecuregcm.push; + +import com.google.protobuf.ByteString; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.redis.RedisSingletonExtension; +import org.whispersystems.textsecuregcm.storage.PubSubProtos; +import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; + +import java.time.Duration; +import java.util.Random; +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.after; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +class ProvisioningManagerTest { + + private ProvisioningManager provisioningManager; + + @RegisterExtension + static final RedisSingletonExtension REDIS_EXTENSION = RedisSingletonExtension.builder().build(); + + private static final long PUBSUB_TIMEOUT_MILLIS = 1_000; + + @BeforeEach + void setUp() throws Exception { + provisioningManager = new ProvisioningManager(REDIS_EXTENSION.getRedisClient(), Duration.ofSeconds(1), new CircuitBreakerConfiguration()); + provisioningManager.start(); + } + + @AfterEach + void tearDown() throws Exception { + provisioningManager.stop(); + } + + @Test + void sendProvisioningMessage() { + final ProvisioningAddress address = new ProvisioningAddress("address", 0); + + final byte[] content = new byte[16]; + new Random().nextBytes(content); + + @SuppressWarnings("unchecked") final Consumer subscribedConsumer = mock(Consumer.class); + + provisioningManager.addListener(address, subscribedConsumer); + provisioningManager.sendProvisioningMessage(address, content); + + final ArgumentCaptor messageCaptor = + ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class); + + verify(subscribedConsumer, timeout(PUBSUB_TIMEOUT_MILLIS)).accept(messageCaptor.capture()); + + assertEquals(PubSubProtos.PubSubMessage.Type.DELIVER, messageCaptor.getValue().getType()); + assertEquals(ByteString.copyFrom(content), messageCaptor.getValue().getContent()); + } + + @Test + void removeListener() { + final ProvisioningAddress address = new ProvisioningAddress("address", 0); + + final byte[] content = new byte[16]; + new Random().nextBytes(content); + + @SuppressWarnings("unchecked") final Consumer subscribedConsumer = mock(Consumer.class); + + provisioningManager.addListener(address, subscribedConsumer); + provisioningManager.removeListener(address); + provisioningManager.sendProvisioningMessage(address, content); + + // Make sure that we give the message enough time to show up (if it was going to) before declaring victory + verify(subscribedConsumer, after(PUBSUB_TIMEOUT_MILLIS).never()).accept(any()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisSingletonExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisSingletonExtension.java new file mode 100644 index 000000000..cf1b28eb2 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisSingletonExtension.java @@ -0,0 +1,85 @@ +/* + * Copyright 2013-2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.redis; + +import static org.junit.jupiter.api.Assumptions.assumeFalse; + +import io.lettuce.core.RedisClient; +import io.lettuce.core.api.StatefulRedisConnection; +import java.io.IOException; +import java.net.ServerSocket; +import org.junit.jupiter.api.extension.AfterAllCallback; +import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import redis.embedded.RedisServer; + +public class RedisSingletonExtension implements BeforeAllCallback, BeforeEachCallback, AfterAllCallback, AfterEachCallback { + + private static RedisServer redisServer; + private RedisClient redisClient; + + public static class RedisSingletonExtensionBuilder { + + private RedisSingletonExtensionBuilder() { + } + + public RedisSingletonExtension build() { + return new RedisSingletonExtension(); + } + } + + public static RedisSingletonExtensionBuilder builder() { + return new RedisSingletonExtensionBuilder(); + } + + @Override + public void beforeAll(final ExtensionContext context) throws Exception { + assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows")); + + redisServer = RedisServer.builder() + .setting("appendonly no") + .setting("save \"\"") + .setting("dir " + System.getProperty("java.io.tmpdir")) + .port(getAvailablePort()) + .build(); + + redisServer.start(); + } + + @Override + public void beforeEach(final ExtensionContext context) { + redisClient = RedisClient.create(String.format("redis://127.0.0.1:%d", redisServer.ports().get(0))); + + try (final StatefulRedisConnection connection = redisClient.connect()) { + connection.sync().flushall(); + } + } + + @Override + public void afterEach(final ExtensionContext context) { + redisClient.shutdown(); + } + + @Override + public void afterAll(final ExtensionContext context) { + if (redisServer != null) { + redisServer.stop(); + } + } + + public RedisClient getRedisClient() { + return redisClient; + } + + private static int getAvailablePort() throws IOException { + try (ServerSocket socket = new ServerSocket(0)) { + socket.setReuseAddress(false); + return socket.getLocalPort(); + } + } +}