From 01743e5c8868ca75eedf5472852d214550c9d195 Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer <125505367+jkt-signal@users.noreply.github.com> Date: Tue, 4 Jun 2024 12:16:43 -0700 Subject: [PATCH] Delete messages lazily on account and device deletion to prevent timeouts when deleting accounts/devices with large queues --- .../textsecuregcm/WhisperServerService.java | 1 + .../dynamic/DynamicConfiguration.java | 8 + .../dynamic/DynamicMessagesConfiguration.java | 25 ++ .../controllers/MessageController.java | 4 +- .../storage/MessagePersister.java | 14 +- .../storage/MessagesDynamoDb.java | 152 +++++++++--- .../storage/MessagesManager.java | 24 +- .../websocket/WebSocketConnection.java | 6 +- .../workers/CommandDependencies.java | 1 + .../controllers/MessageControllerTest.java | 14 +- .../MessagePersisterIntegrationTest.java | 5 +- .../storage/MessagePersisterTest.java | 53 +++-- .../storage/MessagesDynamoDbTest.java | 216 +++++++++++++----- .../WebSocketConnectionIntegrationTest.java | 12 +- .../websocket/WebSocketConnectionTest.java | 56 ++--- 15 files changed, 415 insertions(+), 176 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index b2610222c..e4a835b83 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -394,6 +394,7 @@ public class WhisperServerService extends Application getExperimentEnrollmentConfiguration( final String experimentName) { return Optional.ofNullable(experiments.get(experimentName)); @@ -121,4 +125,8 @@ public class DynamicConfiguration { return metricsConfiguration; } + public DynamicMessagesConfiguration getMessagesConfiguration() { + return messagesConfiguration; + } + } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java new file mode 100644 index 000000000..726e9a27c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessagesConfiguration.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration.dynamic; + +import java.util.List; + +import javax.validation.constraints.NotNull; + +public record DynamicMessagesConfiguration(@NotNull List dynamoKeySchemes) { + public enum DynamoKeyScheme { + TRADITIONAL, + LAZY_DELETION; + } + + public DynamicMessagesConfiguration() { + this(List.of(DynamoKeyScheme.TRADITIONAL)); + } + + public DynamoKeyScheme writeKeyScheme() { + return dynamoKeySchemes().getLast(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index d6674a359..276ffabef 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -724,7 +724,7 @@ public class MessageController { return messagesManager.getMessagesForDevice( auth.getAccount().getUuid(), - auth.getAuthenticatedDevice().getId(), + auth.getAuthenticatedDevice(), false) .map(messagesAndHasMore -> { Stream envelopes = messagesAndHasMore.first().stream(); @@ -768,7 +768,7 @@ public class MessageController { public CompletableFuture removePendingMessage(@ReadOnly @Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) { return messagesManager.delete( auth.getAccount().getUuid(), - auth.getAuthenticatedDevice().getId(), + auth.getAuthenticatedDevice(), uuid, null) .thenAccept(maybeDeletedMessage -> { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index 4d68b1fd3..d189f352d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -161,8 +161,13 @@ public class MessagePersister implements Managed { logger.error("No account record found for account {}", accountUuid); continue; } + final Optional maybeDevice = maybeAccount.flatMap(account -> account.getDevice(deviceId)); + if (maybeDevice.isEmpty()) { + logger.error("Account {} does not have a device with id {}", accountUuid, deviceId); + continue; + } try { - persistQueue(maybeAccount.get(), deviceId); + persistQueue(maybeAccount.get(), maybeDevice.get()); } catch (final Exception e) { persistQueueExceptionMeter.increment(); logger.warn("Failed to persist queue {}::{}; will schedule for retry", accountUuid, deviceId, e); @@ -180,8 +185,9 @@ public class MessagePersister implements Managed { } @VisibleForTesting - void persistQueue(final Account account, final byte deviceId) throws MessagePersistenceException { + void persistQueue(final Account account, final Device device) throws MessagePersistenceException { final UUID accountUuid = account.getUuid(); + final byte deviceId = device.getId(); final Timer.Sample sample = Timer.start(); @@ -196,7 +202,7 @@ public class MessagePersister implements Managed { do { messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT); - int messagesRemovedFromCache = messagesManager.persistMessages(accountUuid, deviceId, messages); + int messagesRemovedFromCache = messagesManager.persistMessages(accountUuid, device, messages); messageCount += messages.size(); if (messagesRemovedFromCache == 0) { @@ -246,7 +252,7 @@ public class MessagePersister implements Managed { .filter(d -> !d.isPrimary()) .flatMap(d -> messagesManager - .getEarliestUndeliveredTimestampForDevice(account.getUuid(), d.getId()) + .getEarliestUndeliveredTimestampForDevice(account.getUuid(), d) .map(t -> Tuples.of(d, t))) .sort(Comparator.comparing(Tuple2::getT2)) .map(Tuple2::getT1) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java index 9f2bd5bd1..8080a18b1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java @@ -11,6 +11,9 @@ import static io.micrometer.core.instrument.Metrics.timer; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.protobuf.InvalidProtocolBufferException; + +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Timer; import java.nio.ByteBuffer; import java.time.Duration; @@ -26,12 +29,16 @@ import java.util.function.Predicate; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration.DynamoKeyScheme; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.util.AttributeValues; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; +import reactor.util.function.Tuples; +import reactor.util.function.Tuple2; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; @@ -62,46 +69,53 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { private final Timer storeTimer = timer(name(getClass(), "store")); private final String DELETE_BY_ACCOUNT_TIMER_NAME = name(getClass(), "delete", "account"); private final String DELETE_BY_DEVICE_TIMER_NAME = name(getClass(), "delete", "device"); + private final String MESSAGES_STORED_BY_SCHEME_COUNTER_NAME = name(getClass(), "messagesStored"); + private final String MESSAGES_LOADED_BY_SCHEME_COUNTER_NAME = name(getClass(), "messagesLoaded"); + private final String MESSAGES_DELETED_BY_SCHEME_COUNTER_NAME = name(getClass(), "messagesDeleted"); private final DynamoDbAsyncClient dbAsyncClient; private final String tableName; private final Duration timeToLive; + private final DynamicConfigurationManager dynamicConfig; private final ExecutorService messageDeletionExecutor; private final Scheduler messageDeletionScheduler; private static final Logger logger = LoggerFactory.getLogger(MessagesDynamoDb.class); public MessagesDynamoDb(DynamoDbClient dynamoDb, DynamoDbAsyncClient dynamoDbAsyncClient, String tableName, - Duration timeToLive, ExecutorService messageDeletionExecutor) { + Duration timeToLive, DynamicConfigurationManager dynamicConfig, ExecutorService messageDeletionExecutor) { super(dynamoDb); this.dbAsyncClient = dynamoDbAsyncClient; this.tableName = tableName; this.timeToLive = timeToLive; + this.dynamicConfig = dynamicConfig; this.messageDeletionExecutor = messageDeletionExecutor; this.messageDeletionScheduler = Schedulers.fromExecutor(messageDeletionExecutor); } public void store(final List messages, final UUID destinationAccountUuid, - final byte destinationDeviceId) { - storeTimer.record(() -> writeInBatches(messages, (messageBatch) -> storeBatch(messageBatch, destinationAccountUuid, destinationDeviceId))); + final Device destinationDevice) { + storeTimer.record(() -> writeInBatches(messages, (messageBatch) -> storeBatch(messageBatch, destinationAccountUuid, destinationDevice))); } private void storeBatch(final List messages, final UUID destinationAccountUuid, - final byte destinationDeviceId) { + final Device destinationDevice) { + final byte destinationDeviceId = destinationDevice.getId(); if (messages.size() > DYNAMO_DB_MAX_BATCH_SIZE) { throw new IllegalArgumentException("Maximum batch size of " + DYNAMO_DB_MAX_BATCH_SIZE + " exceeded with " + messages.size() + " messages"); } - final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); + final DynamoKeyScheme scheme = dynamicConfig.getConfiguration().getMessagesConfiguration().writeKeyScheme(); + final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid, destinationDevice, scheme); List writeItems = new ArrayList<>(); for (MessageProtos.Envelope message : messages) { final UUID messageUuid = UUID.fromString(message.getServerGuid()); final ImmutableMap.Builder item = ImmutableMap.builder() .put(KEY_PARTITION, partitionKey) - .put(KEY_SORT, convertSortKey(destinationDeviceId, message.getServerTimestamp(), messageUuid)) + .put(KEY_SORT, convertSortKey(destinationDevice.getId(), message.getServerTimestamp(), messageUuid, scheme)) .put(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT, convertLocalIndexMessageUuidSortKey(messageUuid)) .put(KEY_TTL, AttributeValues.fromLong(getTtlForMessage(message))) .put(KEY_ENVELOPE_BYTES, AttributeValue.builder().b(SdkBytes.fromByteArray(message.toByteArray())).build()); @@ -112,22 +126,43 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { } executeTableWriteItemsUntilComplete(Map.of(tableName, writeItems)); + Metrics.counter(MESSAGES_STORED_BY_SCHEME_COUNTER_NAME, Tags.of("scheme", scheme.name())).increment(writeItems.size()); } - public Publisher load(final UUID destinationAccountUuid, final byte destinationDeviceId, - final Integer limit) { + public Publisher load(final UUID destinationAccountUuid, final Device device, final Integer limit) { + return Flux.concat( + dynamicConfig.getConfiguration().getMessagesConfiguration().dynamoKeySchemes() + .stream() + .map(scheme -> load(destinationAccountUuid, device, limit, scheme)) + .toList()) + .map(messageAndScheme -> { + Metrics.counter(MESSAGES_LOADED_BY_SCHEME_COUNTER_NAME, Tags.of("scheme", messageAndScheme.getT2().name())).increment(); + return messageAndScheme.getT1(); + }); + } - final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); - final QueryRequest.Builder queryRequestBuilder = QueryRequest.builder() + private Publisher> load(final UUID destinationAccountUuid, final Device device, final Integer limit, final DynamoKeyScheme scheme) { + final byte destinationDeviceId = device.getId(); + + final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid, device, scheme); + QueryRequest.Builder queryRequestBuilder = QueryRequest.builder() .tableName(tableName) - .consistentRead(true) - .keyConditionExpression("#part = :part AND begins_with ( #sort , :sortprefix )") - .expressionAttributeNames(Map.of( - "#part", KEY_PARTITION, - "#sort", KEY_SORT)) - .expressionAttributeValues(Map.of( - ":part", partitionKey, - ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId))); + .consistentRead(true); + + queryRequestBuilder = switch (scheme) { + case TRADITIONAL -> queryRequestBuilder + .keyConditionExpression("#part = :part AND begins_with ( #sort , :sortprefix )") + .expressionAttributeNames(Map.of( + "#part", KEY_PARTITION, + "#sort", KEY_SORT)) + .expressionAttributeValues(Map.of( + ":part", partitionKey, + ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId, scheme))); + case LAZY_DELETION -> queryRequestBuilder + .keyConditionExpression("#part = :part") + .expressionAttributeNames(Map.of("#part", KEY_PARTITION)) + .expressionAttributeValues(Map.of(":part", partitionKey)); + }; if (limit != null) { // some callers don’t take advantage of reactive streams, so we want to support limiting the fetch size. Otherwise, @@ -146,13 +181,25 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { return null; } }) - .filter(Predicate.not(Objects::isNull)); + .filter(Predicate.not(Objects::isNull)) + .map(m -> Tuples.of(m, scheme)); } public CompletableFuture> deleteMessageByDestinationAndGuid( - final UUID destinationAccountUuid, final UUID messageUuid) { + final UUID destinationAccountUuid, final Device destinationDevice, final UUID messageUuid) { + return dynamicConfig.getConfiguration().getMessagesConfiguration().dynamoKeySchemes() + .stream() + .map(scheme -> deleteMessageByDestinationAndGuid(destinationAccountUuid, destinationDevice, messageUuid, scheme)) + // this combines the futures by producing a future that returns an arbitrary nonempty + // result if there is one, which should be OK because only one of the keying schemes + // should produce a nonempty result for any given message uuid + .reduce((f, g) -> f.thenCombine(g, (a, b) -> a.or(() -> b))) + .get(); // there is always at least one scheme + } - final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); + private CompletableFuture> deleteMessageByDestinationAndGuid( + final UUID destinationAccountUuid, final Device destinationDevice, final UUID messageUuid, DynamoKeyScheme scheme) { + final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid, destinationDevice, scheme); final QueryRequest queryRequest = QueryRequest.builder() .tableName(tableName) .indexName(LOCAL_INDEX_MESSAGE_UUID_NAME) @@ -179,6 +226,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { .mapNotNull(deleteItemResponse -> { try { if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { + Metrics.counter(MESSAGES_DELETED_BY_SCHEME_COUNTER_NAME, Tags.of("scheme", scheme.name())).increment(); return convertItemToEnvelope(deleteItemResponse.attributes()); } } catch (final InvalidProtocolBufferException e) { @@ -193,10 +241,21 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { } public CompletableFuture> deleteMessage(final UUID destinationAccountUuid, - final byte destinationDeviceId, final UUID messageUuid, final long serverTimestamp) { + final Device destinationDevice, final UUID messageUuid, final long serverTimestamp) { + return dynamicConfig.getConfiguration().getMessagesConfiguration().dynamoKeySchemes() + .stream() + .map(scheme -> deleteMessage(destinationAccountUuid, destinationDevice, messageUuid, serverTimestamp, scheme)) + // this combines the futures by producing a future that returns an arbitrary nonempty + // result if there is one, which should be OK because only one of the keying schemes + // should produce a nonempty result for any given message uuid + .reduce((f, g) -> f.thenCombine(g, (a, b) -> a.or(() -> b))) + .orElseThrow(); // there is always at least one scheme + } - final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); - final AttributeValue sortKey = convertSortKey(destinationDeviceId, serverTimestamp, messageUuid); + private CompletableFuture> deleteMessage(final UUID destinationAccountUuid, + final Device destinationDevice, final UUID messageUuid, final long serverTimestamp, final DynamoKeyScheme scheme) { + final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid, destinationDevice, scheme); + final AttributeValue sortKey = convertSortKey(destinationDevice.getId(), serverTimestamp, messageUuid, scheme); DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder() .tableName(tableName) .key(Map.of(KEY_PARTITION, partitionKey, KEY_SORT, sortKey)) @@ -216,10 +275,12 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { }, messageDeletionExecutor); } + // Deletes all messages stored for the supplied account that were stored under the traditional (uuid+device id) keying scheme. + // Messages stored under the lazy-message-deletion keying scheme will not be affected. public CompletableFuture deleteAllMessagesForAccount(final UUID destinationAccountUuid) { final Timer.Sample sample = Timer.start(); - final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); + final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid, null, DynamoKeyScheme.TRADITIONAL); return Flux.from(dbAsyncClient.queryPaginator(QueryRequest.builder() .tableName(tableName) @@ -243,10 +304,13 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { .toFuture(); } + // Deletes all messages stored for the supplied account and device that were stored under the + // traditional (uuid+device id) keying scheme. Messages stored under the lazy-message-deletion + // keying scheme will not be affected. public CompletableFuture deleteAllMessagesForDevice(final UUID destinationAccountUuid, final byte destinationDeviceId) { final Timer.Sample sample = Timer.start(); - final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); + final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid, null, DynamoKeyScheme.TRADITIONAL); return Flux.from(dbAsyncClient.queryPaginator(QueryRequest.builder() .tableName(tableName) @@ -256,7 +320,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { "#sort", KEY_SORT)) .expressionAttributeValues(Map.of( ":part", partitionKey, - ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId))) + ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId, DynamoKeyScheme.TRADITIONAL))) .projectionExpression(KEY_SORT) .consistentRead(true) .build()) @@ -285,26 +349,38 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { return message.getServerTimestamp() / 1000 + timeToLive.getSeconds(); } - private static AttributeValue convertPartitionKey(final UUID destinationAccountUuid) { - return AttributeValues.fromUUID(destinationAccountUuid); + private static AttributeValue convertPartitionKey(final UUID destinationAccountUuid, final Device destinationDevice, final DynamoKeyScheme scheme) { + return switch (scheme) { + case TRADITIONAL -> AttributeValues.fromUUID(destinationAccountUuid); + case LAZY_DELETION -> { + final ByteBuffer byteBuffer = ByteBuffer.allocate(24); + byteBuffer.putLong(destinationAccountUuid.getMostSignificantBits()); + byteBuffer.putLong(destinationAccountUuid.getLeastSignificantBits()); + byteBuffer.putLong(destinationDevice.getCreated() & ~0x7f + destinationDevice.getId()); + yield AttributeValues.fromByteBuffer(byteBuffer.flip()); + } + }; } private static AttributeValue convertSortKey(final byte destinationDeviceId, final long serverTimestamp, - final UUID messageUuid) { - ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[32]); - // for compatibility - destinationDeviceId was previously `long` - byteBuffer.putLong(destinationDeviceId); + final UUID messageUuid, final DynamoKeyScheme scheme) { + + final ByteBuffer byteBuffer = ByteBuffer.allocate(32); + if (scheme == DynamoKeyScheme.TRADITIONAL) { + // for compatibility - destinationDeviceId was previously `long` + byteBuffer.putLong(destinationDeviceId); + } byteBuffer.putLong(serverTimestamp); byteBuffer.putLong(messageUuid.getMostSignificantBits()); byteBuffer.putLong(messageUuid.getLeastSignificantBits()); return AttributeValues.fromByteBuffer(byteBuffer.flip()); } - private static AttributeValue convertDestinationDeviceIdToSortKeyPrefix(final byte destinationDeviceId) { - ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]); - // for compatibility - destinationDeviceId was previously `long` - byteBuffer.putLong(destinationDeviceId); - return AttributeValues.fromByteBuffer(byteBuffer.flip()); + private static AttributeValue convertDestinationDeviceIdToSortKeyPrefix(final byte destinationDeviceId, final DynamoKeyScheme scheme) { + return switch (scheme) { + case TRADITIONAL -> AttributeValues.fromByteBuffer(ByteBuffer.allocate(8).putLong(destinationDeviceId).flip()); + case LAZY_DELETION -> AttributeValues.b(new byte[0]); + }; } private static AttributeValue convertLocalIndexMessageUuidSortKey(final UUID messageUuid) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 9e1f4955f..d07930211 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -67,7 +67,7 @@ public class MessagesManager { return messagesCache.hasMessages(destinationUuid, destinationDevice); } - public Mono, Boolean>> getMessagesForDevice(UUID destinationUuid, byte destinationDevice, + public Mono, Boolean>> getMessagesForDevice(UUID destinationUuid, Device destinationDevice, boolean cachedMessagesOnly) { return Flux.from( @@ -77,25 +77,25 @@ public class MessagesManager { .map(envelopes -> new Pair<>(envelopes, envelopes.size() >= RESULT_SET_CHUNK_SIZE)); } - public Publisher getMessagesForDeviceReactive(UUID destinationUuid, byte destinationDevice, + public Publisher getMessagesForDeviceReactive(UUID destinationUuid, Device destinationDevice, final boolean cachedMessagesOnly) { return getMessagesForDevice(destinationUuid, destinationDevice, null, cachedMessagesOnly); } - private Publisher getMessagesForDevice(UUID destinationUuid, byte destinationDevice, + private Publisher getMessagesForDevice(UUID destinationUuid, Device destinationDevice, @Nullable Integer limit, final boolean cachedMessagesOnly) { final Publisher dynamoPublisher = cachedMessagesOnly ? Flux.empty() : messagesDynamoDb.load(destinationUuid, destinationDevice, limit); - final Publisher cachePublisher = messagesCache.get(destinationUuid, destinationDevice); + final Publisher cachePublisher = messagesCache.get(destinationUuid, destinationDevice.getId()); return Flux.concat(dynamoPublisher, cachePublisher) .name(GET_MESSAGES_FOR_DEVICE_FLUX_NAME) .tap(Micrometer.metrics(Metrics.globalRegistry)); } - public Mono getEarliestUndeliveredTimestampForDevice(UUID destinationUuid, byte destinationDevice) { + public Mono getEarliestUndeliveredTimestampForDevice(UUID destinationUuid, Device destinationDevice) { return Mono.from(messagesDynamoDb.load(destinationUuid, destinationDevice, 1)).map(Envelope::getServerTimestamp); } @@ -111,9 +111,9 @@ public class MessagesManager { messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, deviceId)); } - public CompletableFuture> delete(UUID destinationUuid, byte destinationDeviceId, UUID guid, + public CompletableFuture> delete(UUID destinationUuid, Device destinationDevice, UUID guid, @Nullable Long serverTimestamp) { - return messagesCache.remove(destinationUuid, destinationDeviceId, guid) + return messagesCache.remove(destinationUuid, destinationDevice.getId(), guid) .thenComposeAsync(removed -> { if (removed.isPresent()) { @@ -121,9 +121,9 @@ public class MessagesManager { } if (serverTimestamp == null) { - return messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, guid); + return messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, destinationDevice, guid); } else { - return messagesDynamoDb.deleteMessage(destinationUuid, destinationDeviceId, guid, serverTimestamp); + return messagesDynamoDb.deleteMessage(destinationUuid, destinationDevice, guid, serverTimestamp); } }, messageDeletionExecutor); @@ -134,20 +134,20 @@ public class MessagesManager { */ public int persistMessages( final UUID destinationUuid, - final byte destinationDeviceId, + final Device destinationDevice, final List messages) { final List nonEphemeralMessages = messages.stream() .filter(envelope -> !envelope.getEphemeral()) .collect(Collectors.toList()); - messagesDynamoDb.store(nonEphemeralMessages, destinationUuid, destinationDeviceId); + messagesDynamoDb.store(nonEphemeralMessages, destinationUuid, destinationDevice); final List messageGuids = messages.stream().map(message -> UUID.fromString(message.getServerGuid())) .collect(Collectors.toList()); int messagesRemovedFromCache = 0; try { - messagesRemovedFromCache = messagesCache.remove(destinationUuid, destinationDeviceId, messageGuids) + messagesRemovedFromCache = messagesCache.remove(destinationUuid, destinationDevice.getId(), messageGuids) .get(30, TimeUnit.SECONDS).size(); PERSIST_MESSAGE_COUNTER.increment(nonEphemeralMessages.size()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 44923458f..bf709623e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -228,7 +228,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac final CompletableFuture result; if (isSuccessResponse(response)) { - result = messagesManager.delete(auth.getAccount().getUuid(), device.getId(), + result = messagesManager.delete(auth.getAccount().getUuid(), device, storedMessageInfo.guid(), storedMessageInfo.serverTimestamp()) .thenApply(ignored -> null); @@ -355,7 +355,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private void sendMessages(final boolean cachedMessagesOnly, final CompletableFuture queueCleared) { final Publisher messages = - messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), device.getId(), cachedMessagesOnly); + messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), device, cachedMessagesOnly); final AtomicBoolean hasErrored = new AtomicBoolean(); @@ -414,7 +414,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); if (envelope.getStory() && !client.shouldDeliverStories()) { - messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp()); + messagesManager.delete(auth.getAccount().getUuid(), device, messageGuid, envelope.getServerTimestamp()); return CompletableFuture.completedFuture(null); } else { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index 07598850f..5c85f692a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -172,6 +172,7 @@ record CommandDependencies( MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, configuration.getDynamoDbTables().getMessages().getTableName(), configuration.getDynamoDbTables().getMessages().getExpiration(), + dynamicConfigurationManager, messageDeletionExecutor); FaultTolerantRedisCluster messagesCluster = configuration.getMessageCacheConfiguration() .getRedisClusterConfiguration().build("messages", redisClientResourcesBuilder); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index e5f3e8969..735609d8f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -136,6 +136,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Pair; @@ -649,7 +650,7 @@ class MessageControllerTest { AuthHelper.VALID_UUID, null, null, 0, true) ); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq((byte) 1), anyBoolean())) + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), anyBoolean())) .thenReturn(Mono.just(new Pair<>(envelopes, false))); final String userAgent = "Test-UA"; @@ -703,7 +704,7 @@ class MessageControllerTest { UUID.randomUUID(), (byte) 2, AuthHelper.VALID_UUID, null, null, 0) ); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq((byte) 1), anyBoolean())) + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), anyBoolean())) .thenReturn(Mono.just(new Pair<>(messages, false))); Response response = @@ -723,24 +724,25 @@ class MessageControllerTest { UUID sourceUuid = UUID.randomUUID(); UUID uuid1 = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid1, null)) + + when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid1, null)) .thenReturn( CompletableFutureTestUtil.almostCompletedFuture(Optional.of(generateEnvelope(uuid1, Envelope.Type.CIPHERTEXT_VALUE, timestamp, sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0)))); UUID uuid2 = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid2, null)) + when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid2, null)) .thenReturn( CompletableFutureTestUtil.almostCompletedFuture(Optional.of(generateEnvelope( uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, System.currentTimeMillis(), sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, null, 0)))); UUID uuid3 = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid3, null)) + when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid3, null)) .thenReturn(CompletableFutureTestUtil.almostCompletedFuture(Optional.empty())); UUID uuid4 = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid4, null)) + when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid4, null)) .thenReturn(CompletableFuture.failedFuture(new RuntimeException("Oh No"))); Response response = resources.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index 46de756d0..e714beeaf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -35,6 +35,8 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; +import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; + import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; @@ -74,7 +76,7 @@ class MessagePersisterIntegrationTest { messageDeletionExecutorService = Executors.newSingleThreadExecutor(); final MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.MESSAGES.tableName(), Duration.ofDays(14), - messageDeletionExecutorService); + dynamicConfigurationManager, messageDeletionExecutorService); final AccountsManager accountsManager = mock(AccountsManager.class); notificationExecutorService = Executors.newSingleThreadExecutor(); @@ -93,6 +95,7 @@ class MessagePersisterIntegrationTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account)); + when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(DevicesHelper.createDevice(Device.PRIMARY_ID))); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index b5977c6c5..4bfb7bbcc 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -46,6 +46,8 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; + import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; @@ -71,6 +73,7 @@ class MessagePersisterTest { private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID(); private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234"; private static final byte DESTINATION_DEVICE_ID = 7; + private static final Device DESTINATION_DEVICE = DevicesHelper.createDevice(DESTINATION_DEVICE_ID); private static final Duration PERSIST_DELAY = Duration.ofMinutes(5); @@ -93,6 +96,7 @@ class MessagePersisterTest { when(destinationAccount.getUuid()).thenReturn(DESTINATION_ACCOUNT_UUID); when(destinationAccount.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER); + when(destinationAccount.getDevice(DESTINATION_DEVICE_ID)).thenReturn(Optional.of(DESTINATION_DEVICE)); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); sharedExecutorService = Executors.newSingleThreadExecutor(); @@ -103,15 +107,15 @@ class MessagePersisterTest { messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, clientPresenceManager, keysManager, dynamicConfigurationManager, PERSIST_DELAY, 1, MoreExecutors.newDirectExecutorService()); - when(messagesManager.persistMessages(any(UUID.class), anyByte(), any())).thenAnswer(invocation -> { + when(messagesManager.persistMessages(any(UUID.class), any(), any())).thenAnswer(invocation -> { final UUID destinationUuid = invocation.getArgument(0); - final byte destinationDeviceId = invocation.getArgument(1); + final Device destinationDevice = invocation.getArgument(1); final List messages = invocation.getArgument(2); - messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); + messagesDynamoDb.store(messages, destinationUuid, destinationDevice); for (final MessageProtos.Envelope message : messages) { - messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())).get(); + messagesCache.remove(destinationUuid, destinationDevice.getId(), UUID.fromString(message.getServerGuid())).get(); } return messages.size(); @@ -150,7 +154,7 @@ class MessagePersisterTest { final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID), - eq(DESTINATION_DEVICE_ID)); + eq(DESTINATION_DEVICE)); assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); } @@ -166,7 +170,7 @@ class MessagePersisterTest { messagePersister.persistNextQueues(now); - verify(messagesDynamoDb, never()).store(any(), any(), anyByte()); + verify(messagesDynamoDb, never()).store(any(), any(), any()); } @Test @@ -187,6 +191,7 @@ class MessagePersisterTest { when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account)); when(account.getUuid()).thenReturn(accountUuid); when(account.getNumber()).thenReturn(accountNumber); + when(account.getDevice(anyByte())).thenAnswer(invocation -> Optional.of(DevicesHelper.createDevice(invocation.getArgument(0)))); insertMessages(accountUuid, deviceId, messagesPerQueue, now); } @@ -197,7 +202,7 @@ class MessagePersisterTest { final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); - verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyByte()); + verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), any()); assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); } @@ -213,7 +218,7 @@ class MessagePersisterTest { doAnswer((Answer) invocation -> { throw new RuntimeException("OH NO."); - }).when(messagesDynamoDb).store(any(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID)); + }).when(messagesDynamoDb).store(any(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE)); messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); @@ -233,11 +238,11 @@ class MessagePersisterTest { setNextSlotToPersist(SlotHash.getSlot(queueName)); // returning `0` indicates something not working correctly - when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenReturn(0); + when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenReturn(0); assertTimeoutPreemptively(Duration.ofSeconds(1), () -> assertThrows(MessagePersistenceException.class, - () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID))); + () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE))); } @Test @@ -273,12 +278,12 @@ class MessagePersisterTest { when(destinationAccount.getDevices()).thenReturn(List.of(primary, activeA, inactiveB, inactiveC, activeD, destination)); - when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); + when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.deleteSingleUsePreKeys(any(), eq(inactiveId))).thenReturn(CompletableFuture.completedFuture(null)); assertTimeoutPreemptively(Duration.ofSeconds(1), () -> - messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID)); + messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, inactiveId); } @@ -298,37 +303,37 @@ class MessagePersisterTest { when(primary.getId()).thenReturn(primaryId); when(primary.isPrimary()).thenReturn(true); when(primary.isEnabled()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(primaryId))) + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(primary))) .thenReturn(Mono.just(4L)); final Device deviceA = mock(Device.class); final byte deviceIdA = 2; when(deviceA.getId()).thenReturn(deviceIdA); when(deviceA.isEnabled()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdA))) + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceA))) .thenReturn(Mono.empty()); final Device deviceB = mock(Device.class); final byte deviceIdB = 3; when(deviceB.getId()).thenReturn(deviceIdB); when(deviceB.isEnabled()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdB))) + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceB))) .thenReturn(Mono.just(2L)); final Device destination = mock(Device.class); when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID); when(destination.isEnabled()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(DESTINATION_DEVICE_ID))) + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(destination))) .thenReturn(Mono.just(5L)); when(destinationAccount.getDevices()).thenReturn(List.of(primary, deviceA, deviceB, destination)); - when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); + when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.deleteSingleUsePreKeys(any(), eq(deviceIdB))).thenReturn(CompletableFuture.completedFuture(null)); assertTimeoutPreemptively(Duration.ofSeconds(1), () -> - messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID)); + messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, deviceIdB); } @@ -348,37 +353,37 @@ class MessagePersisterTest { when(primary.getId()).thenReturn(primaryId); when(primary.isPrimary()).thenReturn(true); when(primary.isEnabled()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(primaryId))) + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(primary))) .thenReturn(Mono.just(1L)); final Device deviceA = mock(Device.class); final byte deviceIdA = 2; when(deviceA.getId()).thenReturn(deviceIdA); when(deviceA.isEnabled()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdA))) + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceA))) .thenReturn(Mono.just(3L)); final Device deviceB = mock(Device.class); final byte deviceIdB = 2; when(deviceB.getId()).thenReturn(deviceIdB); when(deviceB.isEnabled()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdB))) + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceB))) .thenReturn(Mono.empty()); final Device destination = mock(Device.class); when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID); when(destination.isEnabled()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(DESTINATION_DEVICE_ID))) + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(destination))) .thenReturn(Mono.just(2L)); when(destinationAccount.getDevices()).thenReturn(List.of(primary, deviceA, deviceB, destination)); - when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); + when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.deleteSingleUsePreKeys(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); assertTimeoutPreemptively(Duration.ofSeconds(1), () -> - messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID)); + messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java index 29e731004..e22c57ec9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java @@ -6,6 +6,8 @@ package org.whispersystems.textsecuregcm.storage; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import com.google.protobuf.ByteString; import java.time.Duration; @@ -25,9 +27,12 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.reactivestreams.Publisher; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; +import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.MessageHelper; +import org.whispersystems.textsecuregcm.util.SystemMapper; import reactor.core.publisher.Flux; import reactor.test.StepVerifier; @@ -73,18 +78,20 @@ class MessagesDynamoDbTest { } private ExecutorService messageDeletionExecutorService; + private DynamicConfigurationManager dynamicConfigurationManager; private MessagesDynamoDb messagesDynamoDb; - @RegisterExtension static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(Tables.MESSAGES); @BeforeEach void setup() { messageDeletionExecutorService = Executors.newSingleThreadExecutor(); + dynamicConfigurationManager = mock(DynamicConfigurationManager.class); + when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.MESSAGES.tableName(), Duration.ofDays(14), - messageDeletionExecutorService); + dynamicConfigurationManager, messageDeletionExecutorService); } @AfterEach @@ -99,9 +106,11 @@ class MessagesDynamoDbTest { void testSimpleFetchAfterInsert() { final UUID destinationUuid = UUID.randomUUID(); final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1); - messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId); + final Device destinationDevice = DevicesHelper.createDevice(destinationDeviceId); - final List messagesStored = load(destinationUuid, destinationDeviceId, + messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDevice); + + final List messagesStored = load(destinationUuid, destinationDevice, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE); assertThat(messagesStored).isNotNull().hasSize(3); final MessageProtos.Envelope firstMessage = @@ -117,6 +126,7 @@ class MessagesDynamoDbTest { void testLoadManyAfterInsert(final int messageCount) { final UUID destinationUuid = UUID.randomUUID(); final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1); + final Device destinationDevice = DevicesHelper.createDevice(destinationDeviceId); final List messages = new ArrayList<>(messageCount); for (int i = 0; i < messageCount; i++) { @@ -124,9 +134,9 @@ class MessagesDynamoDbTest { "message " + i)); } - messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); + messagesDynamoDb.store(messages, destinationUuid, destinationDevice); - final Publisher fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, null); + final Publisher fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDevice, null); final long firstRequest = Math.min(10, messageCount); StepVerifier.setDefaultTimeout(Duration.ofSeconds(15)); @@ -150,6 +160,7 @@ class MessagesDynamoDbTest { final int messageCount = 200; final UUID destinationUuid = UUID.randomUUID(); final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1); + final Device destinationDevice = DevicesHelper.createDevice(destinationDeviceId); final List messages = new ArrayList<>(messageCount); for (int i = 0; i < messageCount; i++) { @@ -157,11 +168,11 @@ class MessagesDynamoDbTest { "message " + i)); } - messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); + messagesDynamoDb.store(messages, destinationUuid, destinationDevice); final int messageLoadLimit = 100; final int halfOfMessageLoadLimit = messageLoadLimit / 2; - final Publisher fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, + final Publisher fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDevice, messageLoadLimit); StepVerifier.setDefaultTimeout(Duration.ofSeconds(10)); @@ -186,23 +197,25 @@ class MessagesDynamoDbTest { void testDeleteForDestination() { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); - messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID); - messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID); - final byte deviceId2 = 2; - messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, deviceId2); + final Device primary = DevicesHelper.createDevice((byte) 1); + final Device device2 = DevicesHelper.createDevice((byte) 2); - assertThat(load(destinationUuid, (byte) 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, primary); + messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, primary); + messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, device2); + + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, deviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, device2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, (byte) 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid).join(); - assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); - assertThat(load(destinationUuid, deviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); - assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); + assertThat(load(destinationUuid, device2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); + assertThat(load(secondDestinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); } @@ -210,26 +223,28 @@ class MessagesDynamoDbTest { void testDeleteForDestinationDevice() { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); - messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID); - messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID); - final byte destinationDeviceId2 = 2; - messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2); + final Device primary = DevicesHelper.createDevice((byte) 1); + final Device device2 = DevicesHelper.createDevice((byte) 2); - assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, primary); + messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, primary); + messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, device2); + + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, device2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); - messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, destinationDeviceId2).join(); + messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, device2.getId()).join(); - assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, device2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .isEmpty(); - assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); } @@ -237,35 +252,37 @@ class MessagesDynamoDbTest { void testDeleteMessageByDestinationAndGuid() throws Exception { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); - messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID); - messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID); - final byte destinationDeviceId2 = 2; - messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2); + final Device primary = DevicesHelper.createDevice((byte) 1); + final Device device2 = DevicesHelper.createDevice((byte) 2); - assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, primary); + messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, primary); + messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, device2); + + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, device2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); final Optional deletedMessage = messagesDynamoDb.deleteMessageByDestinationAndGuid( - secondDestinationUuid, + secondDestinationUuid, primary, UUID.fromString(MESSAGE2.getServerGuid())).get(5, TimeUnit.SECONDS); assertThat(deletedMessage).isPresent(); - assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, device2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .isEmpty(); final Optional alreadyDeletedMessage = messagesDynamoDb.deleteMessageByDestinationAndGuid( - secondDestinationUuid, + secondDestinationUuid, primary, UUID.fromString(MESSAGE2.getServerGuid())).get(5, TimeUnit.SECONDS); assertThat(alreadyDeletedMessage).isNotPresent(); @@ -276,36 +293,127 @@ class MessagesDynamoDbTest { void testDeleteSingleMessage() throws Exception { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); - messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID); - messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID); - final byte destinationDeviceId2 = 2; - messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2); + final Device primary = DevicesHelper.createDevice((byte) 1); + final Device device2 = DevicesHelper.createDevice((byte) 2); - assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, primary); + messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, primary); + messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, device2); + + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, device2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); - messagesDynamoDb.deleteMessage(secondDestinationUuid, Device.PRIMARY_ID, + messagesDynamoDb.deleteMessage(secondDestinationUuid, primary, UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()).get(1, TimeUnit.SECONDS); - assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, device2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .isEmpty(); } - private List load(final UUID destinationUuid, final byte destinationDeviceId, + private List load(final UUID destinationUuid, final Device destinationDevice, final int count) { - return Flux.from(messagesDynamoDb.load(destinationUuid, destinationDeviceId, count)) + return Flux.from(messagesDynamoDb.load(destinationUuid, destinationDevice, count)) .take(count, true) .collectList() .block(); } + + @Test + void testMessageKeySchemeMigration() throws Exception { + final UUID destinationUuid = UUID.randomUUID(); + final Device primary = DevicesHelper.createDevice((byte) 1); + + // store message 1 in old scheme + when(dynamicConfigurationManager.getConfiguration()).thenReturn(SystemMapper.yamlMapper().readValue(""" + messagesConfiguration: + dynamoKeySchemes: + - TRADITIONAL + """, DynamicConfiguration.class)); + messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, primary); + + // store message 2 in new scheme during migration + when(dynamicConfigurationManager.getConfiguration()).thenReturn(SystemMapper.yamlMapper().readValue(""" + messagesConfiguration: + dynamoKeySchemes: + - TRADITIONAL + - LAZY_DELETION + """, DynamicConfiguration.class)); + messagesDynamoDb.store(List.of(MESSAGE2), destinationUuid, primary); + + // in old scheme, we should only get message 1 back (we would never actually do this, it's just a way to prove we used the new scheme for message 2) + when(dynamicConfigurationManager.getConfiguration()).thenReturn(SystemMapper.yamlMapper().readValue(""" + messagesConfiguration: + dynamoKeySchemes: + - TRADITIONAL + """, DynamicConfiguration.class)); + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).containsExactly(MESSAGE1); + + // during migration we should get both messages back in order + when(dynamicConfigurationManager.getConfiguration()).thenReturn(SystemMapper.yamlMapper().readValue(""" + messagesConfiguration: + dynamoKeySchemes: + - TRADITIONAL + - LAZY_DELETION + """, DynamicConfiguration.class)); + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).containsExactly(MESSAGE1, MESSAGE2); + + // after migration we would only get message 2 back (we shouldn't do this either in practice) + when(dynamicConfigurationManager.getConfiguration()).thenReturn(SystemMapper.yamlMapper().readValue(""" + messagesConfiguration: + dynamoKeySchemes: + - LAZY_DELETION + """, DynamicConfiguration.class)); + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).containsExactly(MESSAGE2); + } + + @Test + void testLazyMessageDeletion() throws Exception { + final UUID destinationUuid = UUID.randomUUID(); + final Device primary = DevicesHelper.createDevice((byte) 1); + primary.setCreated(System.currentTimeMillis()); + + when(dynamicConfigurationManager.getConfiguration()).thenReturn(SystemMapper.yamlMapper().readValue(""" + messagesConfiguration: + dynamoKeySchemes: + - LAZY_DELETION + """, DynamicConfiguration.class)); + + messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, primary); + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)) + .as("load should return all messages stored").containsOnly(MESSAGE1, MESSAGE2, MESSAGE3); + + messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, primary, UUID.fromString(MESSAGE1.getServerGuid())) + .get(1, TimeUnit.SECONDS); + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)) + .as("deleting message by guid should work").containsExactly(MESSAGE3, MESSAGE2); + + messagesDynamoDb.deleteMessage(destinationUuid, primary, UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()) + .get(1, TimeUnit.SECONDS); + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)) + .as("deleting message by guid and timestamp should work").containsExactly(MESSAGE3); + + messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, (byte) 1).get(1, TimeUnit.SECONDS); + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)) + .as("deleting all messages for device should do nothing").containsExactly(MESSAGE3); + + messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid).get(1, TimeUnit.SECONDS); + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)) + .as("deleting all messages for account should do nothing").containsExactly(MESSAGE3); + + primary.setCreated(primary.getCreated() + 1000); + assertThat(load(destinationUuid, primary, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)) + .as("devices with the same id but different create timestamps should see no messages") + .isEmpty(); + } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 532e03960..59cd33e60 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -44,6 +44,7 @@ import org.junit.jupiter.params.provider.CsvSource; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; @@ -52,6 +53,7 @@ import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.MessagesCache; @@ -86,6 +88,8 @@ class WebSocketConnectionIntegrationTest { @BeforeEach void setUp() throws Exception { + final DynamicConfigurationManager mockDynamicConfigurationManager = mock(DynamicConfigurationManager.class); + when(mockDynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); sharedExecutorService = Executors.newSingleThreadExecutor(); scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); @@ -94,7 +98,7 @@ class WebSocketConnectionIntegrationTest { messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.MESSAGES.tableName(), Duration.ofDays(7), - sharedExecutorService); + mockDynamicConfigurationManager, sharedExecutorService); reportMessageManager = mock(ReportMessageManager.class); account = mock(Account.class); device = mock(Device.class); @@ -147,7 +151,7 @@ class WebSocketConnectionIntegrationTest { expectedMessages.add(envelope); } - messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); + messagesDynamoDb.store(persistedMessages, account.getUuid(), device); } for (int i = 0; i < cachedMessageCount; i++) { @@ -235,7 +239,7 @@ class WebSocketConnectionIntegrationTest { expectedMessages.add(envelope); } - messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); + messagesDynamoDb.store(persistedMessages, account.getUuid(), device); } for (int i = 0; i < cachedMessageCount; i++) { @@ -303,7 +307,7 @@ class WebSocketConnectionIntegrationTest { expectedMessages.add(envelope); } - messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); + messagesDynamoDb.store(persistedMessages, account.getUuid(), device); } for (int i = 0; i < cachedMessageCount; i++) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index ce7468507..4800d17be 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -184,12 +184,12 @@ class WebSocketConnectionTest { when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); - when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn( + when(messagesManager.delete(any(), any(), any(), any())).thenReturn( CompletableFuture.completedFuture(Optional.empty())); String userAgent = HttpHeaders.USER_AGENT; - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) .thenReturn(Flux.fromIterable(outgoingMessages)); final List> futures = new LinkedList<>(); @@ -218,7 +218,7 @@ class WebSocketConnectionTest { futures.get(0).completeExceptionally(new IOException()); futures.get(2).completeExceptionally(new IOException()); - verify(messagesManager, times(1)).delete(eq(accountUuid), eq(deviceId), + verify(messagesManager, times(1)).delete(eq(accountUuid), eq(device), eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp())); verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(accountUuid)), eq(deviceId), eq(new AciServiceIdentifier(senderOneUuid)), eq(2222L)); @@ -239,7 +239,7 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), argThat(d -> d.getId() == Device.PRIMARY_ID), anyBoolean())) .thenReturn(Flux.empty()) .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"))) .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second"))) @@ -330,12 +330,12 @@ class WebSocketConnectionTest { when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); - when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn( + when(messagesManager.delete(any(), any(), any(), any())).thenReturn( CompletableFuture.completedFuture(Optional.empty())); String userAgent = HttpHeaders.USER_AGENT; - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) .thenReturn(Flux.fromIterable(pendingMessages)); final List> futures = new LinkedList<>(); @@ -383,7 +383,7 @@ class WebSocketConnectionTest { final AtomicBoolean returnMessageList = new AtomicBoolean(false); when( - messagesManager.getMessagesForDeviceReactive(account.getUuid(), Device.PRIMARY_ID, false)) + messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) .thenAnswer(invocation -> { synchronized (threadWaiting) { threadWaiting.set(true); @@ -430,7 +430,7 @@ class WebSocketConnectionTest { } }); - verify(messagesManager).getMessagesForDeviceReactive(any(UUID.class), anyByte(), eq(false)); + verify(messagesManager).getMessagesForDeviceReactive(any(UUID.class), any(), eq(false)); } @Test @@ -451,10 +451,10 @@ class WebSocketConnectionTest { final List secondPageMessages = List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), eq(false))) + when(messagesManager.getMessagesForDeviceReactive(accountUuid, device, false)) .thenReturn(Flux.fromStream(Stream.concat(firstPageMessages.stream(), secondPageMessages.stream()))); - when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(), any())) + when(messagesManager.delete(eq(accountUuid), eq(device), any(), any())) .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -506,10 +506,10 @@ class WebSocketConnectionTest { .toList(); final Flux allMessages = Flux.concat(firstPublisher, secondPublisher); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), eq(false))) + when(messagesManager.getMessagesForDeviceReactive(accountUuid, device, false)) .thenReturn(allMessages); - when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(), any())) + when(messagesManager.delete(eq(accountUuid), eq(device), any(), any())) .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -579,11 +579,11 @@ class WebSocketConnectionTest { final List messages = List.of( createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first")); - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), Device.PRIMARY_ID, false)) + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) .thenReturn(Flux.fromIterable(messages)) .thenReturn(Flux.empty()); - when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(UUID.class), any())) + when(messagesManager.delete(eq(accountUuid), eq(device), any(UUID.class), any())) .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -642,7 +642,7 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) .thenReturn(Flux.empty()); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -676,12 +676,12 @@ class WebSocketConnectionTest { final List secondPageMessages = List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) .thenReturn(Flux.fromIterable(firstPageMessages)) .thenReturn(Flux.fromIterable(secondPageMessages)) .thenReturn(Flux.empty()); - when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(), any())) + when(messagesManager.delete(eq(accountUuid), eq(device), any(), any())) .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -725,7 +725,7 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) .thenReturn(Flux.empty()); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -737,11 +737,11 @@ class WebSocketConnectionTest { // anything. connection.processStoredMessages(); - verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device.getId(), false); + verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, false); connection.handleNewMessagesAvailable(); - verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device.getId(), true); + verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, true); } @Test @@ -756,7 +756,7 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) .thenReturn(Flux.empty()); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -769,7 +769,7 @@ class WebSocketConnectionTest { connection.processStoredMessages(); connection.handleMessagesPersisted(); - verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getUuid(), device.getId(), false); + verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getUuid(), device, false); } @Test @@ -781,7 +781,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) .thenReturn(Flux.error(new RedisException("OH NO"))); when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer( @@ -810,7 +810,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false)) + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) .thenReturn(Flux.error(new RedisException("OH NO"))); final WebSocketClient client = mock(WebSocketClient.class); @@ -838,7 +838,7 @@ class WebSocketConnectionTest { final TestPublisher testPublisher = TestPublisher.createCold(); final Flux flux = Flux.from(testPublisher); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(deviceId), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) .thenReturn(flux); final WebSocketClient client = mock(WebSocketClient.class); @@ -846,7 +846,7 @@ class WebSocketConnectionTest { final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); - when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn( + when(messagesManager.delete(any(), any(), any(), any())).thenReturn( CompletableFuture.completedFuture(Optional.empty())); WebSocketConnection connection = webSocketConnection(client); @@ -894,7 +894,7 @@ class WebSocketConnectionTest { s.onCancel(() -> canceled.set(true)); }); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(deviceId), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) .thenReturn(flux); final WebSocketClient client = mock(WebSocketClient.class); @@ -902,7 +902,7 @@ class WebSocketConnectionTest { final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); - when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn( + when(messagesManager.delete(any(), any(), any(), any())).thenReturn( CompletableFuture.completedFuture(Optional.empty())); WebSocketConnection connection = webSocketConnection(client);