Refactor provisioning plumbing to use Lettuce
This commit is contained in:
parent
ae70d1113c
commit
11829d1f9f
|
@ -91,9 +91,7 @@ clientPresenceCluster: # Redis server configuration for client presence cluster
|
|||
configurationUri: redis://redis.example.com:6379/
|
||||
|
||||
pubsub: # Redis server configuration for pubsub cluster
|
||||
url: redis://redis.example.com:6379/
|
||||
replicaUrls:
|
||||
- redis://redis.example.com:6379/
|
||||
uri: redis://redis.example.com:6379/
|
||||
|
||||
pushSchedulerCluster: # Redis server configuration for push scheduler cluster
|
||||
configurationUri: redis://redis.example.com:6379/
|
||||
|
|
|
@ -41,7 +41,6 @@ import java.util.ArrayList;
|
|||
import java.util.Collections;
|
||||
import java.util.EnumSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.ServiceLoader;
|
||||
import java.util.concurrent.ArrayBlockingQueue;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
|
@ -65,7 +64,6 @@ import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation;
|
|||
import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.dispatch.DispatchManager;
|
||||
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
|
||||
|
@ -147,7 +145,6 @@ import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge;
|
|||
import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener;
|
||||
import org.whispersystems.textsecuregcm.metrics.TrafficSource;
|
||||
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
|
||||
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
|
||||
import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck;
|
||||
import org.whispersystems.textsecuregcm.push.APNSender;
|
||||
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler;
|
||||
|
@ -160,7 +157,6 @@ import org.whispersystems.textsecuregcm.push.PushNotificationManager;
|
|||
import org.whispersystems.textsecuregcm.push.ReceiptSender;
|
||||
import org.whispersystems.textsecuregcm.redis.ConnectionEventLogger;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
||||
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
|
||||
import org.whispersystems.textsecuregcm.s3.PolicySigner;
|
||||
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
|
||||
|
@ -198,7 +194,6 @@ import org.whispersystems.textsecuregcm.storage.NonNormalizedAccountCrawlerListe
|
|||
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
|
||||
import org.whispersystems.textsecuregcm.storage.Profiles;
|
||||
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
|
||||
import org.whispersystems.textsecuregcm.storage.PubSubManager;
|
||||
import org.whispersystems.textsecuregcm.storage.PushChallengeDynamoDb;
|
||||
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
|
||||
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
|
||||
|
@ -387,11 +382,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
final VerificationSessions verificationSessions = new VerificationSessions(dynamoDbAsyncClient,
|
||||
config.getDynamoDbTables().getVerificationSessions().getTableName(), clock);
|
||||
|
||||
RedisClientFactory pubSubClientFactory = new RedisClientFactory("pubsub_cache",
|
||||
config.getPubsubCacheConfiguration().getUrl(), config.getPubsubCacheConfiguration().getReplicaUrls(),
|
||||
config.getPubsubCacheConfiguration().getCircuitBreakerConfiguration());
|
||||
ReplicatedJedisPool pubsubClient = pubSubClientFactory.getRedisClientPool();
|
||||
|
||||
ClientResources redisClientResources = ClientResources.builder().build();
|
||||
ConnectionEventLogger.logConnectionEvents(redisClientResources);
|
||||
|
||||
|
@ -530,14 +520,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
pendingAccountsManager, secureStorageClient, secureBackupClient, secureValueRecovery2Client, clientPresenceManager,
|
||||
experimentEnrollmentManager, registrationRecoveryPasswordsManager, clock);
|
||||
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
|
||||
DispatchManager dispatchManager = new DispatchManager(pubSubClientFactory, Optional.empty());
|
||||
PubSubManager pubSubManager = new PubSubManager(pubsubClient, dispatchManager);
|
||||
APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration());
|
||||
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials());
|
||||
ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, apnSender, accountsManager);
|
||||
PushNotificationManager pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler, pushLatencyManager, dynamicConfigurationManager);
|
||||
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(), dynamicConfigurationManager, rateLimitersCluster);
|
||||
ProvisioningManager provisioningManager = new ProvisioningManager(pubSubManager);
|
||||
ProvisioningManager provisioningManager = new ProvisioningManager(config.getPubsubCacheConfiguration().getUri(), redisClientResources, config.getPubsubCacheConfiguration().getTimeout(), config.getPubsubCacheConfiguration().getCircuitBreakerConfiguration());
|
||||
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
|
||||
config.getDynamoDbTables().getIssuedReceipts().getTableName(),
|
||||
config.getDynamoDbTables().getIssuedReceipts().getExpiration(),
|
||||
|
@ -647,7 +635,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
|
||||
environment.lifecycle().manage(apnSender);
|
||||
environment.lifecycle().manage(apnPushNotificationScheduler);
|
||||
environment.lifecycle().manage(pubSubManager);
|
||||
environment.lifecycle().manage(provisioningManager);
|
||||
environment.lifecycle().manage(accountDatabaseCrawler);
|
||||
environment.lifecycle().manage(directoryReconciliationAccountDatabaseCrawler);
|
||||
environment.lifecycle().manage(accountCleanerAccountDatabaseCrawler);
|
||||
|
@ -821,7 +809,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment = new WebSocketEnvironment<>(environment,
|
||||
webSocketEnvironment.getRequestLog(), 60000);
|
||||
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
|
||||
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager));
|
||||
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager));
|
||||
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
|
||||
provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
|
||||
|
||||
|
|
|
@ -16,11 +16,7 @@ public class RedisConfiguration {
|
|||
|
||||
@JsonProperty
|
||||
@NotEmpty
|
||||
private String url;
|
||||
|
||||
@JsonProperty
|
||||
@NotNull
|
||||
private List<String> replicaUrls;
|
||||
private String uri;
|
||||
|
||||
@JsonProperty
|
||||
@NotNull
|
||||
|
@ -31,12 +27,8 @@ public class RedisConfiguration {
|
|||
@Valid
|
||||
private CircuitBreakerConfiguration circuitBreaker = new CircuitBreakerConfiguration();
|
||||
|
||||
public String getUrl() {
|
||||
return url;
|
||||
}
|
||||
|
||||
public List<String> getReplicaUrls() {
|
||||
return replicaUrls;
|
||||
public String getUri() {
|
||||
return uri;
|
||||
}
|
||||
|
||||
public Duration getTimeout() {
|
||||
|
|
|
@ -5,37 +5,157 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.push;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import org.whispersystems.textsecuregcm.storage.PubSubManager;
|
||||
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
|
||||
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
|
||||
|
||||
import static com.codahale.metrics.MetricRegistry.name;
|
||||
|
||||
public class ProvisioningManager {
|
||||
private final PubSubManager pubSubManager;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.dropwizard.lifecycle.Managed;
|
||||
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
|
||||
import io.lettuce.core.RedisClient;
|
||||
import io.lettuce.core.api.StatefulRedisConnection;
|
||||
import io.lettuce.core.codec.ByteArrayCodec;
|
||||
import io.lettuce.core.pubsub.RedisPubSubAdapter;
|
||||
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection;
|
||||
import io.lettuce.core.resource.ClientResources;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Consumer;
|
||||
import io.micrometer.core.instrument.Tags;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.RedisOperation;
|
||||
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
|
||||
import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil;
|
||||
import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException;
|
||||
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
|
||||
|
||||
private final Counter provisioningMessageOnlineCounter = Metrics.counter(name(getClass(), "sendProvisioningMessage"), "online", "true");
|
||||
private final Counter provisioningMessageOfflineCounter = Metrics.counter(name(getClass(), "sendProvisioningMessage"), "online", "false");
|
||||
public class ProvisioningManager extends RedisPubSubAdapter<byte[], byte[]> implements Managed {
|
||||
|
||||
public ProvisioningManager(final PubSubManager pubSubManager) {
|
||||
this.pubSubManager = pubSubManager;
|
||||
}
|
||||
private final RedisClient redisClient;
|
||||
private final StatefulRedisPubSubConnection<byte[], byte[]> subscriptionConnection;
|
||||
private final StatefulRedisConnection<byte[], byte[]> publicationConnection;
|
||||
|
||||
public boolean sendProvisioningMessage(ProvisioningAddress address, byte[] body) {
|
||||
PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.newBuilder()
|
||||
.setType(PubSubProtos.PubSubMessage.Type.DELIVER)
|
||||
.setContent(ByteString.copyFrom(body))
|
||||
.build();
|
||||
private final CircuitBreaker circuitBreaker;
|
||||
|
||||
if (pubSubManager.publish(address, pubSubMessage)) {
|
||||
provisioningMessageOnlineCounter.increment();
|
||||
return true;
|
||||
} else {
|
||||
provisioningMessageOfflineCounter.increment();
|
||||
return false;
|
||||
private final Map<ProvisioningAddress, Consumer<PubSubProtos.PubSubMessage>> listenersByProvisioningAddress =
|
||||
new ConcurrentHashMap<>();
|
||||
|
||||
private static final String ACTIVE_LISTENERS_GAUGE_NAME = name(ProvisioningManager.class, "activeListeners");
|
||||
|
||||
private static final String SEND_PROVISIONING_MESSAGE_COUNTER_NAME =
|
||||
name(ProvisioningManager.class, "sendProvisioningMessage");
|
||||
|
||||
private static final String RECEIVE_PROVISIONING_MESSAGE_COUNTER_NAME =
|
||||
name(ProvisioningManager.class, "receiveProvisioningMessage");
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(ProvisioningManager.class);
|
||||
|
||||
public ProvisioningManager(final String redisUri,
|
||||
final ClientResources clientResources,
|
||||
final Duration timeout,
|
||||
final CircuitBreakerConfiguration circuitBreakerConfiguration) {
|
||||
|
||||
this(RedisClient.create(clientResources, redisUri), timeout, circuitBreakerConfiguration);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
ProvisioningManager(final RedisClient redisClient,
|
||||
final Duration timeout,
|
||||
final CircuitBreakerConfiguration circuitBreakerConfiguration) {
|
||||
|
||||
this.redisClient = redisClient;
|
||||
this.redisClient.setDefaultTimeout(timeout);
|
||||
|
||||
this.subscriptionConnection = redisClient.connectPubSub(new ByteArrayCodec());
|
||||
this.publicationConnection = redisClient.connect(new ByteArrayCodec());
|
||||
|
||||
this.circuitBreaker = CircuitBreaker.of("pubsub-breaker", circuitBreakerConfiguration.toCircuitBreakerConfig());
|
||||
|
||||
CircuitBreakerUtil.registerMetrics(circuitBreaker, ProvisioningManager.class);
|
||||
|
||||
Metrics.gaugeMapSize(ACTIVE_LISTENERS_GAUGE_NAME, Tags.empty(), listenersByProvisioningAddress);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() throws Exception {
|
||||
subscriptionConnection.addListener(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stop() throws Exception {
|
||||
subscriptionConnection.removeListener(this);
|
||||
|
||||
subscriptionConnection.close();
|
||||
publicationConnection.close();
|
||||
|
||||
redisClient.shutdown();
|
||||
}
|
||||
|
||||
public void addListener(final ProvisioningAddress address, final Consumer<PubSubProtos.PubSubMessage> listener) {
|
||||
listenersByProvisioningAddress.put(address, listener);
|
||||
|
||||
circuitBreaker.executeRunnable(
|
||||
() -> subscriptionConnection.sync().subscribe(address.serialize().getBytes(StandardCharsets.UTF_8)));
|
||||
}
|
||||
|
||||
public void removeListener(final ProvisioningAddress address) {
|
||||
RedisOperation.unchecked(() -> circuitBreaker.executeRunnable(
|
||||
() -> subscriptionConnection.sync().unsubscribe(address.serialize().getBytes(StandardCharsets.UTF_8))));
|
||||
|
||||
listenersByProvisioningAddress.remove(address);
|
||||
}
|
||||
|
||||
public boolean sendProvisioningMessage(final ProvisioningAddress address, final byte[] body) {
|
||||
final PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.newBuilder()
|
||||
.setType(PubSubProtos.PubSubMessage.Type.DELIVER)
|
||||
.setContent(ByteString.copyFrom(body))
|
||||
.build();
|
||||
|
||||
final boolean receiverPresent = circuitBreaker.executeSupplier(
|
||||
() -> publicationConnection.sync()
|
||||
.publish(address.serialize().getBytes(StandardCharsets.UTF_8), pubSubMessage.toByteArray()) > 0);
|
||||
|
||||
Metrics.counter(SEND_PROVISIONING_MESSAGE_COUNTER_NAME, "online", String.valueOf(receiverPresent)).increment();
|
||||
|
||||
return receiverPresent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void message(final byte[] channel, final byte[] message) {
|
||||
try {
|
||||
final ProvisioningAddress address = new ProvisioningAddress(new String(channel, StandardCharsets.UTF_8));
|
||||
final PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.parseFrom(message);
|
||||
|
||||
if (pubSubMessage.getType() == PubSubProtos.PubSubMessage.Type.DELIVER) {
|
||||
final Consumer<PubSubProtos.PubSubMessage> listener = listenersByProvisioningAddress.get(address);
|
||||
|
||||
boolean listenerPresent = false;
|
||||
|
||||
if (listener != null) {
|
||||
listenerPresent = true;
|
||||
listener.accept(pubSubMessage);
|
||||
}
|
||||
|
||||
Metrics.counter(RECEIVE_PROVISIONING_MESSAGE_COUNTER_NAME, "listenerPresent", String.valueOf(listenerPresent)).increment();
|
||||
}
|
||||
} catch (final InvalidWebsocketAddressException e) {
|
||||
logger.warn("Failed to parse provisioning address", e);
|
||||
} catch (final InvalidProtocolBufferException e) {
|
||||
logger.warn("Failed to parse pub/sub message", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unsubscribed(final byte[] channel, final long count) {
|
||||
try {
|
||||
listenersByProvisioningAddress.remove(new ProvisioningAddress(new String(channel)));
|
||||
} catch (final InvalidWebsocketAddressException e) {
|
||||
logger.warn("Failed to parse provisioning address for `unsubscribe` event", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,30 +5,41 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.websocket;
|
||||
|
||||
import org.whispersystems.textsecuregcm.storage.PubSubManager;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
|
||||
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
|
||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||
import org.whispersystems.websocket.session.WebSocketSessionContext;
|
||||
import org.whispersystems.websocket.setup.WebSocketConnectListener;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public class ProvisioningConnectListener implements WebSocketConnectListener {
|
||||
|
||||
private final PubSubManager pubSubManager;
|
||||
private final ProvisioningManager provisioningManager;
|
||||
|
||||
public ProvisioningConnectListener(PubSubManager pubSubManager) {
|
||||
this.pubSubManager = pubSubManager;
|
||||
public ProvisioningConnectListener(final ProvisioningManager provisioningManager) {
|
||||
this.provisioningManager = provisioningManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onWebSocketConnect(WebSocketSessionContext context) {
|
||||
final ProvisioningConnection connection = new ProvisioningConnection(context.getClient());
|
||||
final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate();
|
||||
final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate();
|
||||
context.addListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress));
|
||||
|
||||
pubSubManager.subscribe(provisioningAddress, connection);
|
||||
provisioningManager.addListener(provisioningAddress, message -> {
|
||||
assert message.getType() == PubSubProtos.PubSubMessage.Type.DELIVER;
|
||||
|
||||
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
|
||||
@Override
|
||||
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
|
||||
pubSubManager.unsubscribe(provisioningAddress, connection);
|
||||
}
|
||||
final Optional<byte[]> body = Optional.of(message.getContent().toByteArray());
|
||||
|
||||
context.getClient().sendRequest("PUT", "/v1/message", List.of(HeaderUtils.getTimestampHeader()), body)
|
||||
.whenComplete((ignored, throwable) -> context.getClient().close(1000, "Closed"));
|
||||
});
|
||||
|
||||
context.getClient().sendRequest("PUT", "/v1/address", List.of(HeaderUtils.getTimestampHeader()),
|
||||
Optional.of(MessageProtos.ProvisioningUuid.newBuilder()
|
||||
.setUuid(provisioningAddress.getAddress())
|
||||
.build()
|
||||
.toByteArray()));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013-2020 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.websocket;
|
||||
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import java.util.Collections;
|
||||
import java.util.Optional;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.dispatch.DispatchChannel;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.ProvisioningUuid;
|
||||
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
|
||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||
import org.whispersystems.websocket.WebSocketClient;
|
||||
|
||||
public class ProvisioningConnection implements DispatchChannel {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(ProvisioningConnection.class);
|
||||
|
||||
private final WebSocketClient client;
|
||||
|
||||
public ProvisioningConnection(WebSocketClient client) {
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onDispatchMessage(String channel, byte[] message) {
|
||||
try {
|
||||
PubSubMessage outgoingMessage = PubSubMessage.parseFrom(message);
|
||||
|
||||
if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
|
||||
Optional<byte[]> body = Optional.of(outgoingMessage.getContent().toByteArray());
|
||||
|
||||
client.sendRequest("PUT", "/v1/message", Collections.singletonList(HeaderUtils.getTimestampHeader()), body)
|
||||
.thenAccept(response -> client.close(1001, "All you get."))
|
||||
.exceptionally(throwable -> {
|
||||
client.close(1001, "That's all!");
|
||||
return null;
|
||||
});
|
||||
}
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
logger.warn("Protobuf Error: ", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onDispatchSubscribed(String channel) {
|
||||
try {
|
||||
ProvisioningAddress address = new ProvisioningAddress(channel);
|
||||
this.client.sendRequest("PUT", "/v1/address", Collections.singletonList(HeaderUtils.getTimestampHeader()),
|
||||
Optional.of(ProvisioningUuid.newBuilder()
|
||||
.setUuid(address.getAddress())
|
||||
.build()
|
||||
.toByteArray()));
|
||||
|
||||
} catch (InvalidWebsocketAddressException e) {
|
||||
logger.warn("Badly formatted address", e);
|
||||
this.client.close(1001, "Server Error");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onDispatchUnsubscribed(String channel) {
|
||||
this.client.close(1001, "Closed");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package org.whispersystems.textsecuregcm.push;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.RedisSingletonExtension;
|
||||
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
|
||||
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.Random;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.after;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.timeout;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
class ProvisioningManagerTest {
|
||||
|
||||
private ProvisioningManager provisioningManager;
|
||||
|
||||
@RegisterExtension
|
||||
static final RedisSingletonExtension REDIS_EXTENSION = RedisSingletonExtension.builder().build();
|
||||
|
||||
private static final long PUBSUB_TIMEOUT_MILLIS = 1_000;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
provisioningManager = new ProvisioningManager(REDIS_EXTENSION.getRedisClient(), Duration.ofSeconds(1), new CircuitBreakerConfiguration());
|
||||
provisioningManager.start();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
provisioningManager.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
void sendProvisioningMessage() {
|
||||
final ProvisioningAddress address = new ProvisioningAddress("address", 0);
|
||||
|
||||
final byte[] content = new byte[16];
|
||||
new Random().nextBytes(content);
|
||||
|
||||
@SuppressWarnings("unchecked") final Consumer<PubSubProtos.PubSubMessage> subscribedConsumer = mock(Consumer.class);
|
||||
|
||||
provisioningManager.addListener(address, subscribedConsumer);
|
||||
provisioningManager.sendProvisioningMessage(address, content);
|
||||
|
||||
final ArgumentCaptor<PubSubProtos.PubSubMessage> messageCaptor =
|
||||
ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class);
|
||||
|
||||
verify(subscribedConsumer, timeout(PUBSUB_TIMEOUT_MILLIS)).accept(messageCaptor.capture());
|
||||
|
||||
assertEquals(PubSubProtos.PubSubMessage.Type.DELIVER, messageCaptor.getValue().getType());
|
||||
assertEquals(ByteString.copyFrom(content), messageCaptor.getValue().getContent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void removeListener() {
|
||||
final ProvisioningAddress address = new ProvisioningAddress("address", 0);
|
||||
|
||||
final byte[] content = new byte[16];
|
||||
new Random().nextBytes(content);
|
||||
|
||||
@SuppressWarnings("unchecked") final Consumer<PubSubProtos.PubSubMessage> subscribedConsumer = mock(Consumer.class);
|
||||
|
||||
provisioningManager.addListener(address, subscribedConsumer);
|
||||
provisioningManager.removeListener(address);
|
||||
provisioningManager.sendProvisioningMessage(address, content);
|
||||
|
||||
// Make sure that we give the message enough time to show up (if it was going to) before declaring victory
|
||||
verify(subscribedConsumer, after(PUBSUB_TIMEOUT_MILLIS).never()).accept(any());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Copyright 2013-2022 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.redis;
|
||||
|
||||
import static org.junit.jupiter.api.Assumptions.assumeFalse;
|
||||
|
||||
import io.lettuce.core.RedisClient;
|
||||
import io.lettuce.core.api.StatefulRedisConnection;
|
||||
import java.io.IOException;
|
||||
import java.net.ServerSocket;
|
||||
import org.junit.jupiter.api.extension.AfterAllCallback;
|
||||
import org.junit.jupiter.api.extension.AfterEachCallback;
|
||||
import org.junit.jupiter.api.extension.BeforeAllCallback;
|
||||
import org.junit.jupiter.api.extension.BeforeEachCallback;
|
||||
import org.junit.jupiter.api.extension.ExtensionContext;
|
||||
import redis.embedded.RedisServer;
|
||||
|
||||
public class RedisSingletonExtension implements BeforeAllCallback, BeforeEachCallback, AfterAllCallback, AfterEachCallback {
|
||||
|
||||
private static RedisServer redisServer;
|
||||
private RedisClient redisClient;
|
||||
|
||||
public static class RedisSingletonExtensionBuilder {
|
||||
|
||||
private RedisSingletonExtensionBuilder() {
|
||||
}
|
||||
|
||||
public RedisSingletonExtension build() {
|
||||
return new RedisSingletonExtension();
|
||||
}
|
||||
}
|
||||
|
||||
public static RedisSingletonExtensionBuilder builder() {
|
||||
return new RedisSingletonExtensionBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforeAll(final ExtensionContext context) throws Exception {
|
||||
assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows"));
|
||||
|
||||
redisServer = RedisServer.builder()
|
||||
.setting("appendonly no")
|
||||
.setting("save \"\"")
|
||||
.setting("dir " + System.getProperty("java.io.tmpdir"))
|
||||
.port(getAvailablePort())
|
||||
.build();
|
||||
|
||||
redisServer.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforeEach(final ExtensionContext context) {
|
||||
redisClient = RedisClient.create(String.format("redis://127.0.0.1:%d", redisServer.ports().get(0)));
|
||||
|
||||
try (final StatefulRedisConnection<String, String> connection = redisClient.connect()) {
|
||||
connection.sync().flushall();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterEach(final ExtensionContext context) {
|
||||
redisClient.shutdown();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterAll(final ExtensionContext context) {
|
||||
if (redisServer != null) {
|
||||
redisServer.stop();
|
||||
}
|
||||
}
|
||||
|
||||
public RedisClient getRedisClient() {
|
||||
return redisClient;
|
||||
}
|
||||
|
||||
private static int getAvailablePort() throws IOException {
|
||||
try (ServerSocket socket = new ServerSocket(0)) {
|
||||
socket.setReuseAddress(false);
|
||||
return socket.getLocalPort();
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue