diff --git a/pom.xml b/pom.xml index 94fef0be6..cc7aa26fc 100644 --- a/pom.xml +++ b/pom.xml @@ -57,7 +57,7 @@ 2.9.0 1.7.10 1.4.0 - 6.1.9.RELEASE + 6.2.0.RELEASE 8.12.54 7.0.1 1.9.3 @@ -151,6 +151,13 @@ pom import + + io.projectreactor + reactor-bom + 2020.0.23 + pom + import + com.eatthepath pushy diff --git a/service/pom.xml b/service/pom.xml index 305e1f53b..e03e640d1 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -228,6 +228,10 @@ io.github.resilience4j resilience4j-retry + + io.github.resilience4j + resilience4j-reactor + io.grpc @@ -407,7 +411,6 @@ io.projectreactor reactor-core - 3.3.22.RELEASE io.vavr @@ -420,6 +423,11 @@ test + + io.projectreactor + reactor-test + + org.signal embedded-redis diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 5967a0ea0..30a109750 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -116,7 +116,6 @@ import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; import org.whispersystems.textsecuregcm.limits.DynamicRateLimiters; import org.whispersystems.textsecuregcm.limits.PushChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; -import org.whispersystems.textsecuregcm.limits.RateLimitChallengeOptionManager; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; @@ -330,6 +329,13 @@ public class WhisperServerService extends Application messageDeletionQueue = new ArrayBlockingQueue<>(10_000); + Metrics.gaugeCollectionSize(name(getClass(), "messageDeletionQueueSize"), Collections.emptyList(), + messageDeletionQueue); + ExecutorService messageDeletionAsyncExecutor = environment.lifecycle() + .executorService(name(getClass(), "messageDeletionAsyncExecutor-%d")).maxThreads(16) + .workQueue(messageDeletionQueue).build(); + Accounts accounts = new Accounts(dynamicConfigurationManager, dynamoDbClient, dynamoDbAsyncClient, @@ -345,9 +351,10 @@ public class WhisperServerService extends Application { + public CompletableFuture removePendingMessage(@Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) { + return messagesManager.delete( + auth.getAccount().getUuid(), + auth.getAuthenticatedDevice().getId(), + uuid, + null) + .thenAccept(maybeDeletedMessage -> { + maybeDeletedMessage.ifPresent(deletedMessage -> { - WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getTimestamp(), auth.getAuthenticatedDevice()); + WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getTimestamp(), + auth.getAuthenticatedDevice()); - if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) { - try { - receiptSender.sendReceipt( - UUID.fromString(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(), - UUID.fromString(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp()); - } catch (Exception e) { - logger.warn("Failed to send delivery receipt", e); - } - } - }); + if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) { + try { + receiptSender.sendReceipt( + UUID.fromString(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(), + UUID.fromString(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp()); + } catch (Exception e) { + logger.warn("Failed to send delivery receipt", e); + } + } + }); + }); } @Timed diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java index dc187cfed..3b3e70898 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/ClusterLuaScript.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.redis; import com.google.common.annotations.VisibleForTesting; +import io.lettuce.core.RedisException; import io.lettuce.core.RedisNoScriptException; import io.lettuce.core.ScriptOutputType; import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; @@ -15,9 +16,12 @@ import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.List; +import java.util.concurrent.CompletableFuture; import org.apache.commons.codec.binary.Hex; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; public class ClusterLuaScript { @@ -73,11 +77,31 @@ public class ClusterLuaScript { execute(connection, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY))); } + public CompletableFuture executeAsync(final List keys, final List args) { + return redisCluster.withCluster(connection -> + executeAsync(connection, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY))); + } + + public Flux executeReactive(final List keys, final List args) { + return redisCluster.withCluster(connection -> + executeReactive(connection, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY))); + } + public Object executeBinary(final List keys, final List args) { return redisCluster.withBinaryCluster(connection -> execute(connection, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY))); } + public CompletableFuture executeBinaryAsync(final List keys, final List args) { + return redisCluster.withBinaryCluster(connection -> + executeAsync(connection, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY))); + } + + public Flux executeBinaryReactive(final List keys, final List args) { + return redisCluster.withBinaryCluster(connection -> + executeReactive(connection, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY))); + } + private Object execute(final StatefulRedisClusterConnection connection, final T[] keys, final T[] args) { try { try { @@ -90,4 +114,32 @@ public class ClusterLuaScript { throw e; } } + + private CompletableFuture executeAsync(final StatefulRedisClusterConnection connection, + final T[] keys, final T[] args) { + + return connection.async().evalsha(sha, scriptOutputType, keys, args) + .exceptionallyCompose(throwable -> { + if (throwable instanceof RedisNoScriptException) { + return connection.async().eval(script, scriptOutputType, keys, args); + } + + log.warn("Failed to execute script", throwable); + throw new RedisException(throwable); + }).toCompletableFuture(); + } + + private Flux executeReactive(final StatefulRedisClusterConnection connection, + final T[] keys, final T[] args) { + + return connection.reactive().evalsha(sha, scriptOutputType, keys, args) + .onErrorResume(e -> { + if (e instanceof RedisNoScriptException) { + return connection.reactive().eval(script, scriptOutputType, keys, args); + } + + log.warn("Failed to execute script", e); + return Mono.error(e); + }); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java index f36f71478..46d9abf50 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantRedisCluster.java @@ -8,6 +8,8 @@ package org.whispersystems.textsecuregcm.redis; import com.codahale.metrics.SharedMetricRegistries; import com.google.common.annotations.VisibleForTesting; import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.github.resilience4j.reactor.circuitbreaker.operator.CircuitBreakerOperator; +import io.github.resilience4j.reactor.retry.RetryOperator; import io.github.resilience4j.retry.Retry; import io.lettuce.core.ClientOptions.DisconnectedBehavior; import io.lettuce.core.RedisCommandTimeoutException; @@ -24,11 +26,13 @@ import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; import java.util.function.Function; +import org.reactivestreams.Publisher; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; import org.whispersystems.textsecuregcm.util.Constants; +import reactor.core.publisher.Flux; /** * A fault-tolerant access manager for a Redis cluster. A fault-tolerant Redis cluster provides managed, @@ -81,64 +85,79 @@ public class FaultTolerantRedisCluster { } void shutdown() { - stringConnection.close(); - binaryConnection.close(); + stringConnection.close(); + binaryConnection.close(); - for (final StatefulRedisClusterPubSubConnection pubSubConnection : pubSubConnections) { - pubSubConnection.close(); - } + for (final StatefulRedisClusterPubSubConnection pubSubConnection : pubSubConnections) { + pubSubConnection.close(); + } - clusterClient.shutdown(); + clusterClient.shutdown(); } - public String getName() { - return name; - } + public String getName() { + return name; + } - public void useCluster(final Consumer> consumer) { - useConnection(stringConnection, consumer); - } + public void useCluster(final Consumer> consumer) { + useConnection(stringConnection, consumer); + } - public T withCluster(final Function, T> function) { - return withConnection(stringConnection, function); - } + public T withCluster(final Function, T> function) { + return withConnection(stringConnection, function); + } - public void useBinaryCluster(final Consumer> consumer) { - useConnection(binaryConnection, consumer); - } + public void useBinaryCluster(final Consumer> consumer) { + useConnection(binaryConnection, consumer); + } - public T withBinaryCluster(final Function, T> function) { - return withConnection(binaryConnection, function); - } + public T withBinaryCluster(final Function, T> function) { + return withConnection(binaryConnection, function); + } - private void useConnection(final StatefulRedisClusterConnection connection, final Consumer> consumer) { - try { - circuitBreaker.executeCheckedRunnable(() -> retry.executeRunnable(() -> consumer.accept(connection))); - } catch (final Throwable t) { - if (t instanceof RedisException) { - throw (RedisException) t; - } else { - throw new RedisException(t); - } - } - } + public Publisher withBinaryClusterReactive( + final Function, Publisher> function) { + return withConnectionReactive(binaryConnection, function); + } - private T withConnection(final StatefulRedisClusterConnection connection, final Function, T> function) { - try { - return circuitBreaker.executeCheckedSupplier(() -> retry.executeCallable(() -> function.apply(connection))); - } catch (final Throwable t) { - if (t instanceof RedisException) { - throw (RedisException) t; - } else { - throw new RedisException(t); - } - } + private void useConnection(final StatefulRedisClusterConnection connection, + final Consumer> consumer) { + try { + circuitBreaker.executeCheckedRunnable(() -> retry.executeRunnable(() -> consumer.accept(connection))); + } catch (final Throwable t) { + if (t instanceof RedisException) { + throw (RedisException) t; + } else { + throw new RedisException(t); + } } + } - public FaultTolerantPubSubConnection createPubSubConnection() { - final StatefulRedisClusterPubSubConnection pubSubConnection = clusterClient.connectPubSub(); - pubSubConnections.add(pubSubConnection); - - return new FaultTolerantPubSubConnection<>(name, pubSubConnection, circuitBreaker, retry); + private T withConnection(final StatefulRedisClusterConnection connection, + final Function, T> function) { + try { + return circuitBreaker.executeCheckedSupplier(() -> retry.executeCallable(() -> function.apply(connection))); + } catch (final Throwable t) { + if (t instanceof RedisException) { + throw (RedisException) t; + } else { + throw new RedisException(t); + } } + } + + private Publisher withConnectionReactive(final StatefulRedisClusterConnection connection, + final Function, Publisher> function) { + + return Flux.from(function.apply(connection)) + .transformDeferred(RetryOperator.of(retry)) + .transformDeferred(CircuitBreakerOperator.of(circuitBreaker)); + } + + public FaultTolerantPubSubConnection createPubSubConnection() { + final StatefulRedisClusterPubSubConnection pubSubConnection = clusterClient.connectPubSub(); + pubSubConnections.add(pubSubConnection); + + return new FaultTolerantPubSubConnection<>(name, pubSubConnection, circuitBreaker, retry); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java index 4d48b2f5f..8325a90df 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java @@ -26,7 +26,7 @@ import software.amazon.awssdk.services.dynamodb.model.BatchWriteItemResponse; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.WriteRequest; -public class AbstractDynamoDbStore { +public abstract class AbstractDynamoDbStore { private final DynamoDbClient dynamoDbClient; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 195813e22..fde41c5d7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2021 Signal Messenger, LLC + * Copyright 2013-2022 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -22,6 +22,7 @@ import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; @@ -34,23 +35,32 @@ import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.function.Predicate; import java.util.stream.Collectors; import javax.annotation.Nullable; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.RedisClusterUtil; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; public class MessagesCache extends RedisClusterPubSubAdapter implements Managed { private final FaultTolerantRedisCluster readDeleteCluster; private final FaultTolerantPubSubConnection pubSubConnection; + private final Clock clock; private final ExecutorService notificationExecutorService; + private final ExecutorService messageDeletionExecutorService; private final ClusterLuaScript insertScript; private final ClusterLuaScript removeByGuidScript; @@ -79,22 +89,23 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp private static final String QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue::"; private static final String PERSISTING_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_persisting::"; - private static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10); + @VisibleForTesting + static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10); - private static final String REMOVE_TIMER_NAME = name(MessagesCache.class, "remove"); - - private static final String REMOVE_METHOD_TAG = "method"; - private static final String REMOVE_METHOD_UUID = "uuid"; + private static final int PAGE_SIZE = 100; private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class); public MessagesCache(final FaultTolerantRedisCluster insertCluster, final FaultTolerantRedisCluster readDeleteCluster, - final ExecutorService notificationExecutorService) throws IOException { + final Clock clock, final ExecutorService notificationExecutorService, + final ExecutorService messageDeletionExecutorService) throws IOException { this.readDeleteCluster = readDeleteCluster; this.pubSubConnection = readDeleteCluster.createPubSubConnection(); + this.clock = clock; this.notificationExecutorService = notificationExecutorService; + this.messageDeletionExecutorService = messageDeletionExecutorService; this.insertScript = ClusterLuaScript.fromResource(insertCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); this.removeByGuidScript = ClusterLuaScript.fromResource(readDeleteCluster, "lua/remove_item_by_guid.lua", @@ -147,33 +158,39 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp guid.toString().getBytes(StandardCharsets.UTF_8)))); } - public Optional remove(final UUID destinationUuid, final long destinationDevice, + public CompletableFuture> remove(final UUID destinationUuid, + final long destinationDevice, final UUID messageGuid) { - return remove(destinationUuid, destinationDevice, List.of(messageGuid)).stream().findFirst(); + + return remove(destinationUuid, destinationDevice, List.of(messageGuid)) + .thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.get(0))); } @SuppressWarnings("unchecked") - public List remove(final UUID destinationUuid, final long destinationDevice, + public CompletableFuture> remove(final UUID destinationUuid, + final long destinationDevice, final List messageGuids) { - final List serialized = (List) Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, - REMOVE_METHOD_UUID).record(() -> - removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), + + return removeByGuidScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, destinationDevice), getMessageQueueMetadataKey(destinationUuid, destinationDevice), getQueueIndexKey(destinationUuid, destinationDevice)), messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8)) - .collect(Collectors.toList()))); + .collect(Collectors.toList())) + .thenApplyAsync(result -> { + List serialized = (List) result; - final List removedMessages = new ArrayList<>(serialized.size()); + final List removedMessages = new ArrayList<>(serialized.size()); - for (final byte[] bytes : serialized) { - try { - removedMessages.add(MessageProtos.Envelope.parseFrom(bytes)); - } catch (final InvalidProtocolBufferException e) { - logger.warn("Failed to parse envelope", e); - } - } + for (final byte[] bytes : serialized) { + try { + removedMessages.add(MessageProtos.Envelope.parseFrom(bytes)); + } catch (final InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + } - return removedMessages; + return removedMessages; + }, messageDeletionExecutorService); } public boolean hasMessages(final UUID destinationUuid, final long destinationDevice) { @@ -181,50 +198,110 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp connection -> connection.sync().zcard(getMessageQueueKey(destinationUuid, destinationDevice)) > 0); } - @SuppressWarnings("unchecked") - public List get(final UUID destinationUuid, final long destinationDevice, final int limit) { - return getMessagesTimer.record(() -> { - final List queueItems = (List) getItemsScript.executeBinary( - List.of(getMessageQueueKey(destinationUuid, destinationDevice), - getPersistInProgressKey(destinationUuid, destinationDevice)), - List.of(String.valueOf(limit).getBytes(StandardCharsets.UTF_8))); + public Publisher get(final UUID destinationUuid, final long destinationDevice) { - final long earliestAllowableEphemeralTimestamp = - System.currentTimeMillis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis(); + final long earliestAllowableEphemeralTimestamp = + clock.millis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis(); - final List messageEntities; - final List staleEphemeralMessageGuids = new ArrayList<>(); + final Flux allMessages = getAllMessages(destinationUuid, destinationDevice) + .publish() + // We expect exactly two subscribers to this base flux: + // 1. the websocket that delivers messages to clients + // 2. an internal process to discard stale ephemeral messages + // The discard subscriber will subscribe immediately, but we don’t want to do any work if the + // websocket never subscribes. + .autoConnect(2); - if (queueItems.size() % 2 == 0) { - messageEntities = new ArrayList<>(queueItems.size() / 2); + final Flux messagesToPublish = allMessages + .filter(Predicate.not(envelope -> isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp))); - for (int i = 0; i < queueItems.size() - 1; i += 2) { - try { - final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i)); - if (message.getEphemeral() && message.getTimestamp() < earliestAllowableEphemeralTimestamp) { - staleEphemeralMessageGuids.add(UUID.fromString(message.getServerGuid())); - continue; - } + final Flux staleEphemeralMessages = allMessages + .filter(envelope -> isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp)); - messageEntities.add(message); - } catch (InvalidProtocolBufferException e) { - logger.warn("Failed to parse envelope", e); + discardStaleEphemeralMessages(destinationUuid, destinationDevice, staleEphemeralMessages); + + return messagesToPublish; + } + + private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message, + long earliestAllowableTimestamp) { + return message.hasEphemeral() && message.getEphemeral() && message.getTimestamp() < earliestAllowableTimestamp; + } + + private void discardStaleEphemeralMessages(final UUID destinationUuid, final long destinationDevice, + Flux staleEphemeralMessages) { + staleEphemeralMessages + .map(e -> UUID.fromString(e.getServerGuid())) + .buffer(PAGE_SIZE) + .subscribeOn(Schedulers.boundedElastic()) + .subscribe(staleEphemeralMessageGuids -> + remove(destinationUuid, destinationDevice, staleEphemeralMessageGuids) + .thenAccept(removedMessages -> staleEphemeralMessagesCounter.increment(removedMessages.size())), + e -> logger.warn("Could not remove stale ephemeral messages from cache", e)); + } + + @VisibleForTesting + Flux getAllMessages(final UUID destinationUuid, final long destinationDevice) { + + // fetch messages by page + return getNextMessagePage(destinationUuid, destinationDevice, -1) + .expand(queueItemsAndLastMessageId -> { + // expand() is breadth-first, so each page will be published in order + if (queueItemsAndLastMessageId.first().isEmpty()) { + return Mono.empty(); } - } - } else { - logger.error("\"Get messages\" operation returned a list with a non-even number of elements."); - messageEntities = Collections.emptyList(); - } - try { - remove(destinationUuid, destinationDevice, staleEphemeralMessageGuids); - staleEphemeralMessagesCounter.increment(staleEphemeralMessageGuids.size()); - } catch (final Throwable e) { - logger.warn("Could not remove stale ephemeral messages from cache", e); - } + return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second()); + }) + .limitRate(1) + // we want to ensure we don’t accidentally block the Lettuce/netty i/o executors + .publishOn(Schedulers.boundedElastic()) + .map(Pair::first) + .flatMapIterable(queueItems -> { + final List envelopes = new ArrayList<>(queueItems.size() / 2); - return messageEntities; - }); + for (int i = 0; i < queueItems.size() - 1; i += 2) { + try { + final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i)); + + envelopes.add(message); + } catch (InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + } + + return envelopes; + }); + } + + private Flux, Long>> getNextMessagePage(final UUID destinationUuid, final long destinationDevice, + long messageId) { + + return getItemsScript.executeBinaryReactive( + List.of(getMessageQueueKey(destinationUuid, destinationDevice), + getPersistInProgressKey(destinationUuid, destinationDevice)), + List.of(String.valueOf(PAGE_SIZE).getBytes(StandardCharsets.UTF_8), + String.valueOf(messageId).getBytes(StandardCharsets.UTF_8))) + .map(result -> { + logger.trace("Processing page: {}", messageId); + + @SuppressWarnings("unchecked") + List queueItems = (List) result; + + if (queueItems.isEmpty()) { + return new Pair<>(Collections.emptyList(), null); + } + + if (queueItems.size() % 2 != 0) { + logger.error("\"Get messages\" operation returned a list with a non-even number of elements."); + return new Pair<>(Collections.emptyList(), null); + } + + final long lastMessageId = Long.parseLong( + new String(queueItems.get(queueItems.size() - 1), StandardCharsets.UTF_8)); + + return new Pair<>(queueItems, lastMessageId); + }); } @VisibleForTesting diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java index c19621539..6135b5043 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java @@ -1,5 +1,5 @@ /* - * Copyright 2021 Signal Messenger, LLC + * Copyright 2021-2022 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -17,19 +17,24 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.UUID; -import java.util.stream.Collectors; -import javax.annotation.Nonnull; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.function.Predicate; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.util.AttributeValues; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; -import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; import software.amazon.awssdk.services.dynamodb.model.DeleteRequest; import software.amazon.awssdk.services.dynamodb.model.PutRequest; import software.amazon.awssdk.services.dynamodb.model.QueryRequest; @@ -48,22 +53,25 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { private static final String KEY_ENVELOPE_BYTES = "EB"; private final Timer storeTimer = timer(name(getClass(), "store")); - private final Timer loadTimer = timer(name(getClass(), "load")); - private final Timer deleteByGuid = timer(name(getClass(), "delete", "guid")); - private final Timer deleteByKey = timer(name(getClass(), "delete", "key")); private final Timer deleteByAccount = timer(name(getClass(), "delete", "account")); private final Timer deleteByDevice = timer(name(getClass(), "delete", "device")); + private final DynamoDbAsyncClient dbAsyncClient; private final String tableName; private final Duration timeToLive; + private final ExecutorService messageDeletionExecutor; private static final Logger logger = LoggerFactory.getLogger(MessagesDynamoDb.class); - public MessagesDynamoDb(DynamoDbClient dynamoDb, String tableName, Duration timeToLive) { + public MessagesDynamoDb(DynamoDbClient dynamoDb, DynamoDbAsyncClient dynamoDbAsyncClient, String tableName, + Duration timeToLive, ExecutorService messageDeletionExecutor) { super(dynamoDb); + this.dbAsyncClient = dynamoDbAsyncClient; this.tableName = tableName; this.timeToLive = timeToLive; + + this.messageDeletionExecutor = messageDeletionExecutor; } public void store(final List messages, final UUID destinationAccountUuid, final long destinationDeviceId) { @@ -95,105 +103,105 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { executeTableWriteItemsUntilComplete(Map.of(tableName, writeItems)); } - public List load(final UUID destinationAccountUuid, final long destinationDeviceId, final int requestedNumberOfMessagesToFetch) { - return loadTimer.record(() -> { - final int numberOfMessagesToFetch = Math.min(requestedNumberOfMessagesToFetch, RESULT_SET_CHUNK_SIZE); - final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); - final QueryRequest queryRequest = QueryRequest.builder() - .tableName(tableName) - .consistentRead(true) - .keyConditionExpression("#part = :part AND begins_with ( #sort , :sortprefix )") - .expressionAttributeNames(Map.of( - "#part", KEY_PARTITION, - "#sort", KEY_SORT)) - .expressionAttributeValues(Map.of( - ":part", partitionKey, - ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId))) - .limit(numberOfMessagesToFetch) - .build(); - List messageEntities = new ArrayList<>(numberOfMessagesToFetch); - for (Map message : db().queryPaginator(queryRequest).items()) { - try { - messageEntities.add(convertItemToEnvelope(message)); - } catch (final InvalidProtocolBufferException e) { - logger.error("Failed to parse envelope", e); - } + public Publisher load(final UUID destinationAccountUuid, final long destinationDeviceId, + final Integer limit) { - if (messageEntities.size() == numberOfMessagesToFetch) { - // queryPaginator() uses limit() as the page size, not as an absolute limit - // …but a page might be smaller than limit, because a page is capped at 1 MB - break; - } - } - return messageEntities; - }); - } + final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); + final QueryRequest.Builder queryRequestBuilder = QueryRequest.builder() + .tableName(tableName) + .consistentRead(true) + .keyConditionExpression("#part = :part AND begins_with ( #sort , :sortprefix )") + .expressionAttributeNames(Map.of( + "#part", KEY_PARTITION, + "#sort", KEY_SORT)) + .expressionAttributeValues(Map.of( + ":part", partitionKey, + ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId))); - public Optional deleteMessageByDestinationAndGuid(final UUID destinationAccountUuid, - final UUID messageUuid) { - return deleteByGuid.record(() -> { - final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); - final QueryRequest queryRequest = QueryRequest.builder() - .tableName(tableName) - .indexName(LOCAL_INDEX_MESSAGE_UUID_NAME) - .projectionExpression(KEY_SORT) - .consistentRead(true) - .keyConditionExpression("#part = :part AND #uuid = :uuid") - .expressionAttributeNames(Map.of( - "#part", KEY_PARTITION, - "#uuid", LOCAL_INDEX_MESSAGE_UUID_KEY_SORT)) - .expressionAttributeValues(Map.of( - ":part", partitionKey, - ":uuid", convertLocalIndexMessageUuidSortKey(messageUuid))) - .build(); - return deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(partitionKey, queryRequest); - }); - } - - public Optional deleteMessage(final UUID destinationAccountUuid, - final long destinationDeviceId, final UUID messageUuid, final long serverTimestamp) { - return deleteByKey.record(() -> { - final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); - final AttributeValue sortKey = convertSortKey(destinationDeviceId, serverTimestamp, messageUuid); - DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder() - .tableName(tableName) - .key(Map.of(KEY_PARTITION, partitionKey, KEY_SORT, sortKey)) - .returnValues(ReturnValue.ALL_OLD); - final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build()); - if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { - try { - return Optional.of(convertItemToEnvelope(deleteItemResponse.attributes())); - } catch (final InvalidProtocolBufferException e) { - logger.error("Failed to parse envelope", e); - return Optional.empty(); - } - } - - return Optional.empty(); - }); - } - - @Nonnull - private Optional deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(AttributeValue partitionKey, QueryRequest queryRequest) { - Optional result = Optional.empty(); - for (Map item : db().queryPaginator(queryRequest).items()) { - final byte[] rangeKeyValue = item.get(KEY_SORT).b().asByteArray(); - DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder() - .tableName(tableName) - .key(Map.of(KEY_PARTITION, partitionKey, KEY_SORT, AttributeValues.fromByteArray(rangeKeyValue))); - if (result.isEmpty()) { - deleteItemRequest.returnValues(ReturnValue.ALL_OLD); - } - final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build()); - if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { - try { - result = Optional.of(convertItemToEnvelope(deleteItemResponse.attributes())); - } catch (final InvalidProtocolBufferException e) { - logger.error("Failed to parse envelope", e); - } - } + if (limit != null) { + // some callers don’t take advantage of reactive streams, so we want to support limiting the fetch size. Otherwise, + // we could fetch up to 1 MB (likely >1,000 messages) and discard 90% of them + queryRequestBuilder.limit(Math.min(RESULT_SET_CHUNK_SIZE, limit)); } - return result; + + final QueryRequest queryRequest = queryRequestBuilder.build(); + + return dbAsyncClient.queryPaginator(queryRequest).items() + .map(message -> { + try { + return convertItemToEnvelope(message); + } catch (final InvalidProtocolBufferException e) { + logger.error("Failed to parse envelope", e); + return null; + } + }) + .filter(Predicate.not(Objects::isNull)); + } + + public CompletableFuture> deleteMessageByDestinationAndGuid( + final UUID destinationAccountUuid, final UUID messageUuid) { + + final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); + final QueryRequest queryRequest = QueryRequest.builder() + .tableName(tableName) + .indexName(LOCAL_INDEX_MESSAGE_UUID_NAME) + .projectionExpression(KEY_SORT) + .consistentRead(true) + .keyConditionExpression("#part = :part AND #uuid = :uuid") + .expressionAttributeNames(Map.of( + "#part", KEY_PARTITION, + "#uuid", LOCAL_INDEX_MESSAGE_UUID_KEY_SORT)) + .expressionAttributeValues(Map.of( + ":part", partitionKey, + ":uuid", convertLocalIndexMessageUuidSortKey(messageUuid))) + .build(); + + // because we are filtering on message UUID, this query should return at most one item, + // but it’s simpler to handle the full stream and return the “last” item + return Flux.from(dbAsyncClient.queryPaginator(queryRequest).items()) + .flatMap(item -> Mono.fromCompletionStage(dbAsyncClient.deleteItem(DeleteItemRequest.builder() + .tableName(tableName) + .key(Map.of(KEY_PARTITION, partitionKey, KEY_SORT, + AttributeValues.fromByteArray(item.get(KEY_SORT).b().asByteArray()))) + .returnValues(ReturnValue.ALL_OLD) + .build()))) + .mapNotNull(deleteItemResponse -> { + try { + if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { + return convertItemToEnvelope(deleteItemResponse.attributes()); + } + } catch (final InvalidProtocolBufferException e) { + logger.error("Failed to parse envelope", e); + } + return null; + }) + .last() + .toFuture() + .thenApply(Optional::ofNullable); + } + + public CompletableFuture> deleteMessage(final UUID destinationAccountUuid, + final long destinationDeviceId, final UUID messageUuid, final long serverTimestamp) { + + final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); + final AttributeValue sortKey = convertSortKey(destinationDeviceId, serverTimestamp, messageUuid); + DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder() + .tableName(tableName) + .key(Map.of(KEY_PARTITION, partitionKey, KEY_SORT, sortKey)) + .returnValues(ReturnValue.ALL_OLD); + + return dbAsyncClient.deleteItem(deleteItemRequest.build()) + .thenApplyAsync(deleteItemResponse -> { + if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { + try { + return Optional.of(convertItemToEnvelope(deleteItemResponse.attributes())); + } catch (final InvalidProtocolBufferException e) { + logger.error("Failed to parse envelope", e); + } + } + + return Optional.empty(); + }, messageDeletionExecutor); } public void deleteAllMessagesForAccount(final UUID destinationAccountUuid) { @@ -248,7 +256,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { KEY_PARTITION, partitionKey, KEY_SORT, item.get(KEY_SORT))).build()) .build()) - .collect(Collectors.toList()); + .toList(); executeTableWriteItemsUntilComplete(Map.of(tableName, deletes)); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 70bfb45aa..a20579ba5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2021 Signal Messenger, LLC + * Copyright 2013-2022 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.storage; @@ -9,19 +9,30 @@ import static com.codahale.metrics.MetricRegistry.name; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; -import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Pair; +import reactor.core.publisher.Flux; public class MessagesManager { private static final int RESULT_SET_CHUNK_SIZE = 100; + private static final Logger logger = LoggerFactory.getLogger(MessagesManager.class); + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private static final Meter cacheHitByGuidMeter = metricRegistry.meter(name(MessagesManager.class, "cacheHitByGuid")); private static final Meter cacheMissByGuidMeter = metricRegistry.meter( @@ -55,18 +66,32 @@ public class MessagesManager { return messagesCache.hasMessages(destinationUuid, destinationDevice); } - public Pair, Boolean> getMessagesForDevice(UUID destinationUuid, long destinationDevice, final boolean cachedMessagesOnly) { - List messageList = new ArrayList<>(); + public Pair, Boolean> getMessagesForDevice(UUID destinationUuid, long destinationDevice, + boolean cachedMessagesOnly) { - if (!cachedMessagesOnly) { - messageList.addAll(messagesDynamoDb.load(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE)); - } + final List envelopes = Flux.from( + getMessagesForDevice(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE, cachedMessagesOnly)) + .take(RESULT_SET_CHUNK_SIZE, true) + .collectList() + .blockOptional().orElse(Collections.emptyList()); - if (messageList.size() < RESULT_SET_CHUNK_SIZE) { - messageList.addAll(messagesCache.get(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE - messageList.size())); - } + return new Pair<>(envelopes, envelopes.size() >= RESULT_SET_CHUNK_SIZE); + } - return new Pair<>(messageList, messageList.size() >= RESULT_SET_CHUNK_SIZE); + public Publisher getMessagesForDeviceReactive(UUID destinationUuid, long destinationDevice, + final boolean cachedMessagesOnly) { + + return getMessagesForDevice(destinationUuid, destinationDevice, null, cachedMessagesOnly); + } + + private Publisher getMessagesForDevice(UUID destinationUuid, long destinationDevice, + @Nullable Integer limit, final boolean cachedMessagesOnly) { + + final Publisher dynamoPublisher = + cachedMessagesOnly ? Flux.empty() : messagesDynamoDb.load(destinationUuid, destinationDevice, limit); + final Publisher cachePublisher = messagesCache.get(destinationUuid, destinationDevice); + + return Flux.concat(dynamoPublisher, cachePublisher); } public void clear(UUID destinationUuid) { @@ -79,21 +104,25 @@ public class MessagesManager { messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, deviceId); } - public Optional delete(UUID destinationUuid, long destinationDeviceId, UUID guid, Long serverTimestamp) { - Optional removed = messagesCache.remove(destinationUuid, destinationDeviceId, guid); + public CompletableFuture> delete(UUID destinationUuid, long destinationDeviceId, UUID guid, + @Nullable Long serverTimestamp) { + return messagesCache.remove(destinationUuid, destinationDeviceId, guid) + .thenCompose(removed -> { - if (removed.isEmpty()) { - if (serverTimestamp == null) { - removed = messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, guid); - } else { - removed = messagesDynamoDb.deleteMessage(destinationUuid, destinationDeviceId, guid, serverTimestamp); - } - cacheMissByGuidMeter.mark(); - } else { - cacheHitByGuidMeter.mark(); - } + if (removed.isPresent()) { + cacheHitByGuidMeter.mark(); + return CompletableFuture.completedFuture(removed); + } - return removed; + cacheMissByGuidMeter.mark(); + + if (serverTimestamp == null) { + return messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, guid); + } else { + return messagesDynamoDb.deleteMessage(destinationUuid, destinationDeviceId, guid, serverTimestamp); + } + + }); } /** @@ -112,10 +141,15 @@ public class MessagesManager { final List messageGuids = messages.stream().map(message -> UUID.fromString(message.getServerGuid())) .collect(Collectors.toList()); - int messagesRemovedFromCache = messagesCache.remove(destinationUuid, destinationDeviceId, messageGuids).size(); - - persistMessageMeter.mark(nonEphemeralMessages.size()); + int messagesRemovedFromCache = 0; + try { + messagesRemovedFromCache = messagesCache.remove(destinationUuid, destinationDeviceId, messageGuids) + .get(30, TimeUnit.SECONDS).size(); + persistMessageMeter.mark(nonEphemeralMessages.size()); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + logger.warn("Failed to remove messages from cache", e); + } return messagesRemovedFromCache; } @@ -129,4 +163,5 @@ public class MessagesManager { public void removeMessageAvailabilityListener(final MessageAvailabilityListener listener) { messagesCache.removeMessageAvailabilityListener(listener); } + } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index e284e6530..0ad9dd7dd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2021 Signal Messenger, LLC + * Copyright 2013-2022 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -11,14 +11,19 @@ import com.codahale.metrics.Counter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tags; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.eclipse.jetty.websocket.api.UpgradeResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.PushNotificationManager; @@ -32,32 +37,48 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener; public class AuthenticatedConnectListener implements WebSocketConnectListener { - private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - private static final Timer durationTimer = metricRegistry.timer(name(WebSocketConnection.class, "connected_duration" )); - private static final Timer unauthenticatedDurationTimer = metricRegistry.timer(name(WebSocketConnection.class, "unauthenticated_connection_duration")); - private static final Counter openWebsocketCounter = metricRegistry.counter(name(WebSocketConnection.class, "open_websockets")); + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private static final Timer durationTimer = metricRegistry.timer( + name(WebSocketConnection.class, "connected_duration")); + private static final Timer unauthenticatedDurationTimer = metricRegistry.timer( + name(WebSocketConnection.class, "unauthenticated_connection_duration")); + private static final Counter openWebsocketCounter = metricRegistry.counter( + name(WebSocketConnection.class, "open_websockets")); + + private static final String OPEN_WEBSOCKET_COUNTER_NAME = MetricsUtil.name(WebSocketConnection.class, + "openWebsockets"); private static final long RENEW_PRESENCE_INTERVAL_MINUTES = 5; + private static final String REACTIVE_MESSAGE_QUEUE_EXPERIMENT_NAME = "reactive_message_queue_v1"; + private static final Logger log = LoggerFactory.getLogger(AuthenticatedConnectListener.class); - private final ReceiptSender receiptSender; - private final MessagesManager messagesManager; + private final ReceiptSender receiptSender; + private final MessagesManager messagesManager; private final PushNotificationManager pushNotificationManager; private final ClientPresenceManager clientPresenceManager; private final ScheduledExecutorService scheduledExecutorService; + private final ExperimentEnrollmentManager experimentEnrollmentManager; + + private final AtomicInteger openReactiveWebSockets = new AtomicInteger(0); + private final AtomicInteger openStandardWebSockets = new AtomicInteger(0); public AuthenticatedConnectListener(ReceiptSender receiptSender, MessagesManager messagesManager, PushNotificationManager pushNotificationManager, ClientPresenceManager clientPresenceManager, - ScheduledExecutorService scheduledExecutorService) - { - this.receiptSender = receiptSender; - this.messagesManager = messagesManager; + ScheduledExecutorService scheduledExecutorService, + ExperimentEnrollmentManager experimentEnrollmentManager) { + this.receiptSender = receiptSender; + this.messagesManager = messagesManager; this.pushNotificationManager = pushNotificationManager; this.clientPresenceManager = clientPresenceManager; this.scheduledExecutorService = scheduledExecutorService; + this.experimentEnrollmentManager = experimentEnrollmentManager; + + Metrics.gauge(OPEN_WEBSOCKET_COUNTER_NAME, Tags.of("reactive", String.valueOf(true)), openReactiveWebSockets); + Metrics.gauge(OPEN_WEBSOCKET_COUNTER_NAME, Tags.of("reactive", String.valueOf(false)), openStandardWebSockets); } @Override @@ -66,43 +87,56 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { final AuthenticatedAccount auth = context.getAuthenticated(AuthenticatedAccount.class); final Device device = auth.getAuthenticatedDevice(); final Timer.Context timer = durationTimer.time(); + final boolean enrolledInReactiveMessageQueue = experimentEnrollmentManager.isEnrolled( + auth.getAccount().getUuid(), + REACTIVE_MESSAGE_QUEUE_EXPERIMENT_NAME); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, context.getClient(), - scheduledExecutorService); + scheduledExecutorService, + enrolledInReactiveMessageQueue); openWebsocketCounter.inc(); + if (enrolledInReactiveMessageQueue) { + openReactiveWebSockets.incrementAndGet(); + } else { + openStandardWebSockets.incrementAndGet(); + } + pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), device, context.getClient().getUserAgent()); final AtomicReference> renewPresenceFutureReference = new AtomicReference<>(); - context.addListener(new WebSocketSessionContext.WebSocketEventListener() { - @Override - public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) { - openWebsocketCounter.dec(); - timer.stop(); - - final ScheduledFuture renewPresenceFuture = renewPresenceFutureReference.get(); - - if (renewPresenceFuture != null) { - renewPresenceFuture.cancel(false); - } - - connection.stop(); - - RedisOperation.unchecked( - () -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), device.getId())); - RedisOperation.unchecked(() -> { - messagesManager.removeMessageAvailabilityListener(connection); - - if (messagesManager.hasCachedMessages(auth.getAccount().getUuid(), device.getId())) { - try { - pushNotificationManager.sendNewMessageNotification(auth.getAccount(), device.getId(), true); - } catch (NotPushRegisteredException ignored) { - } - } - }); + context.addListener((closingContext, statusCode, reason) -> { + openWebsocketCounter.dec(); + if (enrolledInReactiveMessageQueue) { + openReactiveWebSockets.decrementAndGet(); + } else { + openStandardWebSockets.decrementAndGet(); } + + timer.stop(); + + final ScheduledFuture renewPresenceFuture = renewPresenceFutureReference.get(); + + if (renewPresenceFuture != null) { + renewPresenceFuture.cancel(false); + } + + connection.stop(); + + RedisOperation.unchecked( + () -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), device.getId())); + RedisOperation.unchecked(() -> { + messagesManager.removeMessageAvailabilityListener(connection); + + if (messagesManager.hasCachedMessages(auth.getAccount().getUuid(), device.getId())) { + try { + pushNotificationManager.sendNewMessageNotification(auth.getAccount(), device.getId(), true); + } catch (NotPushRegisteredException ignored) { + } + } + }); }); try { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 5e47bd429..1a9a73996 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -1,12 +1,11 @@ /* - * Copyright 2013-2021 Signal Messenger, LLC + * Copyright 2013-2022 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.websocket; import static com.codahale.metrics.MetricRegistry.name; -import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import com.codahale.metrics.Histogram; import com.codahale.metrics.Meter; @@ -34,11 +33,13 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; import javax.ws.rs.WebApplicationException; import org.apache.commons.lang3.StringUtils; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.NoSuchUserException; +import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; @@ -49,13 +50,14 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.TimestampHeaderUtil; -import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; -import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; -@SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener { private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); @@ -70,8 +72,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac name(WebSocketConnection.class, "messagesPersisted")); private static final Meter bytesSentMeter = metricRegistry.meter(name(WebSocketConnection.class, "bytes_sent")); private static final Meter sendFailuresMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_failures")); - private static final Meter discardedMessagesMeter = metricRegistry.meter( - name(WebSocketConnection.class, "discardedMessages")); private static final String INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME = name(WebSocketConnection.class, "initialQueueLength"); @@ -85,11 +85,12 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac "messageAvailableAfterClientClosed"); private static final String STATUS_CODE_TAG = "status"; private static final String STATUS_MESSAGE_TAG = "message"; + private static final String REACTIVE_TAG = "reactive"; private static final long SLOW_DRAIN_THRESHOLD = 10_000; @VisibleForTesting - static final int MAX_DESKTOP_MESSAGE_SIZE = 1024 * 1024; + static final int MESSAGE_PUBLISHER_LIMIT_RATE = 100; @VisibleForTesting static final int MAX_CONSECUTIVE_RETRIES = 5; @@ -111,18 +112,19 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private final ScheduledExecutorService scheduledExecutorService; - private final boolean isDesktopClient; - private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); private final AtomicReference storedMessageState = new AtomicReference<>( StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); private final LongAdder sentMessageCounter = new LongAdder(); private final AtomicLong queueDrainStartTime = new AtomicLong(); - private final AtomicInteger consecutiveRetries = new AtomicInteger(); - private final AtomicReference> retryFuture = new AtomicReference<>(); + private final AtomicInteger consecutiveRetries = new AtomicInteger(); + private final AtomicReference> retryFuture = new AtomicReference<>(); + private final AtomicReference messageSubscription = new AtomicReference<>(); private final Random random = new Random(); + private final boolean useReactive; + private Scheduler reactiveScheduler; private enum StoredMessageState { EMPTY, @@ -135,7 +137,28 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac AuthenticatedAccount auth, Device device, WebSocketClient client, - ScheduledExecutorService scheduledExecutorService) { + ScheduledExecutorService scheduledExecutorService, + boolean useReactive) { + + this(receiptSender, + messagesManager, + auth, + device, + client, + scheduledExecutorService, + useReactive, + Schedulers.boundedElastic()); + } + + @VisibleForTesting + WebSocketConnection(ReceiptSender receiptSender, + MessagesManager messagesManager, + AuthenticatedAccount auth, + Device device, + WebSocketClient client, + ScheduledExecutorService scheduledExecutorService, + boolean useReactive, + Scheduler reactiveScheduler) { this(receiptSender, messagesManager, @@ -143,7 +166,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac device, client, DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS, - scheduledExecutorService); + scheduledExecutorService, + useReactive, + reactiveScheduler); } @VisibleForTesting @@ -153,7 +178,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac Device device, WebSocketClient client, int sendFuturesTimeoutMillis, - ScheduledExecutorService scheduledExecutorService) { + ScheduledExecutorService scheduledExecutorService, + boolean useReactive, + Scheduler reactiveScheduler) { this.receiptSender = receiptSender; this.messagesManager = messagesManager; @@ -162,16 +189,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac this.client = client; this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis; this.scheduledExecutorService = scheduledExecutorService; - - Optional maybePlatform; - - try { - maybePlatform = Optional.of(UserAgentUtil.parseUserAgentString(client.getUserAgent()).getPlatform()); - } catch (final UnrecognizedUserAgentException e) { - maybePlatform = Optional.empty(); - } - - this.isDesktopClient = maybePlatform.map(platform -> platform == ClientPlatform.DESKTOP).orElse(false); + this.useReactive = useReactive; + this.reactiveScheduler = reactiveScheduler; } public void start() { @@ -186,10 +205,15 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac future.cancel(false); } + final Disposable subscription = messageSubscription.get(); + if (subscription != null) { + subscription.dispose(); + } + client.close(1000, "OK"); } - private CompletableFuture sendMessage(final Envelope message, final Optional storedMessageInfo) { + private CompletableFuture sendMessage(final Envelope message, StoredMessageInfo storedMessageInfo) { // clear ephemeral field from the envelope final Optional body = Optional.ofNullable(message.toBuilder().clearEphemeral().build().toByteArray()); @@ -199,33 +223,43 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac MessageMetrics.measureAccountEnvelopeUuidMismatches(auth.getAccount(), message); // X-Signal-Key: false must be sent until Android stops assuming it missing means true - return client.sendRequest("PUT", "/api/v1/message", List.of("X-Signal-Key: false", TimestampHeaderUtil.getTimestampHeader()), body).whenComplete((response, throwable) -> { - if (throwable == null) { - if (isSuccessResponse(response)) { - if (storedMessageInfo.isPresent()) { - messagesManager.delete(auth.getAccount().getUuid(), device.getId(), storedMessageInfo.get().getGuid(), storedMessageInfo.get().getServerTimestamp()); - } - - if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) { - recordMessageDeliveryDuration(message.getTimestamp(), device); - sendDeliveryReceiptFor(message); - } - } else { - final List tags = new ArrayList<>( - List.of(Tag.of(STATUS_CODE_TAG, String.valueOf(response.getStatus())), - UserAgentTagUtil.getPlatformTag(client.getUserAgent()))); - - // TODO Remove this once we've identified the cause of message rejections from desktop clients - if (StringUtils.isNotBlank(response.getMessage())) { - tags.add(Tag.of(STATUS_MESSAGE_TAG, response.getMessage())); + return client.sendRequest("PUT", "/api/v1/message", + List.of("X-Signal-Key: false", TimestampHeaderUtil.getTimestampHeader()), body) + .whenComplete((ignored, throwable) -> { + if (throwable != null) { + sendFailuresMeter.mark(); } + }).thenCompose(response -> { + final CompletableFuture result; + if (isSuccessResponse(response)) { - Metrics.counter(NON_SUCCESS_RESPONSE_COUNTER_NAME, tags).increment(); - } - } else { - sendFailuresMeter.mark(); - } - }); + result = messagesManager.delete(auth.getAccount().getUuid(), device.getId(), + storedMessageInfo.guid(), storedMessageInfo.serverTimestamp()); + + if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) { + recordMessageDeliveryDuration(message.getTimestamp(), device); + sendDeliveryReceiptFor(message); + } + } else { + final List tags = new ArrayList<>( + List.of( + Tag.of(STATUS_CODE_TAG, String.valueOf(response.getStatus())), + UserAgentTagUtil.getPlatformTag(client.getUserAgent()), + Tag.of(REACTIVE_TAG, String.valueOf(useReactive)) + )); + + // TODO Remove this once we've identified the cause of message rejections from desktop clients + if (StringUtils.isNotBlank(response.getMessage())) { + tags.add(Tag.of(STATUS_MESSAGE_TAG, response.getMessage())); + } + + Metrics.counter(NON_SUCCESS_RESPONSE_COUNTER_NAME, tags).increment(); + + result = CompletableFuture.completedFuture(null); + } + + return result; + }); } public static void recordMessageDeliveryDuration(long timestamp, Device messageDestinationDevice) { @@ -260,65 +294,96 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac @VisibleForTesting void processStoredMessages() { - if (processStoredMessagesSemaphore.tryAcquire()) { - final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY); - final CompletableFuture queueClearedFuture = new CompletableFuture<>(); - - sendNextMessagePage(state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE, queueClearedFuture); - - queueClearedFuture.whenComplete((v, cause) -> { - if (cause == null) { - consecutiveRetries.set(0); - - if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { - final List tags = List.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); - final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get(); - - Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum()); - Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDuration, TimeUnit.MILLISECONDS); - - if (drainDuration > SLOW_DRAIN_THRESHOLD) { - Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment(); - } - - client.sendRequest("PUT", "/api/v1/queue/empty", - Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); - } - } else { - storedMessageState.compareAndSet(StoredMessageState.EMPTY, state); - } - - processStoredMessagesSemaphore.release(); - - if (cause == null) { - if (storedMessageState.get() != StoredMessageState.EMPTY) { - processStoredMessages(); - } - } else { - if (client.isOpen()) { - - if (consecutiveRetries.incrementAndGet() > MAX_CONSECUTIVE_RETRIES) { - logger.warn("Max consecutive retries exceeded", cause); - client.close(1011, "Failed to retrieve messages"); - } else { - logger.debug("Failed to clear queue", cause); - final List tags = List.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); - - Metrics.counter(QUEUE_DRAIN_RETRY_COUNTER_NAME, tags).increment(); - - final long delay = RETRY_DELAY_MILLIS + random.nextInt(RETRY_DELAY_JITTER_MILLIS); - retryFuture - .set(scheduledExecutorService.schedule(this::processStoredMessages, delay, TimeUnit.MILLISECONDS)); - } - } else { - logger.debug("Client disconnected before queue cleared"); - } - } - }); + if (useReactive) { + processStoredMessages_reactive(); + } else { + processStoredMessage_paged(); } } - private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture queueClearedFuture) { + private void processStoredMessage_paged() { + assert !useReactive; + + if (processStoredMessagesSemaphore.tryAcquire()) { + final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY); + final CompletableFuture queueCleared = new CompletableFuture<>(); + + sendNextMessagePage(state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE, queueCleared); + + setQueueClearedHandler(state, queueCleared); + } + } + + private void setQueueClearedHandler(final StoredMessageState state, final CompletableFuture queueCleared) { + + queueCleared.whenComplete((v, cause) -> { + if (cause == null) { + consecutiveRetries.set(0); + + if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { + final List tags = List.of( + UserAgentTagUtil.getPlatformTag(client.getUserAgent()), + Tag.of(REACTIVE_TAG, String.valueOf(useReactive)) + ); + final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get(); + + Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum()); + Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDuration, TimeUnit.MILLISECONDS); + + if (drainDuration > SLOW_DRAIN_THRESHOLD) { + Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment(); + } + + client.sendRequest("PUT", "/api/v1/queue/empty", + Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); + } + } else { + storedMessageState.compareAndSet(StoredMessageState.EMPTY, state); + } + + processStoredMessagesSemaphore.release(); + + if (cause == null) { + if (storedMessageState.get() != StoredMessageState.EMPTY) { + processStoredMessages(); + } + } else { + if (client.isOpen()) { + + if (consecutiveRetries.incrementAndGet() > MAX_CONSECUTIVE_RETRIES) { + logger.warn("Max consecutive retries exceeded", cause); + client.close(1011, "Failed to retrieve messages"); + } else { + logger.debug("Failed to clear queue", cause); + final List tags = List.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); + + Metrics.counter(QUEUE_DRAIN_RETRY_COUNTER_NAME, tags).increment(); + + final long delay = RETRY_DELAY_MILLIS + random.nextInt(RETRY_DELAY_JITTER_MILLIS); + retryFuture + .set(scheduledExecutorService.schedule(this::processStoredMessages, delay, TimeUnit.MILLISECONDS)); + } + } else { + logger.debug("Client disconnected before queue cleared"); + } + } + }); + } + + private void processStoredMessages_reactive() { + assert useReactive; + + if (processStoredMessagesSemaphore.tryAcquire()) { + final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY); + final CompletableFuture queueCleared = new CompletableFuture<>(); + + sendMessagesReactive(state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE, queueCleared); + + setQueueClearedHandler(state, queueCleared); + } + } + + private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture queueCleared) { try { final Pair, Boolean> messagesAndHasMore = messagesManager.getMessagesForDevice( auth.getAccount().getUuid(), device.getId(), cachedMessagesOnly); @@ -330,25 +395,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac for (int i = 0; i < messages.size(); i++) { final Envelope envelope = messages.get(i); - final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); - - final boolean discard; - if (isDesktopClient && envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE) { - discard = true; - } else if (envelope.getStory() && !client.shouldDeliverStories()) { - discard = true; - } else { - discard = false; - } - - if (discard) { - messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp()); - discardedMessagesMeter.mark(); - - sendFutures[i] = CompletableFuture.completedFuture(null); - } else { - sendFutures[i] = sendMessage(envelope, Optional.of(new StoredMessageInfo(messageGuid, envelope.getServerTimestamp()))); - } + sendFutures[i] = sendMessage(envelope); } // Set a large, non-zero timeout, to prevent any failure to acknowledge receipt from blocking indefinitely @@ -357,16 +404,45 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac .whenComplete((v, cause) -> { if (cause == null) { if (hasMore) { - sendNextMessagePage(cachedMessagesOnly, queueClearedFuture); + sendNextMessagePage(cachedMessagesOnly, queueCleared); } else { - queueClearedFuture.complete(null); + queueCleared.complete(null); } } else { - queueClearedFuture.completeExceptionally(cause); + queueCleared.completeExceptionally(cause); } }); } catch (final Exception e) { - queueClearedFuture.completeExceptionally(e); + queueCleared.completeExceptionally(e); + } + } + + private void sendMessagesReactive(final boolean cachedMessagesOnly, final CompletableFuture queueCleared) { + + final Publisher messages = + messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), device.getId(), cachedMessagesOnly); + + final Disposable subscription = Flux.from(messages) + .limitRate(MESSAGE_PUBLISHER_LIMIT_RATE) + .flatMapSequential(envelope -> + Mono.fromFuture(sendMessage(envelope).orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS))) + .doOnError(queueCleared::completeExceptionally) + .doOnComplete(() -> queueCleared.complete(null)) + .subscribeOn(reactiveScheduler) + .subscribe(); + + messageSubscription.set(subscription); + } + + private CompletableFuture sendMessage(Envelope envelope) { + final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); + + if (envelope.getStory() && !client.shouldDeliverStories()) { + messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp()); + + return CompletableFuture.completedFuture(null); + } else { + return sendMessage(envelope, new StoredMessageInfo(messageGuid, envelope.getServerTimestamp())); } } @@ -381,6 +457,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac messageAvailableMeter.mark(); storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE); + processStoredMessages(); return true; @@ -396,6 +473,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac messagesPersistedMeter.mark(); storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); + processStoredMessages(); return true; @@ -405,7 +483,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac public void handleDisplacement(final boolean connectedElsewhere) { final Tags tags = Tags.of( UserAgentTagUtil.getPlatformTag(client.getUserAgent()), - Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere))); + Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)), + Tag.of(REACTIVE_TAG, String.valueOf(useReactive))); Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment(); @@ -429,21 +508,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac } } - private static class StoredMessageInfo { - private final UUID guid; - private final long serverTimestamp; + private record StoredMessageInfo(UUID guid, long serverTimestamp) { - public StoredMessageInfo(UUID guid, long serverTimestamp) { - this.guid = guid; - this.serverTimestamp = serverTimestamp; - } - - public UUID getGuid() { - return guid; - } - - public long getServerTimestamp() { - return serverTimestamp; - } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java index 1696b61cf..3fdc3b598 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java @@ -25,7 +25,6 @@ import net.sourceforge.argparse4j.inf.Subparser; import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; -import org.whispersystems.textsecuregcm.push.PushLatencyManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; @@ -45,9 +44,9 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.Profiles; import org.whispersystems.textsecuregcm.storage.ProfilesManager; +import org.whispersystems.textsecuregcm.storage.ProhibitedUsernames; import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; -import org.whispersystems.textsecuregcm.storage.ProhibitedUsernames; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; import org.whispersystems.textsecuregcm.storage.UsernameNotAvailableException; import org.whispersystems.textsecuregcm.storage.VerificationCodeStore; @@ -97,6 +96,8 @@ public class AssignUsernameCommand extends EnvironmentCommand commands = mock(RedisAdvancedClusterCommands.class); - final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands); + final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.builder().stringCommands(commands).build(); final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; @@ -51,7 +62,7 @@ public class ClusterLuaScriptTest { @Test void testExecuteScriptNotLoaded() { final RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); - final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands); + final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.builder().stringCommands(commands).build(); final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; @@ -71,8 +82,10 @@ public class ClusterLuaScriptTest { void testExecuteBinaryScriptNotLoaded() { final RedisAdvancedClusterCommands stringCommands = mock(RedisAdvancedClusterCommands.class); final RedisAdvancedClusterCommands binaryCommands = mock(RedisAdvancedClusterCommands.class); - final FaultTolerantRedisCluster mockCluster = - RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands); + final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.builder() + .stringCommands(stringCommands) + .binaryCommands(binaryCommands) + .build(); final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; @@ -85,17 +98,85 @@ public class ClusterLuaScriptTest { luaScript.executeBinary(keys, values); verify(binaryCommands).eval(script, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][])); - verify(binaryCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][])); + verify(binaryCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]), + values.toArray(new byte[0][])); } @Test - public void testExecuteRealCluster() { + void testExecuteBinaryAsyncScriptNotLoaded() throws Exception { + final RedisAdvancedClusterAsyncCommands binaryAsyncCommands = + mock(RedisAdvancedClusterAsyncCommands.class); + final FaultTolerantRedisCluster mockCluster = + RedisClusterHelper.builder().binaryAsyncCommands(binaryAsyncCommands).build(); + + final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; + final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; + final List keys = List.of("key".getBytes(StandardCharsets.UTF_8)); + final List values = List.of("value".getBytes(StandardCharsets.UTF_8)); + + final AsyncCommand evalShaFailure = new AsyncCommand<>(mock(RedisCommand.class)); + evalShaFailure.completeExceptionally(new RedisNoScriptException("OH NO")); + + final AsyncCommand evalSuccess = new AsyncCommand<>(mock(RedisCommand.class)); + evalSuccess.complete(); + + when(binaryAsyncCommands.evalsha(any(), any(), any(), any())).thenReturn((RedisFuture) evalShaFailure); + when(binaryAsyncCommands.eval(anyString(), any(), any(), any())).thenReturn((RedisFuture) evalSuccess); + + final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType); + luaScript.executeBinaryAsync(keys, values).get(5, TimeUnit.SECONDS); + + verify(binaryAsyncCommands).eval(script, scriptOutputType, keys.toArray(new byte[0][]), + values.toArray(new byte[0][])); + verify(binaryAsyncCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]), + values.toArray(new byte[0][])); + } + + @Test + void testExecuteBinaryReactiveScriptNotLoaded() { + final RedisAdvancedClusterReactiveCommands binaryReactiveCommands = + mock(RedisAdvancedClusterReactiveCommands.class); + final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.builder() + .binaryReactiveCommands(binaryReactiveCommands).build(); + + final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])"; + final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE; + final List keys = List.of("key".getBytes(StandardCharsets.UTF_8)); + final List values = List.of("value".getBytes(StandardCharsets.UTF_8)); + + when(binaryReactiveCommands.evalsha(any(), any(), any(), any())) + .thenReturn(Flux.error(new RedisNoScriptException("OH NO"))); + when(binaryReactiveCommands.eval(anyString(), any(), any(), any())).thenReturn(Flux.just("ok")); + + final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType); + luaScript.executeBinaryReactive(keys, values).blockLast(Duration.ofSeconds(5)); + + verify(binaryReactiveCommands).eval(script, scriptOutputType, keys.toArray(new byte[0][]), + values.toArray(new byte[0][])); + verify(binaryReactiveCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]), + values.toArray(new byte[0][])); + } + + @ParameterizedTest + @EnumSource(ExecuteMode.class) + void testExecuteRealCluster(final ExecuteMode mode) throws Exception { + REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(c -> c.sync().scriptFlush(FlushMode.SYNC)); + REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(c -> c.sync().configResetstat()); + final ClusterLuaScript script = new ClusterLuaScript(REDIS_CLUSTER_EXTENSION.getRedisCluster(), "return 2;", ScriptOutputType.INTEGER); for (int i = 0; i < 7; i++) { - assertEquals(2L, script.execute(Collections.emptyList(), Collections.emptyList())); + final long actual = switch (mode) { + case SYNC -> (long) script.execute(Collections.emptyList(), Collections.emptyList()); + case ASYNC -> + (long) script.executeAsync(Collections.emptyList(), Collections.emptyList()).get(5, TimeUnit.SECONDS); + case REACTIVE -> (long) script.executeReactive(Collections.emptyList(), Collections.emptyList()) + .blockLast(Duration.ofSeconds(5)); + }; + + assertEquals(2L, actual); } final int evalCount = REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(connection -> { @@ -120,4 +201,11 @@ public class ClusterLuaScriptTest { assertEquals(1, evalCount); } + + private enum ExecuteMode { + SYNC, + ASYNC, + REACTIVE + } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index cb11eecbe..d45d6543c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -155,7 +155,7 @@ class AccountsManagerConcurrentModificationIntegrationTest { accountsManager = new AccountsManager( accounts, phoneNumberIdentifiers, - RedisClusterHelper.buildMockRedisCluster(commands), + RedisClusterHelper.builder().stringCommands(commands).build(), deletedAccountsManager, mock(DirectoryQueue.class), mock(Keys.class), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 8dfacb199..5a70e0798 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -147,7 +147,7 @@ class AccountsManagerTest { accountsManager = new AccountsManager( accounts, phoneNumberIdentifiers, - RedisClusterHelper.buildMockRedisCluster(commands), + RedisClusterHelper.builder().stringCommands(commands).build(), deletedAccountsManager, directoryQueue, keys, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java index 0c8fb0bde..bcf881716 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java @@ -78,7 +78,14 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback } @Override - public void afterEach(ExtensionContext context) throws Exception { + public void afterEach(ExtensionContext context) { + stopServer(); + } + + /** + * For use in integration tests that want to test resiliency/error handling + */ + public void stopServer() { try { server.stop(); } catch (Exception e) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index 793bac9ba..1a08d62c5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -15,6 +15,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import io.lettuce.core.cluster.SlotHash; import java.nio.ByteBuffer; +import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; @@ -32,7 +33,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.push.PushLatencyManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; @@ -47,6 +47,7 @@ class MessagePersisterIntegrationTest { static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); private ExecutorService notificationExecutorService; + private ExecutorService messageDeletionExecutorService; private MessagesCache messagesCache; private MessagesManager messagesManager; private MessagePersister messagePersister; @@ -66,13 +67,16 @@ class MessagePersisterIntegrationTest { when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); + messageDeletionExecutorService = Executors.newSingleThreadExecutor(); final MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), - MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(14)); + dynamoDbExtension.getDynamoDbAsyncClient(), MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(14), + messageDeletionExecutorService); final AccountsManager accountsManager = mock(AccountsManager.class); notificationExecutorService = Executors.newSingleThreadExecutor(); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), - REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService); + REDIS_CLUSTER_EXTENSION.getRedisCluster(), Clock.systemUTC(), notificationExecutorService, + messageDeletionExecutorService); messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class)); messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, PERSIST_DELAY); @@ -94,6 +98,9 @@ class MessagePersisterIntegrationTest { void tearDown() throws Exception { notificationExecutorService.shutdown(); notificationExecutorService.awaitTermination(15, TimeUnit.SECONDS); + + messageDeletionExecutorService.shutdown(); + messageDeletionExecutorService.awaitTermination(15, TimeUnit.SECONDS); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index bf0b36bc0..0f63fa89f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import com.google.protobuf.ByteString; import io.lettuce.core.cluster.SlotHash; import java.nio.charset.StandardCharsets; +import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.List; @@ -46,7 +47,7 @@ class MessagePersisterTest { @RegisterExtension static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); - private ExecutorService notificationExecutorService; + private ExecutorService sharedExecutorService; private MessagesCache messagesCache; private MessagesDynamoDb messagesDynamoDb; private MessagePersister messagePersister; @@ -74,9 +75,9 @@ class MessagePersisterTest { when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); - notificationExecutorService = Executors.newSingleThreadExecutor(); + sharedExecutorService = Executors.newSingleThreadExecutor(); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), - REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService); + REDIS_CLUSTER_EXTENSION.getRedisCluster(), Clock.systemUTC(), sharedExecutorService, sharedExecutorService); messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, PERSIST_DELAY); @@ -88,7 +89,7 @@ class MessagePersisterTest { messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); for (final MessageProtos.Envelope message : messages) { - messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())); + messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())).get(); } return null; @@ -97,8 +98,8 @@ class MessagePersisterTest { @AfterEach void tearDown() throws Exception { - notificationExecutorService.shutdown(); - notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS); + sharedExecutorService.shutdown(); + sharedExecutorService.awaitTermination(1, TimeUnit.SECONDS); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 2e603ff26..b4cb88b7a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2021 Signal Messenger, LLC + * Copyright 2013-2022 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -7,16 +7,34 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import com.google.protobuf.ByteString; +import io.lettuce.core.RedisFuture; import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; +import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands; +import io.lettuce.core.protocol.AsyncCommand; +import io.lettuce.core.protocol.RedisCommand; import java.nio.charset.StandardCharsets; +import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.time.ZoneId; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; +import java.util.Deque; import java.util.List; import java.util.Optional; import java.util.Random; @@ -26,191 +44,692 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.test.StepVerifier; class MessagesCacheTest { - @RegisterExtension - static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); - - private ExecutorService notificationExecutorService; - private MessagesCache messagesCache; - private final Random random = new Random(); private long serialTimestamp = 0; - private static final UUID DESTINATION_UUID = UUID.randomUUID(); - private static final int DESTINATION_DEVICE_ID = 7; + @Nested + class WithRealCluster { - @BeforeEach - void setUp() throws Exception { + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); - REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> { - connection.sync().flushall(); - connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz"); - }); + private ExecutorService sharedExecutorService; + private MessagesCache messagesCache; - notificationExecutorService = Executors.newSingleThreadExecutor(); - messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService); + private static final UUID DESTINATION_UUID = UUID.randomUUID(); + private static final int DESTINATION_DEVICE_ID = 7; - messagesCache.start(); - } + @BeforeEach + void setUp() throws Exception { - @AfterEach - void tearDown() throws Exception { - messagesCache.stop(); + REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> { + connection.sync().flushall(); + connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz"); + }); - notificationExecutorService.shutdown(); - notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS); - } + sharedExecutorService = Executors.newSingleThreadExecutor(); + messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), + REDIS_CLUSTER_EXTENSION.getRedisCluster(), Clock.systemUTC(), sharedExecutorService, + sharedExecutorService); - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testInsert(final boolean sealedSender) { - final UUID messageGuid = UUID.randomUUID(); - assertTrue(messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, - generateRandomMessage(messageGuid, sealedSender)) > 0); - } - - @Test - void testDoubleInsertGuid() { - final UUID duplicateGuid = UUID.randomUUID(); - final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false); - - final long firstId = messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); - final long secondId = messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, - duplicateMessage); - - assertEquals(firstId, secondId); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testRemoveByUUID(final boolean sealedSender) { - final UUID messageGuid = UUID.randomUUID(); - - assertEquals(Optional.empty(), messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid)); - - final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); - final Optional maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID, - DESTINATION_DEVICE_ID, messageGuid); - - assertEquals(Optional.of(message), maybeRemovedMessage); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testRemoveBatchByUUID(final boolean sealedSender) { - final int messageCount = 10; - - final List messagesToRemove = new ArrayList<>(messageCount); - final List messagesToPreserve = new ArrayList<>(messageCount); - - for (int i = 0; i < 10; i++) { - messagesToRemove.add(generateRandomMessage(UUID.randomUUID(), sealedSender)); - messagesToPreserve.add(generateRandomMessage(UUID.randomUUID(), sealedSender)); + messagesCache.start(); } - assertEquals(Collections.emptyList(), messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, - messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid())) - .collect(Collectors.toList()))); + @AfterEach + void tearDown() throws Exception { + messagesCache.stop(); - for (final MessageProtos.Envelope message : messagesToRemove) { - messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + sharedExecutorService.shutdown(); + sharedExecutorService.awaitTermination(1, TimeUnit.SECONDS); } - for (final MessageProtos.Envelope message : messagesToPreserve) { - messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, message); - } - - final List removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, - messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid())) - .collect(Collectors.toList())); - - assertEquals(messagesToRemove, removedMessages); - assertEquals(messagesToPreserve, - messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); - } - - @Test - void testHasMessages() { - assertFalse(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); - - final UUID messageGuid = UUID.randomUUID(); - final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); - - assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testGetMessages(final boolean sealedSender) { - final int messageCount = 100; - - final List expectedMessages = new ArrayList<>(messageCount); - - for (int i = 0; i < messageCount; i++) { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testInsert(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); + assertTrue(messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid, sealedSender)) > 0); + } + + @Test + void testDoubleInsertGuid() { + final UUID duplicateGuid = UUID.randomUUID(); + final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false); + + final long firstId = messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, + duplicateMessage); + final long secondId = messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, + duplicateMessage); + + assertEquals(firstId, secondId); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testRemoveByUUID(final boolean sealedSender) throws Exception { + final UUID messageGuid = UUID.randomUUID(); + + assertEquals(Optional.empty(), + messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS)); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); + + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + final Optional maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID, + DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS); + + assertEquals(Optional.of(message), maybeRemovedMessage); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testRemoveBatchByUUID(final boolean sealedSender) throws Exception { + final int messageCount = 10; + + final List messagesToRemove = new ArrayList<>(messageCount); + final List messagesToPreserve = new ArrayList<>(messageCount); + + for (int i = 0; i < 10; i++) { + messagesToRemove.add(generateRandomMessage(UUID.randomUUID(), sealedSender)); + messagesToPreserve.add(generateRandomMessage(UUID.randomUUID(), sealedSender)); + } + + assertEquals(Collections.emptyList(), messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, + messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid())) + .collect(Collectors.toList())).get(5, TimeUnit.SECONDS)); + + for (final MessageProtos.Envelope message : messagesToRemove) { + messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, + message); + } + + for (final MessageProtos.Envelope message : messagesToPreserve) { + messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, + message); + } + + final List removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, + messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid())) + .collect(Collectors.toList())).get(5, TimeUnit.SECONDS); + + assertEquals(messagesToRemove, removedMessages); + assertEquals(messagesToPreserve, + messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); + } + + @Test + void testHasMessages() { + assertFalse(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); + + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); - expectedMessages.add(message); + assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); } - assertEquals(expectedMessages, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); - } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testGetMessages(final boolean sealedSender) throws Exception { + final int messageCount = 100; - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testClearQueueForDevice(final boolean sealedSender) { - final int messageCount = 100; + final List expectedMessages = new ArrayList<>(messageCount); - for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); - messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message); + expectedMessages.add(message); } + + assertEquals(expectedMessages, get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); + + messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, + expectedMessages.stream() + .map(MessageProtos.Envelope::getServerGuid) + .map(UUID::fromString) + .collect(Collectors.toList())); + + final UUID message1Guid = UUID.randomUUID(); + final MessageProtos.Envelope message1 = generateRandomMessage(message1Guid, sealedSender); + messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1); + final List get1 = get(DESTINATION_UUID, DESTINATION_DEVICE_ID, + 1); + assertEquals(List.of(message1), get1); + + messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, message1Guid).get(5, TimeUnit.SECONDS); + + final UUID message2Guid = UUID.randomUUID(); + final MessageProtos.Envelope message2 = generateRandomMessage(message2Guid, sealedSender); + + messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2); + + assertEquals(List.of(message2), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, 1)); } - messagesCache.clear(DESTINATION_UUID, DESTINATION_DEVICE_ID); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testGetMessagesPublisher(final boolean expectStale) throws Exception { + final int messageCount = 214; - assertEquals(Collections.emptyList(), messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); - assertEquals(messageCount, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size()); - } + final List expectedMessages = new ArrayList<>(messageCount); - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testClearQueueForAccount(final boolean sealedSender) { - final int messageCount = 100; - - for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); - final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); - messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message); + expectedMessages.add(message); + } + + final UUID ephemeralMessageGuid = UUID.randomUUID(); + final MessageProtos.Envelope ephemeralMessage = generateRandomMessage(ephemeralMessageGuid, true) + .toBuilder().setEphemeral(true).build(); + messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage); + + final Clock cacheClock; + if (expectStale) { + cacheClock = Clock.fixed(Instant.ofEpochMilli(serialTimestamp + 1), + ZoneId.of("Etc/UTC")); + } else { + cacheClock = Clock.fixed( + Instant.ofEpochMilli(serialTimestamp + 1).plus(MessagesCache.MAX_EPHEMERAL_MESSAGE_DELAY), + ZoneId.of("Etc/UTC")); + } + + final MessagesCache messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + cacheClock, + sharedExecutorService, + sharedExecutorService); + + final List actualMessages = Flux.from( + messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID)) + .collectList() + .block(Duration.ofSeconds(5)); + + if (expectStale) { + final List expectedAllMessages = new ArrayList<>() {{ + addAll(expectedMessages); + add(ephemeralMessage); + }}; + + assertEquals(expectedAllMessages, actualMessages); + + } else { + assertEquals(expectedMessages, actualMessages); + + // delete all of these messages and call `getAll()`, to confirm that ephemeral messages have been discarded + CompletableFuture.allOf(actualMessages.stream() + .map(message -> messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, + UUID.fromString(message.getServerGuid()))) + .toArray(CompletableFuture[]::new)) + .get(5, TimeUnit.SECONDS); + + final List messages = messagesCache.getAllMessages(DESTINATION_UUID, + DESTINATION_DEVICE_ID) + .collectList() + .toFuture().get(5, TimeUnit.SECONDS); + + assertTrue(messages.isEmpty()); } } - messagesCache.clear(DESTINATION_UUID); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testClearQueueForDevice(final boolean sealedSender) { + final int messageCount = 100; - assertEquals(Collections.emptyList(), messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); - assertEquals(Collections.emptyList(), messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount)); + for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { + for (int i = 0; i < messageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); + + messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message); + } + } + + messagesCache.clear(DESTINATION_UUID, DESTINATION_DEVICE_ID); + + assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); + assertEquals(messageCount, get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testClearQueueForAccount(final boolean sealedSender) { + final int messageCount = 100; + + for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { + for (int i = 0; i < messageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); + + messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message); + } + } + + messagesCache.clear(DESTINATION_UUID); + + assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); + assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount)); + } + + @Test + void testClearNullUuid() { + // We're happy as long as this doesn't throw an exception + messagesCache.clear(null); + } + + @Test + void testGetAccountFromQueueName() { + assertEquals(DESTINATION_UUID, + MessagesCache.getAccountUuidFromQueueName( + new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), + StandardCharsets.UTF_8))); + } + + @Test + void testGetDeviceIdFromQueueName() { + assertEquals(DESTINATION_DEVICE_ID, + MessagesCache.getDeviceIdFromQueueName( + new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), + StandardCharsets.UTF_8))); + } + + @Test + void testGetQueueNameFromKeyspaceChannel() { + assertEquals("1b363a31-a429-4fb6-8959-984a025e72ff::7", + MessagesCache.getQueueNameFromKeyspaceChannel( + "__keyspace@0__:user_queue::{1b363a31-a429-4fb6-8959-984a025e72ff::7}")); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testGetQueuesToPersist(final boolean sealedSender) { + final UUID messageGuid = UUID.randomUUID(); + + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid, sealedSender)); + final int slot = SlotHash.getSlot(DESTINATION_UUID + "::" + DESTINATION_DEVICE_ID); + + assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty()); + + final List queues = messagesCache.getQueuesToPersist(slot, Instant.now().plusSeconds(60), 100); + + assertEquals(1, queues.size()); + assertEquals(DESTINATION_UUID, MessagesCache.getAccountUuidFromQueueName(queues.get(0))); + assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.get(0))); + } + + @Test + void testNotifyListenerNewMessage() { + final AtomicBoolean notified = new AtomicBoolean(false); + final UUID messageGuid = UUID.randomUUID(); + + final MessageAvailabilityListener listener = new MessageAvailabilityListener() { + @Override + public boolean handleNewMessagesAvailable() { + synchronized (notified) { + notified.set(true); + notified.notifyAll(); + + return true; + } + } + + @Override + public boolean handleMessagesPersisted() { + return true; + } + }; + + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid, true)); + + synchronized (notified) { + while (!notified.get()) { + notified.wait(); + } + } + + assertTrue(notified.get()); + }); + } + + @Test + void testNotifyListenerPersisted() { + final AtomicBoolean notified = new AtomicBoolean(false); + + final MessageAvailabilityListener listener = new MessageAvailabilityListener() { + @Override + public boolean handleNewMessagesAvailable() { + return true; + } + + @Override + public boolean handleMessagesPersisted() { + synchronized (notified) { + notified.set(true); + notified.notifyAll(); + + return true; + } + } + }; + + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); + + messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); + messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); + + synchronized (notified) { + while (!notified.get()) { + notified.wait(); + } + } + + assertTrue(notified.get()); + }); + } + + + /** + * Helper class that implements {@link MessageAvailabilityListener#handleNewMessagesAvailable()} by always returning + * {@code false}. Its {@code counter} field tracks how many times {@code handleNewMessagesAvailable} has been + * called. + *

+ * It uses a {@link CompletableFuture} to signal that it has received a “messages available” callback for the first + * time. + */ + private static class NewMessagesAvailabilityClosedListener implements MessageAvailabilityListener { + + private int counter; + + private final Consumer messageHandledCallback; + private final CompletableFuture firstMessageHandled = new CompletableFuture<>(); + + private NewMessagesAvailabilityClosedListener(final Consumer messageHandledCallback) { + this.messageHandledCallback = messageHandledCallback; + } + + @Override + public boolean handleNewMessagesAvailable() { + counter++; + messageHandledCallback.accept(counter); + firstMessageHandled.complete(null); + + return false; + + } + + @Override + public boolean handleMessagesPersisted() { + return true; + } + } + + @Test + void testAvailabilityListenerResponses() { + final NewMessagesAvailabilityClosedListener listener1 = new NewMessagesAvailabilityClosedListener( + count -> assertEquals(1, count)); + final NewMessagesAvailabilityClosedListener listener2 = new NewMessagesAvailabilityClosedListener( + count -> assertEquals(1, count)); + + assertTimeoutPreemptively(Duration.ofSeconds(30), () -> { + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener1); + final UUID messageGuid1 = UUID.randomUUID(); + messagesCache.insert(messageGuid1, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid1, true)); + + listener1.firstMessageHandled.get(); + + // Avoid a race condition by blocking on the message handled future *and* the current notification executor task— + // the notification executor task includes unsubscribing `listener1`, and, if we don’t wait, sometimes + // `listener2` will get subscribed before `listener1` is cleaned up + sharedExecutorService.submit(() -> listener1.firstMessageHandled.get()).get(); + + final UUID messageGuid2 = UUID.randomUUID(); + messagesCache.insert(messageGuid2, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid2, true)); + + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener2); + + final UUID messageGuid3 = UUID.randomUUID(); + messagesCache.insert(messageGuid3, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid3, true)); + + listener2.firstMessageHandled.get(); + }); + } + + private List get(final UUID destinationUuid, final long destinationDeviceId, + final int messageCount) { + return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId)) + .take(messageCount, true) + .collectList() + .block(); + } + + } + + @Nested + class WithMockCluster { + + private MessagesCache messagesCache; + private RedisAdvancedClusterReactiveCommands reactiveCommands; + private RedisAdvancedClusterAsyncCommands asyncCommands; + + @SuppressWarnings("unchecked") + @BeforeEach + void setup() throws Exception { + reactiveCommands = mock(RedisAdvancedClusterReactiveCommands.class); + asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class); + + final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.builder() + .binaryReactiveCommands(reactiveCommands) + .binaryAsyncCommands(asyncCommands) + .build(); + + messagesCache = new MessagesCache(mockCluster, mockCluster, Clock.systemUTC(), mock(ExecutorService.class), + mock(ExecutorService.class)); + } + + @AfterEach + void teardown() { + StepVerifier.resetDefaultTimeout(); + } + + @Test + void testGetAllMessagesLimitsAndBackpressure() { + // this test makes sure that we don’t fetch and buffer all messages from the cache when the publisher + // is subscribed. Rather, we should be fetching in pages to satisfy downstream requests, so that memory usage + // is limited to few pages of messages + + // we use a combination of Flux.just() and Sinks to control when data is “fetched” from the cache. The initial + // Flux.just()s are pages that are readily available, on demand. By design, there are more of these pages than + // the initial prefetch. The sinks allow us to create extra demand but defer producing values to satisfy the demand + // until later on. + + final AtomicReference> page4Sink = new AtomicReference<>(); + final AtomicReference> page56Sink = new AtomicReference<>(); + final AtomicReference> emptyFinalPageSink = new AtomicReference<>(); + + final Deque> pages = new ArrayDeque<>(); + pages.add(generatePage()); + pages.add(generatePage()); + pages.add(generatePage()); + pages.add(generatePage()); + // make sure that stale ephemeral messages are also produced by calls to getAllMessages() + pages.add(generateStaleEphemeralPage()); + pages.add(generatePage()); + + when(reactiveCommands.evalsha(any(), any(), any(), any())) + .thenReturn(Flux.just(pages.pop())) + .thenReturn(Flux.just(pages.pop())) + .thenReturn(Flux.just(pages.pop())) + .thenReturn(Flux.create(sink -> page4Sink.compareAndSet(null, sink))) + .thenReturn(Flux.create(sink -> page56Sink.compareAndSet(null, sink))) + .thenReturn(Flux.create(sink -> emptyFinalPageSink.compareAndSet(null, sink))) + .thenReturn(Flux.empty()); + + final Flux allMessages = messagesCache.getAllMessages(UUID.randomUUID(), 1L); + + // Why initialValue = 3? + // 1. messagesCache.getAllMessages() above produces the first call + // 2. when we subscribe, the prefetch of 1 results in `expand()`, which produces a second call + // 3. there is an implicit “low tide mark” of 1, meaning there will be an extra call to replenish when there is + // 1 value remaining + final AtomicInteger expectedReactiveCommandInvocations = new AtomicInteger(3); + + StepVerifier.setDefaultTimeout(Duration.ofSeconds(5)); + + final int page = 100; + final int halfPage = page / 2; + + // in order to fully control demand and separate the prefetch mechanics, initially subscribe with a request of 0 + StepVerifier.create(allMessages, 0) + .expectSubscription() + .then(() -> verify(reactiveCommands, times(expectedReactiveCommandInvocations.get())).evalsha(any(), any(), + any(), any())) + .thenRequest(halfPage) // page 0.5 requested + .expectNextCount(halfPage) // page 0.5 produced + // page 0.5 produced, 1.5 remain, so no additional interactions with the cache cluster + .then(() -> verify(reactiveCommands, times(expectedReactiveCommandInvocations.get())).evalsha(any(), + any(), any(), any())) + .then(() -> assertNull(page4Sink.get(), "page 4 should not have been fetched yet")) + .thenRequest(page) // page 1.5 requested + .expectNextCount(page) // page 1.5 produced + + // we now have produced 1.5 pages, have 0.5 buffered, and two more have been prefetched. + // after producing more than a full page, we’ll need to replenish from the cache. + // future requests will depend on sink emitters. + // also NB: times() checks cumulative calls, hence addAndGet + .then(() -> verify(reactiveCommands, times(expectedReactiveCommandInvocations.addAndGet(1))).evalsha(any(), + any(), any(), any())) + .then(() -> assertNotNull(page4Sink.get(), "page 4 should have been fetched")) + .thenRequest(page + halfPage) // page 3 requested + .expectNextCount(page + halfPage) // page 1.5–3 produced + + .thenRequest(halfPage) // page 3.5 requested + .then(() -> assertNull(page56Sink.get(), "page 5 should not have been fetched yet")) + .then(() -> page4Sink.get().next(pages.pop()).complete()) + .expectNextCount(halfPage) // page 3.5 produced + .then(() -> verify(reactiveCommands, times(expectedReactiveCommandInvocations.addAndGet(1))).evalsha(any(), + any(), any(), any())) + .then(() -> assertNotNull(page56Sink.get(), "page 5 should have been fetched")) + + .thenRequest(page) // page 4.5 requested + .expectNextCount(halfPage) // page 4 produced + + .thenRequest(page * 4) // request more demand than we will ultimately satisfy + + .then(() -> page56Sink.get().next(pages.pop()).next(pages.pop()).complete()) + .expectNextCount(page + page) // page 5 and 6 produced + .then(() -> emptyFinalPageSink.get().complete()) + // confirm that cache calls increased by 2: one for page 5-and-6 (we got a two-fer in next(pop()).next(pop()), + // and one for the final, empty page + .then(() -> verify(reactiveCommands, times(expectedReactiveCommandInvocations.addAndGet(2))).evalsha(any(), + any(), any(), + any())) + .expectComplete() + .log() + .verify(); + + // make sure that we consumed all the pages, especially in case of refactoring + assertTrue(pages.isEmpty()); + } + + @Test + void testGetDiscardsEphemeralMessages() { + final Deque> pages = new ArrayDeque<>(); + pages.add(generatePage()); + pages.add(generatePage()); + pages.add(generateStaleEphemeralPage()); + + when(reactiveCommands.evalsha(any(), any(), any(), any())) + .thenReturn(Flux.just(pages.pop())) + .thenReturn(Flux.just(pages.pop())) + .thenReturn(Flux.just(pages.pop())) + .thenReturn(Flux.empty()); + + final AsyncCommand removeSuccess = new AsyncCommand<>(mock(RedisCommand.class)); + removeSuccess.complete(); + + when(asyncCommands.evalsha(any(), any(), any(), any())) + .thenReturn((RedisFuture) removeSuccess); + + final Publisher allMessages = messagesCache.get(UUID.randomUUID(), 1L); + + StepVerifier.setDefaultTimeout(Duration.ofSeconds(5)); + + // async commands are used for remove(), and nothing should happen until we are subscribed + verify(asyncCommands, never()).evalsha(any(), any(), any(byte[][].class), any(byte[].class)); + // the reactive commands will be called once, to prep the first page fetch (but no remote request would actually be sent) + verify(reactiveCommands, times(1)).evalsha(any(), any(), any(byte[][].class), any(byte[].class)); + + StepVerifier.create(allMessages) + .expectSubscription() + .expectNextCount(200) + .expectComplete() + .log() + .verify(); + + assertTrue(pages.isEmpty()); + verify(asyncCommands, atLeast(1)).evalsha(any(), any(), any(), any()); + } + + private List generatePage() { + final List messagesAndIds = new ArrayList<>(); + + for (int i = 0; i < 100; i++) { + final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID(), true); + messagesAndIds.add(envelope.toByteArray()); + messagesAndIds.add(String.valueOf(serialTimestamp).getBytes()); + } + + return messagesAndIds; + } + + private List generateStaleEphemeralPage() { + final List messagesAndIds = new ArrayList<>(); + + for (int i = 0; i < 100; i++) { + final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID(), true) + .toBuilder().setEphemeral(true).build(); + messagesAndIds.add(envelope.toByteArray()); + messagesAndIds.add(String.valueOf(serialTimestamp).getBytes()); + } + + return messagesAndIds; + } } private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) { @@ -234,194 +753,4 @@ class MessagesCacheTest { return envelopeBuilder.build(); } - - @Test - void testClearNullUuid() { - // We're happy as long as this doesn't throw an exception - messagesCache.clear(null); - } - - @Test - void testGetAccountFromQueueName() { - assertEquals(DESTINATION_UUID, - MessagesCache.getAccountUuidFromQueueName( - new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), - StandardCharsets.UTF_8))); - } - - @Test - void testGetDeviceIdFromQueueName() { - assertEquals(DESTINATION_DEVICE_ID, - MessagesCache.getDeviceIdFromQueueName( - new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), - StandardCharsets.UTF_8))); - } - - @Test - void testGetQueueNameFromKeyspaceChannel() { - assertEquals("1b363a31-a429-4fb6-8959-984a025e72ff::7", - MessagesCache.getQueueNameFromKeyspaceChannel( - "__keyspace@0__:user_queue::{1b363a31-a429-4fb6-8959-984a025e72ff::7}")); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - public void testGetQueuesToPersist(final boolean sealedSender) { - final UUID messageGuid = UUID.randomUUID(); - - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, - generateRandomMessage(messageGuid, sealedSender)); - final int slot = SlotHash.getSlot(DESTINATION_UUID + "::" + DESTINATION_DEVICE_ID); - - assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty()); - - final List queues = messagesCache.getQueuesToPersist(slot, Instant.now().plusSeconds(60), 100); - - assertEquals(1, queues.size()); - assertEquals(DESTINATION_UUID, MessagesCache.getAccountUuidFromQueueName(queues.get(0))); - assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.get(0))); - } - - @Test - void testNotifyListenerNewMessage() { - final AtomicBoolean notified = new AtomicBoolean(false); - final UUID messageGuid = UUID.randomUUID(); - - final MessageAvailabilityListener listener = new MessageAvailabilityListener() { - @Override - public boolean handleNewMessagesAvailable() { - synchronized (notified) { - notified.set(true); - notified.notifyAll(); - - return true; - } - } - - @Override - public boolean handleMessagesPersisted() { - return true; - } - }; - - assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { - messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, - generateRandomMessage(messageGuid, true)); - - synchronized (notified) { - while (!notified.get()) { - notified.wait(); - } - } - - assertTrue(notified.get()); - }); - } - - @Test - void testNotifyListenerPersisted() { - final AtomicBoolean notified = new AtomicBoolean(false); - - final MessageAvailabilityListener listener = new MessageAvailabilityListener() { - @Override - public boolean handleNewMessagesAvailable() { - return true; - } - - @Override - public boolean handleMessagesPersisted() { - synchronized (notified) { - notified.set(true); - notified.notifyAll(); - - return true; - } - } - }; - - assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { - messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); - - messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); - messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); - - synchronized (notified) { - while (!notified.get()) { - notified.wait(); - } - } - - assertTrue(notified.get()); - }); - } - - - /** - * Helper class that implements {@link MessageAvailabilityListener#handleNewMessagesAvailable()} by always returning - * {@code false}. Its {@code counter} field tracks how many times {@code handleNewMessagesAvailable} has been called. - *

- * It uses a {@link CompletableFuture} to signal that it has received a “messages available” callback for the first - * time. - */ - private static class NewMessagesAvailabilityClosedListener implements MessageAvailabilityListener { - - private int counter; - - private final Consumer messageHandledCallback; - private final CompletableFuture firstMessageHandled = new CompletableFuture<>(); - - private NewMessagesAvailabilityClosedListener(final Consumer messageHandledCallback) { - this.messageHandledCallback = messageHandledCallback; - } - - @Override - public boolean handleNewMessagesAvailable() { - counter++; - messageHandledCallback.accept(counter); - firstMessageHandled.complete(null); - - return false; - - } - - @Override - public boolean handleMessagesPersisted() { - return true; - } - } - - @Test - void testAvailabilityListenerResponses() { - final NewMessagesAvailabilityClosedListener listener1 = new NewMessagesAvailabilityClosedListener( - count -> assertEquals(1, count)); - final NewMessagesAvailabilityClosedListener listener2 = new NewMessagesAvailabilityClosedListener( - count -> assertEquals(1, count)); - - assertTimeoutPreemptively(Duration.ofSeconds(30), () -> { - messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener1); - final UUID messageGuid1 = UUID.randomUUID(); - messagesCache.insert(messageGuid1, DESTINATION_UUID, DESTINATION_DEVICE_ID, - generateRandomMessage(messageGuid1, true)); - - listener1.firstMessageHandled.get(); - - // Avoid a race condition by blocking on the message handled future *and* the current notification executor task— - // the notification executor task includes unsubscribing `listener1`, and, if we don’t wait, sometimes - // `listener2` will get subscribed before `listener1` is cleaned up - notificationExecutorService.submit(() -> listener1.firstMessageHandled.get()).get(); - - final UUID messageGuid2 = UUID.randomUUID(); - messagesCache.insert(messageGuid2, DESTINATION_UUID, DESTINATION_DEVICE_ID, - generateRandomMessage(messageGuid2, true)); - - messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener2); - - final UUID messageGuid3 = UUID.randomUUID(); - messagesCache.insert(messageGuid3, DESTINATION_UUID, DESTINATION_DEVICE_ID, - generateRandomMessage(messageGuid3, true)); - - listener2.firstMessageHandled.get(); - }); - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java index a98126059..eae1802ad 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java @@ -9,14 +9,26 @@ import static org.assertj.core.api.Assertions.assertThat; import com.google.protobuf.ByteString; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.tests.util.MessageHelper; import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; class MessagesDynamoDbTest { @@ -59,6 +71,7 @@ class MessagesDynamoDbTest { MESSAGE3 = builder.build(); } + private ExecutorService messageDeletionExecutorService; private MessagesDynamoDb messagesDynamoDb; @@ -67,8 +80,18 @@ class MessagesDynamoDbTest { @BeforeEach void setup() { - messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME, - Duration.ofDays(14)); + messageDeletionExecutorService = Executors.newSingleThreadExecutor(); + messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), + dynamoDbExtension.getDynamoDbAsyncClient(), MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(14), + messageDeletionExecutorService); + } + + @AfterEach + void teardown() throws Exception { + messageDeletionExecutorService.shutdown(); + messageDeletionExecutorService.awaitTermination(5, TimeUnit.SECONDS); + + StepVerifier.resetDefaultTimeout(); } @Test @@ -77,7 +100,7 @@ class MessagesDynamoDbTest { final int destinationDeviceId = random.nextInt(255) + 1; messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId); - final List messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId, + final List messagesStored = load(destinationUuid, destinationDeviceId, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE); assertThat(messagesStored).isNotNull().hasSize(3); final MessageProtos.Envelope firstMessage = @@ -88,6 +111,73 @@ class MessagesDynamoDbTest { assertThat(messagesStored).element(2).isEqualTo(MESSAGE2); } + @ParameterizedTest + @ValueSource(ints = {10, 100, 100, 1_000, 3_000}) + void testLoadManyAfterInsert(final int messageCount) { + final UUID destinationUuid = UUID.randomUUID(); + final int destinationDeviceId = random.nextInt(255) + 1; + + final List messages = new ArrayList<>(messageCount); + for (int i = 0; i < messageCount; i++) { + messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i)); + } + + messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); + + final Publisher fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, null); + + final long firstRequest = Math.min(10, messageCount); + StepVerifier.setDefaultTimeout(Duration.ofSeconds(15)); + + StepVerifier.Step step = StepVerifier.create(fetchedMessages, 0) + .expectSubscription() + .thenRequest(firstRequest) + .expectNextCount(firstRequest); + + if (messageCount > firstRequest) { + step = step.thenRequest(messageCount) + .expectNextCount(messageCount - firstRequest); + } + + step.thenCancel() + .verify(); + } + + @Test + void testLimitedLoad() { + final int messageCount = 200; + final UUID destinationUuid = UUID.randomUUID(); + final int destinationDeviceId = random.nextInt(255) + 1; + + final List messages = new ArrayList<>(messageCount); + for (int i = 0; i < messageCount; i++) { + messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i)); + } + + messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); + + final int messageLoadLimit = 100; + final int halfOfMessageLoadLimit = messageLoadLimit / 2; + final Publisher fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, messageLoadLimit); + + StepVerifier.setDefaultTimeout(Duration.ofSeconds(10)); + + final AtomicInteger messagesRemaining = new AtomicInteger(messageLoadLimit); + + StepVerifier.create(fetchedMessages, 0) + .expectSubscription() + .thenRequest(halfOfMessageLoadLimit) + .expectNextCount(halfOfMessageLoadLimit) + // the first 100 should be fetched and buffered, but further requests should fail + .then(() -> dynamoDbExtension.stopServer()) + .thenRequest(halfOfMessageLoadLimit) + .expectNextCount(halfOfMessageLoadLimit) + // we’ve consumed all the buffered messages, so a single request will fail + .thenRequest(1) + .expectError() + .verify(); + } + @Test void testDeleteForDestination() { final UUID destinationUuid = UUID.randomUUID(); @@ -96,18 +186,18 @@ class MessagesDynamoDbTest { messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); + assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); + assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); } @@ -119,71 +209,79 @@ class MessagesDynamoDbTest { messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); + assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); } @Test - void testDeleteMessageByDestinationAndGuid() { + void testDeleteMessageByDestinationAndGuid() throws Exception { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); messagesDynamoDb.deleteMessageByDestinationAndGuid(secondDestinationUuid, - UUID.fromString(MESSAGE2.getServerGuid())); + UUID.fromString(MESSAGE2.getServerGuid())).get(5, TimeUnit.SECONDS); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .isEmpty(); } @Test - void testDeleteSingleMessage() { + void testDeleteSingleMessage() throws Exception { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); messagesDynamoDb.deleteMessage(secondDestinationUuid, 1, - UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()); + UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()).get(1, TimeUnit.SECONDS); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .isEmpty(); } + + private List load(final UUID destinationUuid, final long destinationDeviceId, + final int count) { + return Flux.from(messagesDynamoDb.load(destinationUuid, destinationDeviceId, count)) + .take(count, true) + .collectList() + .block(); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java index 55d35fb84..109390fb3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java @@ -14,13 +14,11 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import java.util.UUID; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; -import org.whispersystems.textsecuregcm.push.PushLatencyManager; class MessagesManagerTest { private final MessagesDynamoDb messagesDynamoDb = mock(MessagesDynamoDb.class); private final MessagesCache messagesCache = mock(MessagesCache.class); - private final PushLatencyManager pushLatencyManager = mock(PushLatencyManager.class); private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/ProfilesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/ProfilesManagerTest.java index e3bc5ed44..894f12de0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/ProfilesManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/ProfilesManagerTest.java @@ -41,7 +41,7 @@ public class ProfilesManagerTest { void setUp() { //noinspection unchecked commands = mock(RedisAdvancedClusterCommands.class); - final FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands); + final FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.builder().stringCommands(commands).build(); profiles = mock(Profiles.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java new file mode 100644 index 000000000..0ff6e7856 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java @@ -0,0 +1,28 @@ +/* + * Copyright 2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.tests.util; + +import com.google.protobuf.ByteString; +import java.nio.charset.StandardCharsets; +import java.util.UUID; +import org.whispersystems.textsecuregcm.entities.MessageProtos; + +public class MessageHelper { + + public static MessageProtos.Envelope createMessage(UUID senderUuid, final int senderDeviceId, UUID destinationUuid, + long timestamp, String content) { + return MessageProtos.Envelope.newBuilder() + .setServerGuid(UUID.randomUUID().toString()) + .setType(MessageProtos.Envelope.Type.CIPHERTEXT) + .setTimestamp(timestamp) + .setServerTimestamp(0) + .setSourceUuid(senderUuid.toString()) + .setSourceDevice(senderDeviceId) + .setDestinationUuid(destinationUuid.toString()) + .setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8))) + .build(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java index 101a2575f..3112b2d48 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java @@ -5,70 +5,118 @@ package org.whispersystems.textsecuregcm.tests.util; -import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; -import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; -import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; - -import java.util.function.Consumer; -import java.util.function.Function; - import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; +import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import java.util.function.Consumer; +import java.util.function.Function; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; + public class RedisClusterHelper { - @SuppressWarnings("unchecked") - public static FaultTolerantRedisCluster buildMockRedisCluster(final RedisAdvancedClusterCommands stringCommands) { - return buildMockRedisCluster(stringCommands, mock(RedisAdvancedClusterCommands.class)); + public static RedisClusterHelper.Builder builder() { + return new Builder(); + } + + @SuppressWarnings("unchecked") + private static FaultTolerantRedisCluster buildMockRedisCluster( + final RedisAdvancedClusterCommands stringCommands, + final RedisAdvancedClusterCommands binaryCommands, + final RedisAdvancedClusterAsyncCommands binaryAsyncCommands, + final RedisAdvancedClusterReactiveCommands binaryReactiveCommands) { + final FaultTolerantRedisCluster cluster = mock(FaultTolerantRedisCluster.class); + final StatefulRedisClusterConnection stringConnection = mock(StatefulRedisClusterConnection.class); + final StatefulRedisClusterConnection binaryConnection = mock(StatefulRedisClusterConnection.class); + + when(stringConnection.sync()).thenReturn(stringCommands); + when(binaryConnection.sync()).thenReturn(binaryCommands); + when(binaryConnection.async()).thenReturn(binaryAsyncCommands); + when(binaryConnection.reactive()).thenReturn(binaryReactiveCommands); + + when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> { + return invocation.getArgument(0, Function.class).apply(stringConnection); + }); + + doAnswer(invocation -> { + invocation.getArgument(0, Consumer.class).accept(stringConnection); + return null; + }).when(cluster).useCluster(any(Consumer.class)); + + when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> { + return invocation.getArgument(0, Function.class).apply(stringConnection); + }); + + doAnswer(invocation -> { + invocation.getArgument(0, Consumer.class).accept(stringConnection); + return null; + }).when(cluster).useCluster(any(Consumer.class)); + + when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> { + return invocation.getArgument(0, Function.class).apply(binaryConnection); + }); + + doAnswer(invocation -> { + invocation.getArgument(0, Consumer.class).accept(binaryConnection); + return null; + }).when(cluster).useBinaryCluster(any(Consumer.class)); + + when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> { + return invocation.getArgument(0, Function.class).apply(binaryConnection); + }); + + doAnswer(invocation -> { + invocation.getArgument(0, Consumer.class).accept(binaryConnection); + return null; + }).when(cluster).useBinaryCluster(any(Consumer.class)); + + return cluster; + } + + @SuppressWarnings("unchecked") + public static class Builder { + + private RedisAdvancedClusterCommands stringCommands = mock(RedisAdvancedClusterCommands.class); + private RedisAdvancedClusterCommands binaryCommands = mock(RedisAdvancedClusterCommands.class); + private RedisAdvancedClusterAsyncCommands binaryAsyncCommands = mock( + RedisAdvancedClusterAsyncCommands.class); + private RedisAdvancedClusterReactiveCommands binaryReactiveCommands = mock( + RedisAdvancedClusterReactiveCommands.class); + + private Builder() { + } - @SuppressWarnings("unchecked") - public static FaultTolerantRedisCluster buildMockRedisCluster(final RedisAdvancedClusterCommands stringCommands, final RedisAdvancedClusterCommands binaryCommands) { - final FaultTolerantRedisCluster cluster = mock(FaultTolerantRedisCluster.class); - final StatefulRedisClusterConnection stringConnection = mock(StatefulRedisClusterConnection.class); - final StatefulRedisClusterConnection binaryConnection = mock(StatefulRedisClusterConnection.class); - - when(stringConnection.sync()).thenReturn(stringCommands); - when(binaryConnection.sync()).thenReturn(binaryCommands); - - when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> { - return invocation.getArgument(0, Function.class).apply(stringConnection); - }); - - doAnswer(invocation -> { - invocation.getArgument(0, Consumer.class).accept(stringConnection); - return null; - }).when(cluster).useCluster(any(Consumer.class)); - - when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> { - return invocation.getArgument(0, Function.class).apply(stringConnection); - }); - - doAnswer(invocation -> { - invocation.getArgument(0, Consumer.class).accept(stringConnection); - return null; - }).when(cluster).useCluster(any(Consumer.class)); - - when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> { - return invocation.getArgument(0, Function.class).apply(binaryConnection); - }); - - doAnswer(invocation -> { - invocation.getArgument(0, Consumer.class).accept(binaryConnection); - return null; - }).when(cluster).useBinaryCluster(any(Consumer.class)); - - when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> { - return invocation.getArgument(0, Function.class).apply(binaryConnection); - }); - - doAnswer(invocation -> { - invocation.getArgument(0, Consumer.class).accept(binaryConnection); - return null; - }).when(cluster).useBinaryCluster(any(Consumer.class)); - - return cluster; + public Builder stringCommands(final RedisAdvancedClusterCommands stringCommands) { + this.stringCommands = stringCommands; + return this; } + + public Builder binaryCommands(final RedisAdvancedClusterCommands binaryCommands) { + this.binaryCommands = binaryCommands; + return this; + } + + public Builder binaryAsyncCommands(final RedisAdvancedClusterAsyncCommands binaryAsyncCommands) { + this.binaryAsyncCommands = binaryAsyncCommands; + return this; + } + + public Builder binaryReactiveCommands( + final RedisAdvancedClusterReactiveCommands binaryReactiveCommands) { + this.binaryReactiveCommands = binaryReactiveCommands; + return this; + } + + public FaultTolerantRedisCluster build() { + return RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands, binaryAsyncCommands, + binaryReactiveCommands); + } + } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 59a85f38a..6a912dff2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2021 Signal Messenger, LLC + * Copyright 2013-2022 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; +import java.time.Clock; import java.time.Duration; import java.util.ArrayList; import java.util.List; @@ -36,8 +37,10 @@ import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; @@ -56,6 +59,7 @@ import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import reactor.core.scheduler.Schedulers; class WebSocketConnectionIntegrationTest { @@ -65,16 +69,13 @@ class WebSocketConnectionIntegrationTest { @RegisterExtension static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); - private static final int SEND_FUTURES_TIMEOUT_MILLIS = 100; - - private ExecutorService executorService; + private ExecutorService sharedExecutorService; private MessagesDynamoDb messagesDynamoDb; private MessagesCache messagesCache; private ReportMessageManager reportMessageManager; private Account account; private Device device; private WebSocketClient webSocketClient; - private WebSocketConnection webSocketConnection; private ScheduledExecutorService retrySchedulingExecutor; private long serialTimestamp = System.currentTimeMillis(); @@ -82,11 +83,12 @@ class WebSocketConnectionIntegrationTest { @BeforeEach void setUp() throws Exception { - executorService = Executors.newSingleThreadExecutor(); + sharedExecutorService = Executors.newSingleThreadExecutor(); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), - REDIS_CLUSTER_EXTENSION.getRedisCluster(), executorService); - messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME, - Duration.ofDays(7)); + REDIS_CLUSTER_EXTENSION.getRedisCluster(), Clock.systemUTC(), sharedExecutorService, sharedExecutorService); + messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), + dynamoDbExtension.getDynamoDbAsyncClient(), MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(7), + sharedExecutorService); reportMessageManager = mock(ReportMessageManager.class); account = mock(Account.class); device = mock(Device.class); @@ -96,30 +98,36 @@ class WebSocketConnectionIntegrationTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); when(device.getId()).thenReturn(1L); - - webSocketConnection = new WebSocketConnection( - mock(ReceiptSender.class), - new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager), - new AuthenticatedAccount(() -> new Pair<>(account, device)), - device, - webSocketClient, - SEND_FUTURES_TIMEOUT_MILLIS, - retrySchedulingExecutor); } @AfterEach void tearDown() throws Exception { - executorService.shutdown(); - executorService.awaitTermination(2, TimeUnit.SECONDS); + sharedExecutorService.shutdown(); + sharedExecutorService.awaitTermination(2, TimeUnit.SECONDS); retrySchedulingExecutor.shutdown(); retrySchedulingExecutor.awaitTermination(2, TimeUnit.SECONDS); } - @Test - void testProcessStoredMessages() { - final int persistedMessageCount = 207; - final int cachedMessageCount = 173; + @ParameterizedTest + @CsvSource({ + "207, 173, true", + "207, 173, false", + "323, 0, true", + "323, 0, false", + "0, 221, true", + "0, 221, false", + }) + void testProcessStoredMessages(final int persistedMessageCount, final int cachedMessageCount, + final boolean useReactive) { + final WebSocketConnection webSocketConnection = new WebSocketConnection( + mock(ReceiptSender.class), + new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager), + new AuthenticatedAccount(() -> new Pair<>(account, device)), + device, + webSocketClient, + retrySchedulingExecutor, + useReactive); final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); @@ -150,8 +158,8 @@ class WebSocketConnectionIntegrationTest { final AtomicBoolean queueCleared = new AtomicBoolean(false); when(successResponse.getStatus()).thenReturn(200); - when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn( - CompletableFuture.completedFuture(successResponse)); + when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())) + .thenReturn(CompletableFuture.completedFuture(successResponse)); when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer( (Answer>) invocation -> { @@ -194,8 +202,18 @@ class WebSocketConnectionIntegrationTest { }); } - @Test - void testProcessStoredMessagesClientClosed() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testProcessStoredMessagesClientClosed(final boolean useReactive) { + final WebSocketConnection webSocketConnection = new WebSocketConnection( + mock(ReceiptSender.class), + new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager), + new AuthenticatedAccount(() -> new Pair<>(account, device)), + device, + webSocketClient, + retrySchedulingExecutor, + useReactive); + final int persistedMessageCount = 207; final int cachedMessageCount = 173; @@ -250,8 +268,20 @@ class WebSocketConnectionIntegrationTest { }); } - @Test - void testProcessStoredMessagesSendFutureTimeout() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testProcessStoredMessagesSendFutureTimeout(final boolean useReactive) { + final WebSocketConnection webSocketConnection = new WebSocketConnection( + mock(ReceiptSender.class), + new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager), + new AuthenticatedAccount(() -> new Pair<>(account, device)), + device, + webSocketClient, + 100, // use a very short timeout, so that this test completes quickly + retrySchedulingExecutor, + useReactive, + Schedulers.boundedElastic()); + final int persistedMessageCount = 207; final int cachedMessageCount = 173; @@ -346,4 +376,5 @@ class WebSocketConnectionIntegrationTest { .setDestinationUuid(UUID.randomUUID().toString()) .build(); } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 5f4faa293..393526e78 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2021 Signal Messenger, LLC + * Copyright 2013-2022 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -12,6 +12,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyLong; @@ -42,17 +43,20 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -import org.apache.commons.lang3.RandomStringUtils; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentMatchers; -import org.mockito.invocation.InvocationOnMock; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; -import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -65,6 +69,10 @@ import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.session.WebSocketSessionContext; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; class WebSocketConnectionTest { @@ -83,7 +91,6 @@ class WebSocketConnectionTest { private AuthenticatedAccount auth; private UpgradeRequest upgradeRequest; private ReceiptSender receiptSender; - private PushNotificationManager pushNotificationManager; private ScheduledExecutorService retrySchedulingExecutor; @BeforeEach @@ -95,17 +102,21 @@ class WebSocketConnectionTest { auth = new AuthenticatedAccount(() -> new Pair<>(account, device)); upgradeRequest = mock(UpgradeRequest.class); receiptSender = mock(ReceiptSender.class); - pushNotificationManager = mock(PushNotificationManager.class); retrySchedulingExecutor = mock(ScheduledExecutorService.class); } + @AfterEach + void teardown() { + StepVerifier.resetDefaultTimeout(); + } + @Test void testCredentials() { MessagesManager storedMessages = mock(MessagesManager.class); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, mock(PushNotificationManager.class), mock(ClientPresenceManager.class), - retrySchedulingExecutor); + retrySchedulingExecutor, mock(ExperimentEnrollmentManager.class)); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) @@ -114,7 +125,6 @@ class WebSocketConnectionTest { when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) .thenReturn(Optional.empty()); - when(upgradeRequest.getParameterMap()).thenReturn(Map.of( "login", List.of(VALID_USER), "password", List.of(VALID_PASSWORD))); @@ -136,8 +146,9 @@ class WebSocketConnectionTest { assertTrue(account.isRequired()); } - @Test - void testOpen() throws Exception { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testOpen(final boolean useReactive) throws Exception { MessagesManager storedMessages = mock(MessagesManager.class); UUID accountUuid = UUID.randomUUID(); @@ -166,29 +177,31 @@ class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false)) - .thenReturn(new Pair<>(outgoingMessages, false)); + if (useReactive) { + when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + .thenReturn(Flux.fromIterable(outgoingMessages)); + } else { + when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false)) + .thenReturn(new Pair<>(outgoingMessages, false)); + } final List> futures = new LinkedList<>(); - final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketClient client = mock(WebSocketClient.class); when(client.getUserAgent()).thenReturn(userAgent); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any())) - .thenAnswer(new Answer>() { - @Override - public CompletableFuture answer(InvocationOnMock invocationOnMock) { - CompletableFuture future = new CompletableFuture<>(); - futures.add(future); - return future; - } + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), nullable(List.class), any())) + .thenAnswer(invocation -> { + CompletableFuture future = new CompletableFuture<>(); + futures.add(future); + return future; }); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, - auth, device, client, retrySchedulingExecutor); + auth, device, client, retrySchedulingExecutor, useReactive, Schedulers.immediate()); connection.start(); verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), - ArgumentMatchers.>any()); + any()); assertEquals(3, futures.size()); @@ -208,12 +221,13 @@ class WebSocketConnectionTest { verify(client).close(anyInt(), anyString()); } - @Test - public void testOnlineSend() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testOnlineSend(final boolean useReactive) { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, - retrySchedulingExecutor); + retrySchedulingExecutor, useReactive, Schedulers.immediate()); final UUID accountUuid = UUID.randomUUID(); @@ -222,24 +236,36 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(1L); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean())) - .thenReturn(new Pair<>(Collections.emptyList(), false)) - .thenReturn(new Pair<>(List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first")), false)) - .thenReturn(new Pair<>(List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")), false)); + if (useReactive) { + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean())) + .thenReturn(Flux.empty()) + .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"))) + .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second"))) + .thenReturn(Flux.empty()); + } else { + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean())) + .thenReturn(new Pair<>(Collections.emptyList(), false)) + .thenReturn(new Pair<>(List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first")), + false)) + .thenReturn(new Pair<>(List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")), + false)) + .thenReturn(new Pair<>(Collections.emptyList(), false)); + } final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); final AtomicInteger sendCounter = new AtomicInteger(0); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { - synchronized (sendCounter) { - sendCounter.incrementAndGet(); - sendCounter.notifyAll(); - } + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) + .thenAnswer(invocation -> { + synchronized (sendCounter) { + sendCounter.incrementAndGet(); + sendCounter.notifyAll(); + } - return CompletableFuture.completedFuture(successResponse); - }); + return CompletableFuture.completedFuture(successResponse); + }); assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { // This is a little hacky and non-obvious, but because the first call to getMessagesForDevice returns empty list of @@ -269,9 +295,10 @@ class WebSocketConnectionTest { verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); } - @Test - void testPendingSend() throws Exception { - MessagesManager storedMessages = mock(MessagesManager.class); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testPendingSend(final boolean useReactive) throws Exception { + MessagesManager storedMessages = mock(MessagesManager.class); final UUID accountUuid = UUID.randomUUID(); final UUID senderTwoUuid = UUID.randomUUID(); @@ -311,15 +338,20 @@ class WebSocketConnectionTest { when(sender1.getDevices()).thenReturn(sender1devices); when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); - when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); + when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false)) - .thenReturn(new Pair<>(pendingMessages, false)); + if (useReactive) { + when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + .thenReturn(Flux.fromIterable(pendingMessages)); + } else { + when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false)) + .thenReturn(new Pair<>(pendingMessages, false)); + } final List> futures = new LinkedList<>(); - final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketClient client = mock(WebSocketClient.class); when(client.getUserAgent()).thenReturn(userAgent); when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) @@ -330,7 +362,7 @@ class WebSocketConnectionTest { }); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, - auth, device, client, retrySchedulingExecutor); + auth, device, client, retrySchedulingExecutor, useReactive, Schedulers.immediate()); connection.start(); @@ -350,12 +382,13 @@ class WebSocketConnectionTest { verify(client).close(anyInt(), anyString()); } - @Test - void testProcessStoredMessageConcurrency() throws InterruptedException { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testProcessStoredMessageConcurrency(final boolean useReactive) { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, - retrySchedulingExecutor); + retrySchedulingExecutor, useReactive, Schedulers.immediate()); when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); @@ -365,26 +398,45 @@ class WebSocketConnectionTest { final AtomicBoolean threadWaiting = new AtomicBoolean(false); final AtomicBoolean returnMessageList = new AtomicBoolean(false); - when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false)).thenAnswer( - (Answer) invocation -> { - synchronized (threadWaiting) { - threadWaiting.set(true); - threadWaiting.notifyAll(); - } + if (useReactive) { + when( + messagesManager.getMessagesForDeviceReactive(account.getUuid(), 1L, false)) + .thenAnswer(invocation -> { + synchronized (threadWaiting) { + threadWaiting.set(true); + threadWaiting.notifyAll(); + } - synchronized (returnMessageList) { - while (!returnMessageList.get()) { - returnMessageList.wait(); - } - } + synchronized (returnMessageList) { + while (!returnMessageList.get()) { + returnMessageList.wait(); + } + } - return new OutgoingMessageEntityList(Collections.emptyList(), false); - }); + return Flux.empty(); + }); + } else { + when( + messagesManager.getMessagesForDevice(account.getUuid(), 1L, false)) + .thenAnswer(invocation -> { + synchronized (threadWaiting) { + threadWaiting.set(true); + threadWaiting.notifyAll(); + } - final Thread[] threads = new Thread[10]; + synchronized (returnMessageList) { + while (!returnMessageList.get()) { + returnMessageList.wait(); + } + } + + return new Pair<>(Collections.emptyList(), false); + }); + } + + final Thread[] threads = new Thread[10]; final CountDownLatch unblockedThreadsLatch = new CountDownLatch(threads.length - 1); - assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { for (int i = 0; i < threads.length; i++) { threads[i] = new Thread(() -> { @@ -413,18 +465,24 @@ class WebSocketConnectionTest { } }); - verify(messagesManager).getMessagesForDevice(any(UUID.class), anyLong(), eq(false)); + if (useReactive) { + verify(messagesManager).getMessagesForDeviceReactive(any(UUID.class), anyLong(), eq(false)); + } else { + verify(messagesManager).getMessagesForDevice(any(UUID.class), anyLong(), eq(false)); + } } - @Test - void testProcessStoredMessagesMultiplePages() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testProcessStoredMessagesMultiplePages(final boolean useReactive) { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, - retrySchedulingExecutor); + retrySchedulingExecutor, useReactive, Schedulers.immediate()); when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(UUID.randomUUID()); + final UUID accountUuid = UUID.randomUUID(); + when(account.getUuid()).thenReturn(accountUuid); when(device.getId()).thenReturn(1L); when(client.isOpen()).thenReturn(true); @@ -435,39 +493,56 @@ class WebSocketConnectionTest { final List secondPageMessages = List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); - when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false)) - .thenReturn(new Pair<>(firstPageMessages, true)) - .thenReturn(new Pair<>(secondPageMessages, false)); + if (useReactive) { + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), eq(false))) + .thenReturn(Flux.fromStream(Stream.concat(firstPageMessages.stream(), secondPageMessages.stream()))); + } else { + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq(false))) + .thenReturn(new Pair<>(firstPageMessages, true)) + .thenReturn(new Pair<>(secondPageMessages, false)); + } + + when(messagesManager.delete(eq(accountUuid), eq(1L), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); - final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size()); + final CountDownLatch queueEmptyLatch = new CountDownLatch(1); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { - sendLatch.countDown(); - return CompletableFuture.completedFuture(successResponse); - }); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) + .thenAnswer(invocation -> { + return CompletableFuture.completedFuture(successResponse); + }); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) + .thenAnswer(invocation -> { + queueEmptyLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { connection.processStoredMessages(); - sendLatch.await(); + queueEmptyLatch.await(); }); - verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); + verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), + eq("/api/v1/message"), any(List.class), any(Optional.class)); verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } - @Test - void testProcessStoredMessagesContainsSenderUuid() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testProcessStoredMessagesContainsSenderUuid(final boolean useReactive) { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, - retrySchedulingExecutor); + retrySchedulingExecutor, useReactive, Schedulers.immediate()); when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(UUID.randomUUID()); + final UUID accountUuid = UUID.randomUUID(); + when(account.getUuid()).thenReturn(accountUuid); when(device.getId()).thenReturn(1L); when(client.isOpen()).thenReturn(true); @@ -475,50 +550,65 @@ class WebSocketConnectionTest { final List messages = List.of( createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first")); - when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false)) - .thenReturn(new Pair<>(messages, false)); + if (useReactive) { + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), 1L, false)) + .thenReturn(Flux.fromIterable(messages)) + .thenReturn(Flux.empty()); + } else { + when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false)) + .thenReturn(new Pair<>(messages, false)) + .thenReturn(new Pair<>(Collections.emptyList(), false)); + } + + when(messagesManager.delete(eq(accountUuid), eq(1L), any(UUID.class), any())) + .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); - final CountDownLatch sendLatch = new CountDownLatch(messages.size()); + final CountDownLatch queueEmptyLatch = new CountDownLatch(1); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer(invocation -> { - sendLatch.countDown(); - return CompletableFuture.completedFuture(successResponse); - }); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer( + invocation -> CompletableFuture.completedFuture(successResponse)); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) + .thenAnswer(invocation -> { + queueEmptyLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { connection.processStoredMessages(); - - sendLatch.await(); + queueEmptyLatch.await(); }); - verify(client, times(messages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), argThat(argument -> { - if (argument.isEmpty()) { - return false; - } + verify(client, times(messages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), + argThat(argument -> { + if (argument.isEmpty()) { + return false; + } - final byte[] body = argument.get(); - try { - final Envelope envelope = Envelope.parseFrom(body); - if (!envelope.hasSourceUuid() || envelope.getSourceUuid().length() == 0) { - return false; - } - return envelope.getSourceUuid().equals(senderUuid.toString()); - } catch (InvalidProtocolBufferException e) { - return false; - } - })); + final byte[] body = argument.get(); + try { + final Envelope envelope = Envelope.parseFrom(body); + if (!envelope.hasSourceUuid() || envelope.getSourceUuid().length() == 0) { + return false; + } + return envelope.getSourceUuid().equals(senderUuid.toString()); + } catch (InvalidProtocolBufferException e) { + return false; + } + })); verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } - @Test - void testProcessStoredMessagesSingleEmptyCall() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testProcessStoredMessagesSingleEmptyCall(final boolean useReactive) { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, - retrySchedulingExecutor); + retrySchedulingExecutor, useReactive, Schedulers.immediate()); final UUID accountUuid = UUID.randomUUID(); @@ -527,8 +617,13 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(1L); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean())) - .thenReturn(new Pair<>(Collections.emptyList(), false)); + if (useReactive) { + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean())) + .thenReturn(Flux.empty()); + } else { + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean())) + .thenReturn(new Pair<>(Collections.emptyList(), false)); + } final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -543,12 +638,13 @@ class WebSocketConnectionTest { verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } - @Test - public void testRequeryOnStateMismatch() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testRequeryOnStateMismatch(final boolean useReactive) { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, - retrySchedulingExecutor); + retrySchedulingExecutor, useReactive, Schedulers.immediate()); final UUID accountUuid = UUID.randomUUID(); when(account.getNumber()).thenReturn("+18005551234"); @@ -563,39 +659,57 @@ class WebSocketConnectionTest { final List secondPageMessages = List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); - when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean())) - .thenReturn(new Pair<>(firstPageMessages, false)) - .thenReturn(new Pair<>(secondPageMessages, false)) - .thenReturn(new Pair<>(Collections.emptyList(), false)); + if (useReactive) { + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean())) + .thenReturn(Flux.fromIterable(firstPageMessages)) + .thenReturn(Flux.fromIterable(secondPageMessages)) + .thenReturn(Flux.empty()); + } else { + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean())) + .thenReturn(new Pair<>(firstPageMessages, false)) + .thenReturn(new Pair<>(secondPageMessages, false)) + .thenReturn(new Pair<>(Collections.emptyList(), false)); + } + + when(messagesManager.delete(eq(accountUuid), eq(1L), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); - final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size()); + final CountDownLatch queueEmptyLatch = new CountDownLatch(1); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { - connection.handleNewMessagesAvailable(); - sendLatch.countDown(); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) + .thenAnswer(invocation -> { + connection.handleNewMessagesAvailable(); - return CompletableFuture.completedFuture(successResponse); - }); + return CompletableFuture.completedFuture(successResponse); + }); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) + .thenAnswer(invocation -> { + queueEmptyLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { connection.processStoredMessages(); - sendLatch.await(); + queueEmptyLatch.await(); }); - verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); + verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), + eq("/api/v1/message"), any(List.class), any(Optional.class)); verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } - @Test - void testProcessCachedMessagesOnly() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testProcessCachedMessagesOnly(final boolean useReactive) { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, - retrySchedulingExecutor); + retrySchedulingExecutor, useReactive, Schedulers.immediate()); final UUID accountUuid = UUID.randomUUID(); @@ -604,8 +718,13 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(1L); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean())) - .thenReturn(new Pair<>(Collections.emptyList(), false)); + if (useReactive) { + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean())) + .thenReturn(Flux.empty()); + } else { + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean())) + .thenReturn(new Pair<>(Collections.emptyList(), false)); + } final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -616,19 +735,28 @@ class WebSocketConnectionTest { // anything. connection.processStoredMessages(); - verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), false); + if (useReactive) { + verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device.getId(), false); + } else { + verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), false); + } connection.handleNewMessagesAvailable(); - verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), true); + if (useReactive) { + verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device.getId(), true); + } else { + verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), true); + } } - @Test - void testProcessDatabaseMessagesAfterPersist() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testProcessDatabaseMessagesAfterPersist(final boolean useReactive) { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, - retrySchedulingExecutor); + retrySchedulingExecutor, useReactive, Schedulers.immediate()); final UUID accountUuid = UUID.randomUUID(); @@ -637,8 +765,13 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(1L); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean())) - .thenReturn(new Pair<>(Collections.emptyList(), false)); + if (useReactive) { + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean())) + .thenReturn(Flux.empty()); + } else { + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean())) + .thenReturn(new Pair<>(Collections.emptyList(), false)); + } final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -650,151 +783,16 @@ class WebSocketConnectionTest { connection.processStoredMessages(); connection.handleMessagesPersisted(); - verify(messagesManager, times(2)).getMessagesForDevice(account.getUuid(), device.getId(), false); + if (useReactive) { + verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getUuid(), device.getId(), false); + } else { + verify(messagesManager, times(2)).getMessagesForDevice(account.getUuid(), device.getId(), false); + } } - @Test - void testDiscardOversizedMessagesForDesktop() { - MessagesManager storedMessages = mock(MessagesManager.class); - - UUID accountUuid = UUID.randomUUID(); - UUID senderOneUuid = UUID.randomUUID(); - UUID senderTwoUuid = UUID.randomUUID(); - - List outgoingMessages = List.of( - createMessage(senderOneUuid, UUID.randomUUID(), 1111, "first"), - createMessage(senderOneUuid, UUID.randomUUID(), 2222, - RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)), - createMessage(senderTwoUuid, UUID.randomUUID(), 3333, "third")); - - when(device.getId()).thenReturn(2L); - - when(account.getNumber()).thenReturn("+14152222222"); - when(account.getUuid()).thenReturn(accountUuid); - - final Device sender1device = mock(Device.class); - - List sender1devices = List.of(sender1device); - - Account sender1 = mock(Account.class); - when(sender1.getDevices()).thenReturn(sender1devices); - - when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); - when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); - - String userAgent = "Signal-Desktop/1.2.3"; - - when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false)) - .thenReturn(new Pair<>(outgoingMessages, false)); - - final List> futures = new LinkedList<>(); - final WebSocketClient client = mock(WebSocketClient.class); - - when(client.getUserAgent()).thenReturn(userAgent); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), - ArgumentMatchers.>any())) - .thenAnswer(new Answer>() { - @Override - public CompletableFuture answer(InvocationOnMock invocationOnMock) { - CompletableFuture future = new CompletableFuture<>(); - futures.add(future); - return future; - } - }); - - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, - retrySchedulingExecutor); - - connection.start(); - verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), - ArgumentMatchers.>any()); - - assertEquals(2, futures.size()); - - WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); - when(response.getStatus()).thenReturn(200); - futures.get(0).complete(response); - futures.get(1).complete(response); - - // We should delete all three messages even though we only sent two; one got discarded because it was too big for - // desktop clients. - verify(storedMessages, times(3)).delete(eq(accountUuid), eq(2L), any(UUID.class), any(Long.class)); - - connection.stop(); - verify(client).close(anyInt(), anyString()); - } - - @Test - void testSendOversizedMessagesForNonDesktop() { - MessagesManager storedMessages = mock(MessagesManager.class); - - UUID accountUuid = UUID.randomUUID(); - UUID senderOneUuid = UUID.randomUUID(); - UUID senderTwoUuid = UUID.randomUUID(); - - List outgoingMessages = List.of(createMessage(senderOneUuid, UUID.randomUUID(), 1111, "first"), - createMessage(senderOneUuid, UUID.randomUUID(), 2222, - RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)), - createMessage(senderTwoUuid, UUID.randomUUID(), 3333, "third")); - - when(device.getId()).thenReturn(2L); - - when(account.getNumber()).thenReturn("+14152222222"); - when(account.getUuid()).thenReturn(accountUuid); - - final Device sender1device = mock(Device.class); - - List sender1devices = List.of(sender1device); - - Account sender1 = mock(Account.class); - when(sender1.getDevices()).thenReturn(sender1devices); - - when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); - when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); - - String userAgent = "Signal-Android/4.68.3"; - - when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false)) - .thenReturn(new Pair<>(outgoingMessages, false)); - - final List> futures = new LinkedList<>(); - final WebSocketClient client = mock(WebSocketClient.class); - - when(client.getUserAgent()).thenReturn(userAgent); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), - ArgumentMatchers.>any())) - .thenAnswer(new Answer>() { - @Override - public CompletableFuture answer(InvocationOnMock invocationOnMock) { - CompletableFuture future = new CompletableFuture<>(); - futures.add(future); - return future; - } - }); - - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, - retrySchedulingExecutor); - - connection.start(); - verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), - ArgumentMatchers.>any()); - - assertEquals(3, futures.size()); - - WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); - when(response.getStatus()).thenReturn(200); - futures.get(0).complete(response); - futures.get(1).complete(response); - futures.get(2).complete(response); - - verify(storedMessages, times(3)).delete(eq(accountUuid), eq(2L), any(UUID.class), any(Long.class)); - - connection.stop(); - verify(client).close(anyInt(), anyString()); - } - - @Test - void testRetrieveMessageException() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testRetrieveMessageException(final boolean useReactive) { MessagesManager storedMessages = mock(MessagesManager.class); UUID accountUuid = UUID.randomUUID(); @@ -804,10 +802,13 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); - String userAgent = "Signal-Android/4.68.3"; - - when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false)) - .thenThrow(new RedisException("OH NO")); + if (useReactive) { + when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + .thenReturn(Flux.error(new RedisException("OH NO"))); + } else { + when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false)) + .thenThrow(new RedisException("OH NO")); + } when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer( (Answer>) invocation -> { @@ -819,7 +820,7 @@ class WebSocketConnectionTest { when(client.isOpen()).thenReturn(true); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, - retrySchedulingExecutor); + retrySchedulingExecutor, useReactive, Schedulers.immediate()); connection.start(); verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class), @@ -827,8 +828,9 @@ class WebSocketConnectionTest { verify(client).close(eq(1011), anyString()); } - @Test - void testRetrieveMessageExceptionClientDisconnected() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testRetrieveMessageExceptionClientDisconnected(final boolean useReactive) { MessagesManager storedMessages = mock(MessagesManager.class); UUID accountUuid = UUID.randomUUID(); @@ -838,22 +840,143 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); - String userAgent = "Signal-Android/4.68.3"; - - when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false)) - .thenThrow(new RedisException("OH NO")); + if (useReactive) { + when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + .thenReturn(Flux.error(new RedisException("OH NO"))); + } else { + when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false)) + .thenThrow(new RedisException("OH NO")); + } final WebSocketClient client = mock(WebSocketClient.class); when(client.isOpen()).thenReturn(false); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, - retrySchedulingExecutor); + retrySchedulingExecutor, useReactive, Schedulers.immediate()); connection.start(); verify(retrySchedulingExecutor, never()).schedule(any(Runnable.class), anyLong(), any()); verify(client, never()).close(anyInt(), anyString()); } + @Test + @Disabled("This test is flaky") + void testReactivePublisherLimitRate() { + MessagesManager storedMessages = mock(MessagesManager.class); + + final UUID accountUuid = UUID.randomUUID(); + + final long deviceId = 2L; + when(device.getId()).thenReturn(deviceId); + + when(account.getNumber()).thenReturn("+14152222222"); + when(account.getUuid()).thenReturn(accountUuid); + + final int totalMessages = 10; + final AtomicReference> sink = new AtomicReference<>(); + + final AtomicLong maxRequest = new AtomicLong(-1); + final Flux flux = Flux.create(s -> { + sink.set(s); + s.onRequest(n -> { + if (maxRequest.get() < n) { + maxRequest.set(n); + } + }); + }); + + when(storedMessages.getMessagesForDeviceReactive(eq(accountUuid), eq(deviceId), anyBoolean())) + .thenReturn(flux); + + final WebSocketClient client = mock(WebSocketClient.class); + when(client.isOpen()).thenReturn(true); + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); + when(storedMessages.delete(any(), anyLong(), any(), any())).thenReturn( + CompletableFuture.completedFuture(Optional.empty())); + + WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, + retrySchedulingExecutor, true); + + connection.start(); + + StepVerifier.setDefaultTimeout(Duration.ofSeconds(5)); + + StepVerifier.create(flux, 0) + .expectSubscription() + .thenRequest(totalMessages * 2) + .then(() -> { + for (long i = 0; i < totalMessages; i++) { + sink.get().next(createMessage(UUID.randomUUID(), accountUuid, 1111 * i + 1, "message " + i)); + } + sink.get().complete(); + }) + .expectNextCount(totalMessages) + .expectComplete() + .log() + .verify(); + + assertEquals(WebSocketConnection.MESSAGE_PUBLISHER_LIMIT_RATE, maxRequest.get()); + } + + @Test + void testReactivePublisherDisposedWhenConnectionStopped() { + MessagesManager storedMessages = mock(MessagesManager.class); + + final UUID accountUuid = UUID.randomUUID(); + + final long deviceId = 2L; + when(device.getId()).thenReturn(deviceId); + + when(account.getNumber()).thenReturn("+14152222222"); + when(account.getUuid()).thenReturn(accountUuid); + + final AtomicBoolean canceled = new AtomicBoolean(); + + final Flux flux = Flux.create(s -> { + s.onRequest(n -> { + // the subscriber should request more than 1 message, but we will only send one, so that + // we are sure the subscriber is waiting for more when we stop the connection + assert n > 1; + s.next(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first")); + }); + + s.onCancel(() -> canceled.set(true)); + }); + when(storedMessages.getMessagesForDeviceReactive(eq(accountUuid), eq(deviceId), anyBoolean())) + .thenReturn(flux); + + final WebSocketClient client = mock(WebSocketClient.class); + when(client.isOpen()).thenReturn(true); + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); + when(storedMessages.delete(any(), anyLong(), any(), any())).thenReturn( + CompletableFuture.completedFuture(Optional.empty())); + + WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, + retrySchedulingExecutor, true, Schedulers.immediate()); + + connection.start(); + + verify(client).sendRequest(any(), any(), any(), any()); + + // close the connection before the publisher completes + connection.stop(); + + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + + StepVerifier.create(flux) + .expectSubscription() + .expectNextCount(1) + .then(() -> assertTrue(canceled.get())) + // this is not entirely intuitive, but expecting a timeout is the recommendation for verifying cancellation + .expectTimeout(Duration.ofMillis(100)) + .log() + .verify(); + } + private Envelope createMessage(UUID senderUuid, UUID destinationUuid, long timestamp, String content) { return Envelope.newBuilder() .setServerGuid(UUID.randomUUID().toString()) diff --git a/service/src/test/resources/logback-test.xml b/service/src/test/resources/logback-test.xml index 1c2c5d01c..b01f95f92 100644 --- a/service/src/test/resources/logback-test.xml +++ b/service/src/test/resources/logback-test.xml @@ -1,11 +1,14 @@ - - - %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n - - + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + - - - + + + + + +