Refactor provisioning plumbing to use Lettuce
This commit is contained in:
parent
ae70d1113c
commit
11829d1f9f
|
@ -91,9 +91,7 @@ clientPresenceCluster: # Redis server configuration for client presence cluster
|
||||||
configurationUri: redis://redis.example.com:6379/
|
configurationUri: redis://redis.example.com:6379/
|
||||||
|
|
||||||
pubsub: # Redis server configuration for pubsub cluster
|
pubsub: # Redis server configuration for pubsub cluster
|
||||||
url: redis://redis.example.com:6379/
|
uri: redis://redis.example.com:6379/
|
||||||
replicaUrls:
|
|
||||||
- redis://redis.example.com:6379/
|
|
||||||
|
|
||||||
pushSchedulerCluster: # Redis server configuration for push scheduler cluster
|
pushSchedulerCluster: # Redis server configuration for push scheduler cluster
|
||||||
configurationUri: redis://redis.example.com:6379/
|
configurationUri: redis://redis.example.com:6379/
|
||||||
|
|
|
@ -41,7 +41,6 @@ import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.EnumSet;
|
import java.util.EnumSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.ServiceLoader;
|
import java.util.ServiceLoader;
|
||||||
import java.util.concurrent.ArrayBlockingQueue;
|
import java.util.concurrent.ArrayBlockingQueue;
|
||||||
import java.util.concurrent.BlockingQueue;
|
import java.util.concurrent.BlockingQueue;
|
||||||
|
@ -65,7 +64,6 @@ import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation;
|
||||||
import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
|
import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.dispatch.DispatchManager;
|
|
||||||
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
||||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||||
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
|
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
|
||||||
|
@ -147,7 +145,6 @@ import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge;
|
||||||
import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener;
|
import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener;
|
||||||
import org.whispersystems.textsecuregcm.metrics.TrafficSource;
|
import org.whispersystems.textsecuregcm.metrics.TrafficSource;
|
||||||
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
|
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
|
||||||
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
|
|
||||||
import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck;
|
import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck;
|
||||||
import org.whispersystems.textsecuregcm.push.APNSender;
|
import org.whispersystems.textsecuregcm.push.APNSender;
|
||||||
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler;
|
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler;
|
||||||
|
@ -160,7 +157,6 @@ import org.whispersystems.textsecuregcm.push.PushNotificationManager;
|
||||||
import org.whispersystems.textsecuregcm.push.ReceiptSender;
|
import org.whispersystems.textsecuregcm.push.ReceiptSender;
|
||||||
import org.whispersystems.textsecuregcm.redis.ConnectionEventLogger;
|
import org.whispersystems.textsecuregcm.redis.ConnectionEventLogger;
|
||||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
|
||||||
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
|
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
|
||||||
import org.whispersystems.textsecuregcm.s3.PolicySigner;
|
import org.whispersystems.textsecuregcm.s3.PolicySigner;
|
||||||
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
|
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
|
||||||
|
@ -198,7 +194,6 @@ import org.whispersystems.textsecuregcm.storage.NonNormalizedAccountCrawlerListe
|
||||||
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
|
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
|
||||||
import org.whispersystems.textsecuregcm.storage.Profiles;
|
import org.whispersystems.textsecuregcm.storage.Profiles;
|
||||||
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
|
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.PubSubManager;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.PushChallengeDynamoDb;
|
import org.whispersystems.textsecuregcm.storage.PushChallengeDynamoDb;
|
||||||
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
|
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
|
||||||
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
|
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
|
||||||
|
@ -387,11 +382,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
final VerificationSessions verificationSessions = new VerificationSessions(dynamoDbAsyncClient,
|
final VerificationSessions verificationSessions = new VerificationSessions(dynamoDbAsyncClient,
|
||||||
config.getDynamoDbTables().getVerificationSessions().getTableName(), clock);
|
config.getDynamoDbTables().getVerificationSessions().getTableName(), clock);
|
||||||
|
|
||||||
RedisClientFactory pubSubClientFactory = new RedisClientFactory("pubsub_cache",
|
|
||||||
config.getPubsubCacheConfiguration().getUrl(), config.getPubsubCacheConfiguration().getReplicaUrls(),
|
|
||||||
config.getPubsubCacheConfiguration().getCircuitBreakerConfiguration());
|
|
||||||
ReplicatedJedisPool pubsubClient = pubSubClientFactory.getRedisClientPool();
|
|
||||||
|
|
||||||
ClientResources redisClientResources = ClientResources.builder().build();
|
ClientResources redisClientResources = ClientResources.builder().build();
|
||||||
ConnectionEventLogger.logConnectionEvents(redisClientResources);
|
ConnectionEventLogger.logConnectionEvents(redisClientResources);
|
||||||
|
|
||||||
|
@ -530,14 +520,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
pendingAccountsManager, secureStorageClient, secureBackupClient, secureValueRecovery2Client, clientPresenceManager,
|
pendingAccountsManager, secureStorageClient, secureBackupClient, secureValueRecovery2Client, clientPresenceManager,
|
||||||
experimentEnrollmentManager, registrationRecoveryPasswordsManager, clock);
|
experimentEnrollmentManager, registrationRecoveryPasswordsManager, clock);
|
||||||
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
|
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
|
||||||
DispatchManager dispatchManager = new DispatchManager(pubSubClientFactory, Optional.empty());
|
|
||||||
PubSubManager pubSubManager = new PubSubManager(pubsubClient, dispatchManager);
|
|
||||||
APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration());
|
APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration());
|
||||||
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials());
|
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials());
|
||||||
ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, apnSender, accountsManager);
|
ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, apnSender, accountsManager);
|
||||||
PushNotificationManager pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler, pushLatencyManager, dynamicConfigurationManager);
|
PushNotificationManager pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler, pushLatencyManager, dynamicConfigurationManager);
|
||||||
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(), dynamicConfigurationManager, rateLimitersCluster);
|
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(), dynamicConfigurationManager, rateLimitersCluster);
|
||||||
ProvisioningManager provisioningManager = new ProvisioningManager(pubSubManager);
|
ProvisioningManager provisioningManager = new ProvisioningManager(config.getPubsubCacheConfiguration().getUri(), redisClientResources, config.getPubsubCacheConfiguration().getTimeout(), config.getPubsubCacheConfiguration().getCircuitBreakerConfiguration());
|
||||||
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
|
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
|
||||||
config.getDynamoDbTables().getIssuedReceipts().getTableName(),
|
config.getDynamoDbTables().getIssuedReceipts().getTableName(),
|
||||||
config.getDynamoDbTables().getIssuedReceipts().getExpiration(),
|
config.getDynamoDbTables().getIssuedReceipts().getExpiration(),
|
||||||
|
@ -647,7 +635,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
|
|
||||||
environment.lifecycle().manage(apnSender);
|
environment.lifecycle().manage(apnSender);
|
||||||
environment.lifecycle().manage(apnPushNotificationScheduler);
|
environment.lifecycle().manage(apnPushNotificationScheduler);
|
||||||
environment.lifecycle().manage(pubSubManager);
|
environment.lifecycle().manage(provisioningManager);
|
||||||
environment.lifecycle().manage(accountDatabaseCrawler);
|
environment.lifecycle().manage(accountDatabaseCrawler);
|
||||||
environment.lifecycle().manage(directoryReconciliationAccountDatabaseCrawler);
|
environment.lifecycle().manage(directoryReconciliationAccountDatabaseCrawler);
|
||||||
environment.lifecycle().manage(accountCleanerAccountDatabaseCrawler);
|
environment.lifecycle().manage(accountCleanerAccountDatabaseCrawler);
|
||||||
|
@ -821,7 +809,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment = new WebSocketEnvironment<>(environment,
|
WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment = new WebSocketEnvironment<>(environment,
|
||||||
webSocketEnvironment.getRequestLog(), 60000);
|
webSocketEnvironment.getRequestLog(), 60000);
|
||||||
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
|
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
|
||||||
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager));
|
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager));
|
||||||
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
|
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
|
||||||
provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
|
provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
|
||||||
|
|
||||||
|
|
|
@ -16,11 +16,7 @@ public class RedisConfiguration {
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
@NotEmpty
|
@NotEmpty
|
||||||
private String url;
|
private String uri;
|
||||||
|
|
||||||
@JsonProperty
|
|
||||||
@NotNull
|
|
||||||
private List<String> replicaUrls;
|
|
||||||
|
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
@NotNull
|
@NotNull
|
||||||
|
@ -31,12 +27,8 @@ public class RedisConfiguration {
|
||||||
@Valid
|
@Valid
|
||||||
private CircuitBreakerConfiguration circuitBreaker = new CircuitBreakerConfiguration();
|
private CircuitBreakerConfiguration circuitBreaker = new CircuitBreakerConfiguration();
|
||||||
|
|
||||||
public String getUrl() {
|
public String getUri() {
|
||||||
return url;
|
return uri;
|
||||||
}
|
|
||||||
|
|
||||||
public List<String> getReplicaUrls() {
|
|
||||||
return replicaUrls;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Duration getTimeout() {
|
public Duration getTimeout() {
|
||||||
|
|
|
@ -5,37 +5,157 @@
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.push;
|
package org.whispersystems.textsecuregcm.push;
|
||||||
|
|
||||||
import com.google.protobuf.ByteString;
|
|
||||||
import io.micrometer.core.instrument.Counter;
|
|
||||||
import io.micrometer.core.instrument.Metrics;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.PubSubManager;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
|
|
||||||
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
|
|
||||||
|
|
||||||
import static com.codahale.metrics.MetricRegistry.name;
|
import static com.codahale.metrics.MetricRegistry.name;
|
||||||
|
|
||||||
public class ProvisioningManager {
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
private final PubSubManager pubSubManager;
|
import com.google.protobuf.ByteString;
|
||||||
|
import com.google.protobuf.InvalidProtocolBufferException;
|
||||||
|
import io.dropwizard.lifecycle.Managed;
|
||||||
|
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
|
||||||
|
import io.lettuce.core.RedisClient;
|
||||||
|
import io.lettuce.core.api.StatefulRedisConnection;
|
||||||
|
import io.lettuce.core.codec.ByteArrayCodec;
|
||||||
|
import io.lettuce.core.pubsub.RedisPubSubAdapter;
|
||||||
|
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection;
|
||||||
|
import io.lettuce.core.resource.ClientResources;
|
||||||
|
import io.micrometer.core.instrument.Metrics;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
import io.micrometer.core.instrument.Tags;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
|
||||||
|
import org.whispersystems.textsecuregcm.redis.RedisOperation;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
|
||||||
|
import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil;
|
||||||
|
import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException;
|
||||||
|
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
|
||||||
|
|
||||||
private final Counter provisioningMessageOnlineCounter = Metrics.counter(name(getClass(), "sendProvisioningMessage"), "online", "true");
|
public class ProvisioningManager extends RedisPubSubAdapter<byte[], byte[]> implements Managed {
|
||||||
private final Counter provisioningMessageOfflineCounter = Metrics.counter(name(getClass(), "sendProvisioningMessage"), "online", "false");
|
|
||||||
|
|
||||||
public ProvisioningManager(final PubSubManager pubSubManager) {
|
private final RedisClient redisClient;
|
||||||
this.pubSubManager = pubSubManager;
|
private final StatefulRedisPubSubConnection<byte[], byte[]> subscriptionConnection;
|
||||||
|
private final StatefulRedisConnection<byte[], byte[]> publicationConnection;
|
||||||
|
|
||||||
|
private final CircuitBreaker circuitBreaker;
|
||||||
|
|
||||||
|
private final Map<ProvisioningAddress, Consumer<PubSubProtos.PubSubMessage>> listenersByProvisioningAddress =
|
||||||
|
new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
private static final String ACTIVE_LISTENERS_GAUGE_NAME = name(ProvisioningManager.class, "activeListeners");
|
||||||
|
|
||||||
|
private static final String SEND_PROVISIONING_MESSAGE_COUNTER_NAME =
|
||||||
|
name(ProvisioningManager.class, "sendProvisioningMessage");
|
||||||
|
|
||||||
|
private static final String RECEIVE_PROVISIONING_MESSAGE_COUNTER_NAME =
|
||||||
|
name(ProvisioningManager.class, "receiveProvisioningMessage");
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(ProvisioningManager.class);
|
||||||
|
|
||||||
|
public ProvisioningManager(final String redisUri,
|
||||||
|
final ClientResources clientResources,
|
||||||
|
final Duration timeout,
|
||||||
|
final CircuitBreakerConfiguration circuitBreakerConfiguration) {
|
||||||
|
|
||||||
|
this(RedisClient.create(clientResources, redisUri), timeout, circuitBreakerConfiguration);
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean sendProvisioningMessage(ProvisioningAddress address, byte[] body) {
|
@VisibleForTesting
|
||||||
PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.newBuilder()
|
ProvisioningManager(final RedisClient redisClient,
|
||||||
|
final Duration timeout,
|
||||||
|
final CircuitBreakerConfiguration circuitBreakerConfiguration) {
|
||||||
|
|
||||||
|
this.redisClient = redisClient;
|
||||||
|
this.redisClient.setDefaultTimeout(timeout);
|
||||||
|
|
||||||
|
this.subscriptionConnection = redisClient.connectPubSub(new ByteArrayCodec());
|
||||||
|
this.publicationConnection = redisClient.connect(new ByteArrayCodec());
|
||||||
|
|
||||||
|
this.circuitBreaker = CircuitBreaker.of("pubsub-breaker", circuitBreakerConfiguration.toCircuitBreakerConfig());
|
||||||
|
|
||||||
|
CircuitBreakerUtil.registerMetrics(circuitBreaker, ProvisioningManager.class);
|
||||||
|
|
||||||
|
Metrics.gaugeMapSize(ACTIVE_LISTENERS_GAUGE_NAME, Tags.empty(), listenersByProvisioningAddress);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void start() throws Exception {
|
||||||
|
subscriptionConnection.addListener(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void stop() throws Exception {
|
||||||
|
subscriptionConnection.removeListener(this);
|
||||||
|
|
||||||
|
subscriptionConnection.close();
|
||||||
|
publicationConnection.close();
|
||||||
|
|
||||||
|
redisClient.shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addListener(final ProvisioningAddress address, final Consumer<PubSubProtos.PubSubMessage> listener) {
|
||||||
|
listenersByProvisioningAddress.put(address, listener);
|
||||||
|
|
||||||
|
circuitBreaker.executeRunnable(
|
||||||
|
() -> subscriptionConnection.sync().subscribe(address.serialize().getBytes(StandardCharsets.UTF_8)));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void removeListener(final ProvisioningAddress address) {
|
||||||
|
RedisOperation.unchecked(() -> circuitBreaker.executeRunnable(
|
||||||
|
() -> subscriptionConnection.sync().unsubscribe(address.serialize().getBytes(StandardCharsets.UTF_8))));
|
||||||
|
|
||||||
|
listenersByProvisioningAddress.remove(address);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean sendProvisioningMessage(final ProvisioningAddress address, final byte[] body) {
|
||||||
|
final PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.newBuilder()
|
||||||
.setType(PubSubProtos.PubSubMessage.Type.DELIVER)
|
.setType(PubSubProtos.PubSubMessage.Type.DELIVER)
|
||||||
.setContent(ByteString.copyFrom(body))
|
.setContent(ByteString.copyFrom(body))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
if (pubSubManager.publish(address, pubSubMessage)) {
|
final boolean receiverPresent = circuitBreaker.executeSupplier(
|
||||||
provisioningMessageOnlineCounter.increment();
|
() -> publicationConnection.sync()
|
||||||
return true;
|
.publish(address.serialize().getBytes(StandardCharsets.UTF_8), pubSubMessage.toByteArray()) > 0);
|
||||||
} else {
|
|
||||||
provisioningMessageOfflineCounter.increment();
|
Metrics.counter(SEND_PROVISIONING_MESSAGE_COUNTER_NAME, "online", String.valueOf(receiverPresent)).increment();
|
||||||
return false;
|
|
||||||
|
return receiverPresent;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void message(final byte[] channel, final byte[] message) {
|
||||||
|
try {
|
||||||
|
final ProvisioningAddress address = new ProvisioningAddress(new String(channel, StandardCharsets.UTF_8));
|
||||||
|
final PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.parseFrom(message);
|
||||||
|
|
||||||
|
if (pubSubMessage.getType() == PubSubProtos.PubSubMessage.Type.DELIVER) {
|
||||||
|
final Consumer<PubSubProtos.PubSubMessage> listener = listenersByProvisioningAddress.get(address);
|
||||||
|
|
||||||
|
boolean listenerPresent = false;
|
||||||
|
|
||||||
|
if (listener != null) {
|
||||||
|
listenerPresent = true;
|
||||||
|
listener.accept(pubSubMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
Metrics.counter(RECEIVE_PROVISIONING_MESSAGE_COUNTER_NAME, "listenerPresent", String.valueOf(listenerPresent)).increment();
|
||||||
|
}
|
||||||
|
} catch (final InvalidWebsocketAddressException e) {
|
||||||
|
logger.warn("Failed to parse provisioning address", e);
|
||||||
|
} catch (final InvalidProtocolBufferException e) {
|
||||||
|
logger.warn("Failed to parse pub/sub message", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void unsubscribed(final byte[] channel, final long count) {
|
||||||
|
try {
|
||||||
|
listenersByProvisioningAddress.remove(new ProvisioningAddress(new String(channel)));
|
||||||
|
} catch (final InvalidWebsocketAddressException e) {
|
||||||
|
logger.warn("Failed to parse provisioning address for `unsubscribe` event", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,30 +5,41 @@
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.websocket;
|
package org.whispersystems.textsecuregcm.websocket;
|
||||||
|
|
||||||
import org.whispersystems.textsecuregcm.storage.PubSubManager;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||||
|
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
|
||||||
|
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||||
import org.whispersystems.websocket.session.WebSocketSessionContext;
|
import org.whispersystems.websocket.session.WebSocketSessionContext;
|
||||||
import org.whispersystems.websocket.setup.WebSocketConnectListener;
|
import org.whispersystems.websocket.setup.WebSocketConnectListener;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
public class ProvisioningConnectListener implements WebSocketConnectListener {
|
public class ProvisioningConnectListener implements WebSocketConnectListener {
|
||||||
|
|
||||||
private final PubSubManager pubSubManager;
|
private final ProvisioningManager provisioningManager;
|
||||||
|
|
||||||
public ProvisioningConnectListener(PubSubManager pubSubManager) {
|
public ProvisioningConnectListener(final ProvisioningManager provisioningManager) {
|
||||||
this.pubSubManager = pubSubManager;
|
this.provisioningManager = provisioningManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onWebSocketConnect(WebSocketSessionContext context) {
|
public void onWebSocketConnect(WebSocketSessionContext context) {
|
||||||
final ProvisioningConnection connection = new ProvisioningConnection(context.getClient());
|
|
||||||
final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate();
|
final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate();
|
||||||
|
context.addListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress));
|
||||||
|
|
||||||
pubSubManager.subscribe(provisioningAddress, connection);
|
provisioningManager.addListener(provisioningAddress, message -> {
|
||||||
|
assert message.getType() == PubSubProtos.PubSubMessage.Type.DELIVER;
|
||||||
|
|
||||||
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
|
final Optional<byte[]> body = Optional.of(message.getContent().toByteArray());
|
||||||
@Override
|
|
||||||
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
|
context.getClient().sendRequest("PUT", "/v1/message", List.of(HeaderUtils.getTimestampHeader()), body)
|
||||||
pubSubManager.unsubscribe(provisioningAddress, connection);
|
.whenComplete((ignored, throwable) -> context.getClient().close(1000, "Closed"));
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
context.getClient().sendRequest("PUT", "/v1/address", List.of(HeaderUtils.getTimestampHeader()),
|
||||||
|
Optional.of(MessageProtos.ProvisioningUuid.newBuilder()
|
||||||
|
.setUuid(provisioningAddress.getAddress())
|
||||||
|
.build()
|
||||||
|
.toByteArray()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,69 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2013-2020 Signal Messenger, LLC
|
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.websocket;
|
|
||||||
|
|
||||||
import com.google.protobuf.InvalidProtocolBufferException;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Optional;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
import org.whispersystems.dispatch.DispatchChannel;
|
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.ProvisioningUuid;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
|
|
||||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
|
||||||
import org.whispersystems.websocket.WebSocketClient;
|
|
||||||
|
|
||||||
public class ProvisioningConnection implements DispatchChannel {
|
|
||||||
|
|
||||||
private final Logger logger = LoggerFactory.getLogger(ProvisioningConnection.class);
|
|
||||||
|
|
||||||
private final WebSocketClient client;
|
|
||||||
|
|
||||||
public ProvisioningConnection(WebSocketClient client) {
|
|
||||||
this.client = client;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onDispatchMessage(String channel, byte[] message) {
|
|
||||||
try {
|
|
||||||
PubSubMessage outgoingMessage = PubSubMessage.parseFrom(message);
|
|
||||||
|
|
||||||
if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
|
|
||||||
Optional<byte[]> body = Optional.of(outgoingMessage.getContent().toByteArray());
|
|
||||||
|
|
||||||
client.sendRequest("PUT", "/v1/message", Collections.singletonList(HeaderUtils.getTimestampHeader()), body)
|
|
||||||
.thenAccept(response -> client.close(1001, "All you get."))
|
|
||||||
.exceptionally(throwable -> {
|
|
||||||
client.close(1001, "That's all!");
|
|
||||||
return null;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} catch (InvalidProtocolBufferException e) {
|
|
||||||
logger.warn("Protobuf Error: ", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onDispatchSubscribed(String channel) {
|
|
||||||
try {
|
|
||||||
ProvisioningAddress address = new ProvisioningAddress(channel);
|
|
||||||
this.client.sendRequest("PUT", "/v1/address", Collections.singletonList(HeaderUtils.getTimestampHeader()),
|
|
||||||
Optional.of(ProvisioningUuid.newBuilder()
|
|
||||||
.setUuid(address.getAddress())
|
|
||||||
.build()
|
|
||||||
.toByteArray()));
|
|
||||||
|
|
||||||
} catch (InvalidWebsocketAddressException e) {
|
|
||||||
logger.warn("Badly formatted address", e);
|
|
||||||
this.client.close(1001, "Server Error");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onDispatchUnsubscribed(String channel) {
|
|
||||||
this.client.close(1001, "Closed");
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
package org.whispersystems.textsecuregcm.push;
|
||||||
|
|
||||||
|
import com.google.protobuf.ByteString;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
|
||||||
|
import org.whispersystems.textsecuregcm.redis.RedisSingletonExtension;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
|
||||||
|
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.after;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.timeout;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
|
||||||
|
class ProvisioningManagerTest {
|
||||||
|
|
||||||
|
private ProvisioningManager provisioningManager;
|
||||||
|
|
||||||
|
@RegisterExtension
|
||||||
|
static final RedisSingletonExtension REDIS_EXTENSION = RedisSingletonExtension.builder().build();
|
||||||
|
|
||||||
|
private static final long PUBSUB_TIMEOUT_MILLIS = 1_000;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() throws Exception {
|
||||||
|
provisioningManager = new ProvisioningManager(REDIS_EXTENSION.getRedisClient(), Duration.ofSeconds(1), new CircuitBreakerConfiguration());
|
||||||
|
provisioningManager.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() throws Exception {
|
||||||
|
provisioningManager.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void sendProvisioningMessage() {
|
||||||
|
final ProvisioningAddress address = new ProvisioningAddress("address", 0);
|
||||||
|
|
||||||
|
final byte[] content = new byte[16];
|
||||||
|
new Random().nextBytes(content);
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked") final Consumer<PubSubProtos.PubSubMessage> subscribedConsumer = mock(Consumer.class);
|
||||||
|
|
||||||
|
provisioningManager.addListener(address, subscribedConsumer);
|
||||||
|
provisioningManager.sendProvisioningMessage(address, content);
|
||||||
|
|
||||||
|
final ArgumentCaptor<PubSubProtos.PubSubMessage> messageCaptor =
|
||||||
|
ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class);
|
||||||
|
|
||||||
|
verify(subscribedConsumer, timeout(PUBSUB_TIMEOUT_MILLIS)).accept(messageCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(PubSubProtos.PubSubMessage.Type.DELIVER, messageCaptor.getValue().getType());
|
||||||
|
assertEquals(ByteString.copyFrom(content), messageCaptor.getValue().getContent());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void removeListener() {
|
||||||
|
final ProvisioningAddress address = new ProvisioningAddress("address", 0);
|
||||||
|
|
||||||
|
final byte[] content = new byte[16];
|
||||||
|
new Random().nextBytes(content);
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked") final Consumer<PubSubProtos.PubSubMessage> subscribedConsumer = mock(Consumer.class);
|
||||||
|
|
||||||
|
provisioningManager.addListener(address, subscribedConsumer);
|
||||||
|
provisioningManager.removeListener(address);
|
||||||
|
provisioningManager.sendProvisioningMessage(address, content);
|
||||||
|
|
||||||
|
// Make sure that we give the message enough time to show up (if it was going to) before declaring victory
|
||||||
|
verify(subscribedConsumer, after(PUBSUB_TIMEOUT_MILLIS).never()).accept(any());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,85 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2013-2022 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.redis;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assumptions.assumeFalse;
|
||||||
|
|
||||||
|
import io.lettuce.core.RedisClient;
|
||||||
|
import io.lettuce.core.api.StatefulRedisConnection;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.ServerSocket;
|
||||||
|
import org.junit.jupiter.api.extension.AfterAllCallback;
|
||||||
|
import org.junit.jupiter.api.extension.AfterEachCallback;
|
||||||
|
import org.junit.jupiter.api.extension.BeforeAllCallback;
|
||||||
|
import org.junit.jupiter.api.extension.BeforeEachCallback;
|
||||||
|
import org.junit.jupiter.api.extension.ExtensionContext;
|
||||||
|
import redis.embedded.RedisServer;
|
||||||
|
|
||||||
|
public class RedisSingletonExtension implements BeforeAllCallback, BeforeEachCallback, AfterAllCallback, AfterEachCallback {
|
||||||
|
|
||||||
|
private static RedisServer redisServer;
|
||||||
|
private RedisClient redisClient;
|
||||||
|
|
||||||
|
public static class RedisSingletonExtensionBuilder {
|
||||||
|
|
||||||
|
private RedisSingletonExtensionBuilder() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public RedisSingletonExtension build() {
|
||||||
|
return new RedisSingletonExtension();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RedisSingletonExtensionBuilder builder() {
|
||||||
|
return new RedisSingletonExtensionBuilder();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void beforeAll(final ExtensionContext context) throws Exception {
|
||||||
|
assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows"));
|
||||||
|
|
||||||
|
redisServer = RedisServer.builder()
|
||||||
|
.setting("appendonly no")
|
||||||
|
.setting("save \"\"")
|
||||||
|
.setting("dir " + System.getProperty("java.io.tmpdir"))
|
||||||
|
.port(getAvailablePort())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
redisServer.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void beforeEach(final ExtensionContext context) {
|
||||||
|
redisClient = RedisClient.create(String.format("redis://127.0.0.1:%d", redisServer.ports().get(0)));
|
||||||
|
|
||||||
|
try (final StatefulRedisConnection<String, String> connection = redisClient.connect()) {
|
||||||
|
connection.sync().flushall();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void afterEach(final ExtensionContext context) {
|
||||||
|
redisClient.shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void afterAll(final ExtensionContext context) {
|
||||||
|
if (redisServer != null) {
|
||||||
|
redisServer.stop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public RedisClient getRedisClient() {
|
||||||
|
return redisClient;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int getAvailablePort() throws IOException {
|
||||||
|
try (ServerSocket socket = new ServerSocket(0)) {
|
||||||
|
socket.setReuseAddress(false);
|
||||||
|
return socket.getLocalPort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue