Refactor provisioning plumbing to use Lettuce

This commit is contained in:
Jon Chambers 2022-08-16 12:57:33 -04:00 committed by Jon Chambers
parent ae70d1113c
commit 11829d1f9f
8 changed files with 342 additions and 135 deletions

View File

@ -91,9 +91,7 @@ clientPresenceCluster: # Redis server configuration for client presence cluster
configurationUri: redis://redis.example.com:6379/
pubsub: # Redis server configuration for pubsub cluster
url: redis://redis.example.com:6379/
replicaUrls:
- redis://redis.example.com:6379/
uri: redis://redis.example.com:6379/
pushSchedulerCluster: # Redis server configuration for push scheduler cluster
configurationUri: redis://redis.example.com:6379/

View File

@ -41,7 +41,6 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.ServiceLoader;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
@ -65,7 +64,6 @@ import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation;
import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
@ -147,7 +145,6 @@ import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge;
import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener;
import org.whispersystems.textsecuregcm.metrics.TrafficSource;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler;
@ -160,7 +157,6 @@ import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.ConnectionEventLogger;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
@ -198,7 +194,6 @@ import org.whispersystems.textsecuregcm.storage.NonNormalizedAccountCrawlerListe
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PushChallengeDynamoDb;
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
@ -387,11 +382,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final VerificationSessions verificationSessions = new VerificationSessions(dynamoDbAsyncClient,
config.getDynamoDbTables().getVerificationSessions().getTableName(), clock);
RedisClientFactory pubSubClientFactory = new RedisClientFactory("pubsub_cache",
config.getPubsubCacheConfiguration().getUrl(), config.getPubsubCacheConfiguration().getReplicaUrls(),
config.getPubsubCacheConfiguration().getCircuitBreakerConfiguration());
ReplicatedJedisPool pubsubClient = pubSubClientFactory.getRedisClientPool();
ClientResources redisClientResources = ClientResources.builder().build();
ConnectionEventLogger.logConnectionEvents(redisClientResources);
@ -530,14 +520,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
pendingAccountsManager, secureStorageClient, secureBackupClient, secureValueRecovery2Client, clientPresenceManager,
experimentEnrollmentManager, registrationRecoveryPasswordsManager, clock);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
DispatchManager dispatchManager = new DispatchManager(pubSubClientFactory, Optional.empty());
PubSubManager pubSubManager = new PubSubManager(pubsubClient, dispatchManager);
APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration());
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials());
ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, apnSender, accountsManager);
PushNotificationManager pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler, pushLatencyManager, dynamicConfigurationManager);
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(), dynamicConfigurationManager, rateLimitersCluster);
ProvisioningManager provisioningManager = new ProvisioningManager(pubSubManager);
ProvisioningManager provisioningManager = new ProvisioningManager(config.getPubsubCacheConfiguration().getUri(), redisClientResources, config.getPubsubCacheConfiguration().getTimeout(), config.getPubsubCacheConfiguration().getCircuitBreakerConfiguration());
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
config.getDynamoDbTables().getIssuedReceipts().getTableName(),
config.getDynamoDbTables().getIssuedReceipts().getExpiration(),
@ -647,7 +635,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(apnSender);
environment.lifecycle().manage(apnPushNotificationScheduler);
environment.lifecycle().manage(pubSubManager);
environment.lifecycle().manage(provisioningManager);
environment.lifecycle().manage(accountDatabaseCrawler);
environment.lifecycle().manage(directoryReconciliationAccountDatabaseCrawler);
environment.lifecycle().manage(accountCleanerAccountDatabaseCrawler);
@ -821,7 +809,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment = new WebSocketEnvironment<>(environment,
webSocketEnvironment.getRequestLog(), 60000);
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager));
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager));
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));

View File

@ -16,11 +16,7 @@ public class RedisConfiguration {
@JsonProperty
@NotEmpty
private String url;
@JsonProperty
@NotNull
private List<String> replicaUrls;
private String uri;
@JsonProperty
@NotNull
@ -31,12 +27,8 @@ public class RedisConfiguration {
@Valid
private CircuitBreakerConfiguration circuitBreaker = new CircuitBreakerConfiguration();
public String getUrl() {
return url;
}
public List<String> getReplicaUrls() {
return replicaUrls;
public String getUri() {
return uri;
}
public Duration getTimeout() {

View File

@ -5,37 +5,157 @@
package org.whispersystems.textsecuregcm.push;
import com.google.protobuf.ByteString;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
import static com.codahale.metrics.MetricRegistry.name;
public class ProvisioningManager {
private final PubSubManager pubSubManager;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
import io.lettuce.core.RedisClient;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.codec.ByteArrayCodec;
import io.lettuce.core.pubsub.RedisPubSubAdapter;
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection;
import io.lettuce.core.resource.ClientResources;
import io.micrometer.core.instrument.Metrics;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import io.micrometer.core.instrument.Tags;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil;
import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
private final Counter provisioningMessageOnlineCounter = Metrics.counter(name(getClass(), "sendProvisioningMessage"), "online", "true");
private final Counter provisioningMessageOfflineCounter = Metrics.counter(name(getClass(), "sendProvisioningMessage"), "online", "false");
public class ProvisioningManager extends RedisPubSubAdapter<byte[], byte[]> implements Managed {
public ProvisioningManager(final PubSubManager pubSubManager) {
this.pubSubManager = pubSubManager;
}
private final RedisClient redisClient;
private final StatefulRedisPubSubConnection<byte[], byte[]> subscriptionConnection;
private final StatefulRedisConnection<byte[], byte[]> publicationConnection;
public boolean sendProvisioningMessage(ProvisioningAddress address, byte[] body) {
PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.DELIVER)
.setContent(ByteString.copyFrom(body))
.build();
private final CircuitBreaker circuitBreaker;
if (pubSubManager.publish(address, pubSubMessage)) {
provisioningMessageOnlineCounter.increment();
return true;
} else {
provisioningMessageOfflineCounter.increment();
return false;
private final Map<ProvisioningAddress, Consumer<PubSubProtos.PubSubMessage>> listenersByProvisioningAddress =
new ConcurrentHashMap<>();
private static final String ACTIVE_LISTENERS_GAUGE_NAME = name(ProvisioningManager.class, "activeListeners");
private static final String SEND_PROVISIONING_MESSAGE_COUNTER_NAME =
name(ProvisioningManager.class, "sendProvisioningMessage");
private static final String RECEIVE_PROVISIONING_MESSAGE_COUNTER_NAME =
name(ProvisioningManager.class, "receiveProvisioningMessage");
private static final Logger logger = LoggerFactory.getLogger(ProvisioningManager.class);
public ProvisioningManager(final String redisUri,
final ClientResources clientResources,
final Duration timeout,
final CircuitBreakerConfiguration circuitBreakerConfiguration) {
this(RedisClient.create(clientResources, redisUri), timeout, circuitBreakerConfiguration);
}
@VisibleForTesting
ProvisioningManager(final RedisClient redisClient,
final Duration timeout,
final CircuitBreakerConfiguration circuitBreakerConfiguration) {
this.redisClient = redisClient;
this.redisClient.setDefaultTimeout(timeout);
this.subscriptionConnection = redisClient.connectPubSub(new ByteArrayCodec());
this.publicationConnection = redisClient.connect(new ByteArrayCodec());
this.circuitBreaker = CircuitBreaker.of("pubsub-breaker", circuitBreakerConfiguration.toCircuitBreakerConfig());
CircuitBreakerUtil.registerMetrics(circuitBreaker, ProvisioningManager.class);
Metrics.gaugeMapSize(ACTIVE_LISTENERS_GAUGE_NAME, Tags.empty(), listenersByProvisioningAddress);
}
@Override
public void start() throws Exception {
subscriptionConnection.addListener(this);
}
@Override
public void stop() throws Exception {
subscriptionConnection.removeListener(this);
subscriptionConnection.close();
publicationConnection.close();
redisClient.shutdown();
}
public void addListener(final ProvisioningAddress address, final Consumer<PubSubProtos.PubSubMessage> listener) {
listenersByProvisioningAddress.put(address, listener);
circuitBreaker.executeRunnable(
() -> subscriptionConnection.sync().subscribe(address.serialize().getBytes(StandardCharsets.UTF_8)));
}
public void removeListener(final ProvisioningAddress address) {
RedisOperation.unchecked(() -> circuitBreaker.executeRunnable(
() -> subscriptionConnection.sync().unsubscribe(address.serialize().getBytes(StandardCharsets.UTF_8))));
listenersByProvisioningAddress.remove(address);
}
public boolean sendProvisioningMessage(final ProvisioningAddress address, final byte[] body) {
final PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.DELIVER)
.setContent(ByteString.copyFrom(body))
.build();
final boolean receiverPresent = circuitBreaker.executeSupplier(
() -> publicationConnection.sync()
.publish(address.serialize().getBytes(StandardCharsets.UTF_8), pubSubMessage.toByteArray()) > 0);
Metrics.counter(SEND_PROVISIONING_MESSAGE_COUNTER_NAME, "online", String.valueOf(receiverPresent)).increment();
return receiverPresent;
}
@Override
public void message(final byte[] channel, final byte[] message) {
try {
final ProvisioningAddress address = new ProvisioningAddress(new String(channel, StandardCharsets.UTF_8));
final PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.parseFrom(message);
if (pubSubMessage.getType() == PubSubProtos.PubSubMessage.Type.DELIVER) {
final Consumer<PubSubProtos.PubSubMessage> listener = listenersByProvisioningAddress.get(address);
boolean listenerPresent = false;
if (listener != null) {
listenerPresent = true;
listener.accept(pubSubMessage);
}
Metrics.counter(RECEIVE_PROVISIONING_MESSAGE_COUNTER_NAME, "listenerPresent", String.valueOf(listenerPresent)).increment();
}
} catch (final InvalidWebsocketAddressException e) {
logger.warn("Failed to parse provisioning address", e);
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse pub/sub message", e);
}
}
@Override
public void unsubscribed(final byte[] channel, final long count) {
try {
listenersByProvisioningAddress.remove(new ProvisioningAddress(new String(channel)));
} catch (final InvalidWebsocketAddressException e) {
logger.warn("Failed to parse provisioning address for `unsubscribe` event", e);
}
}
}

View File

@ -5,30 +5,41 @@
package org.whispersystems.textsecuregcm.websocket;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
import java.util.List;
import java.util.Optional;
public class ProvisioningConnectListener implements WebSocketConnectListener {
private final PubSubManager pubSubManager;
private final ProvisioningManager provisioningManager;
public ProvisioningConnectListener(PubSubManager pubSubManager) {
this.pubSubManager = pubSubManager;
public ProvisioningConnectListener(final ProvisioningManager provisioningManager) {
this.provisioningManager = provisioningManager;
}
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
final ProvisioningConnection connection = new ProvisioningConnection(context.getClient());
final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate();
final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate();
context.addListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress));
pubSubManager.subscribe(provisioningAddress, connection);
provisioningManager.addListener(provisioningAddress, message -> {
assert message.getType() == PubSubProtos.PubSubMessage.Type.DELIVER;
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
pubSubManager.unsubscribe(provisioningAddress, connection);
}
final Optional<byte[]> body = Optional.of(message.getContent().toByteArray());
context.getClient().sendRequest("PUT", "/v1/message", List.of(HeaderUtils.getTimestampHeader()), body)
.whenComplete((ignored, throwable) -> context.getClient().close(1000, "Closed"));
});
context.getClient().sendRequest("PUT", "/v1/address", List.of(HeaderUtils.getTimestampHeader()),
Optional.of(MessageProtos.ProvisioningUuid.newBuilder()
.setUuid(provisioningAddress.getAddress())
.build()
.toByteArray()));
}
}

View File

@ -1,69 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.Collections;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ProvisioningUuid;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.websocket.WebSocketClient;
public class ProvisioningConnection implements DispatchChannel {
private final Logger logger = LoggerFactory.getLogger(ProvisioningConnection.class);
private final WebSocketClient client;
public ProvisioningConnection(WebSocketClient client) {
this.client = client;
}
@Override
public void onDispatchMessage(String channel, byte[] message) {
try {
PubSubMessage outgoingMessage = PubSubMessage.parseFrom(message);
if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
Optional<byte[]> body = Optional.of(outgoingMessage.getContent().toByteArray());
client.sendRequest("PUT", "/v1/message", Collections.singletonList(HeaderUtils.getTimestampHeader()), body)
.thenAccept(response -> client.close(1001, "All you get."))
.exceptionally(throwable -> {
client.close(1001, "That's all!");
return null;
});
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Protobuf Error: ", e);
}
}
@Override
public void onDispatchSubscribed(String channel) {
try {
ProvisioningAddress address = new ProvisioningAddress(channel);
this.client.sendRequest("PUT", "/v1/address", Collections.singletonList(HeaderUtils.getTimestampHeader()),
Optional.of(ProvisioningUuid.newBuilder()
.setUuid(address.getAddress())
.build()
.toByteArray()));
} catch (InvalidWebsocketAddressException e) {
logger.warn("Badly formatted address", e);
this.client.close(1001, "Server Error");
}
}
@Override
public void onDispatchUnsubscribed(String channel) {
this.client.close(1001, "Closed");
}
}

View File

@ -0,0 +1,82 @@
package org.whispersystems.textsecuregcm.push;
import com.google.protobuf.ByteString;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.redis.RedisSingletonExtension;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
import java.time.Duration;
import java.util.Random;
import java.util.function.Consumer;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.after;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
class ProvisioningManagerTest {
private ProvisioningManager provisioningManager;
@RegisterExtension
static final RedisSingletonExtension REDIS_EXTENSION = RedisSingletonExtension.builder().build();
private static final long PUBSUB_TIMEOUT_MILLIS = 1_000;
@BeforeEach
void setUp() throws Exception {
provisioningManager = new ProvisioningManager(REDIS_EXTENSION.getRedisClient(), Duration.ofSeconds(1), new CircuitBreakerConfiguration());
provisioningManager.start();
}
@AfterEach
void tearDown() throws Exception {
provisioningManager.stop();
}
@Test
void sendProvisioningMessage() {
final ProvisioningAddress address = new ProvisioningAddress("address", 0);
final byte[] content = new byte[16];
new Random().nextBytes(content);
@SuppressWarnings("unchecked") final Consumer<PubSubProtos.PubSubMessage> subscribedConsumer = mock(Consumer.class);
provisioningManager.addListener(address, subscribedConsumer);
provisioningManager.sendProvisioningMessage(address, content);
final ArgumentCaptor<PubSubProtos.PubSubMessage> messageCaptor =
ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class);
verify(subscribedConsumer, timeout(PUBSUB_TIMEOUT_MILLIS)).accept(messageCaptor.capture());
assertEquals(PubSubProtos.PubSubMessage.Type.DELIVER, messageCaptor.getValue().getType());
assertEquals(ByteString.copyFrom(content), messageCaptor.getValue().getContent());
}
@Test
void removeListener() {
final ProvisioningAddress address = new ProvisioningAddress("address", 0);
final byte[] content = new byte[16];
new Random().nextBytes(content);
@SuppressWarnings("unchecked") final Consumer<PubSubProtos.PubSubMessage> subscribedConsumer = mock(Consumer.class);
provisioningManager.addListener(address, subscribedConsumer);
provisioningManager.removeListener(address);
provisioningManager.sendProvisioningMessage(address, content);
// Make sure that we give the message enough time to show up (if it was going to) before declaring victory
verify(subscribedConsumer, after(PUBSUB_TIMEOUT_MILLIS).never()).accept(any());
}
}

View File

@ -0,0 +1,85 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.redis;
import static org.junit.jupiter.api.Assumptions.assumeFalse;
import io.lettuce.core.RedisClient;
import io.lettuce.core.api.StatefulRedisConnection;
import java.io.IOException;
import java.net.ServerSocket;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import redis.embedded.RedisServer;
public class RedisSingletonExtension implements BeforeAllCallback, BeforeEachCallback, AfterAllCallback, AfterEachCallback {
private static RedisServer redisServer;
private RedisClient redisClient;
public static class RedisSingletonExtensionBuilder {
private RedisSingletonExtensionBuilder() {
}
public RedisSingletonExtension build() {
return new RedisSingletonExtension();
}
}
public static RedisSingletonExtensionBuilder builder() {
return new RedisSingletonExtensionBuilder();
}
@Override
public void beforeAll(final ExtensionContext context) throws Exception {
assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows"));
redisServer = RedisServer.builder()
.setting("appendonly no")
.setting("save \"\"")
.setting("dir " + System.getProperty("java.io.tmpdir"))
.port(getAvailablePort())
.build();
redisServer.start();
}
@Override
public void beforeEach(final ExtensionContext context) {
redisClient = RedisClient.create(String.format("redis://127.0.0.1:%d", redisServer.ports().get(0)));
try (final StatefulRedisConnection<String, String> connection = redisClient.connect()) {
connection.sync().flushall();
}
}
@Override
public void afterEach(final ExtensionContext context) {
redisClient.shutdown();
}
@Override
public void afterAll(final ExtensionContext context) {
if (redisServer != null) {
redisServer.stop();
}
}
public RedisClient getRedisClient() {
return redisClient;
}
private static int getAvailablePort() throws IOException {
try (ServerSocket socket = new ServerSocket(0)) {
socket.setReuseAddress(false);
return socket.getLocalPort();
}
}
}