From 8f7bae54fe20db6d67be360f24e7c59386c5d367 Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer <125505367+jkt-signal@users.noreply.github.com> Date: Wed, 15 Nov 2023 17:15:17 -0800 Subject: [PATCH] When persisting messages fails due to a full queue in DynamoDB, automatically unlink one device to free up room. Co-authored-by: Chris Eager <79161849+eager-signal@users.noreply.github.com> --- .../storage/MessagePersister.java | 111 +++++++++-- .../textsecuregcm/storage/MessagesCache.java | 19 ++ .../storage/MessagesManager.java | 4 + .../MessagePersisterServiceCommand.java | 7 +- .../MessagePersisterIntegrationTest.java | 5 +- .../storage/MessagePersisterTest.java | 175 +++++++++++++++++- .../textsecuregcm/util/MockUtils.java | 37 ++++ 7 files changed, 334 insertions(+), 24 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index d2dbc64fe..ac0d4bab9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -16,17 +16,30 @@ import com.codahale.metrics.Timer; import com.google.common.annotations.VisibleForTesting; import io.dropwizard.lifecycle.Managed; import io.micrometer.core.instrument.Counter; +import reactor.core.publisher.Flux; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException; import java.time.Duration; import java.time.Instant; +import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Util; @@ -35,6 +48,8 @@ public class MessagePersister implements Managed { private final MessagesCache messagesCache; private final MessagesManager messagesManager; private final AccountsManager accountsManager; + private final ClientPresenceManager clientPresenceManager; + private final KeysManager keysManager; private final Duration persistDelay; @@ -50,27 +65,35 @@ public class MessagePersister implements Managed { private final Counter oversizedQueueCounter = counter(name(MessagePersister.class, "persistQueueOversized")); private final Histogram queueCountHistogram = metricRegistry.histogram(name(MessagePersister.class, "queueCount")); private final Histogram queueSizeHistogram = metricRegistry.histogram(name(MessagePersister.class, "queueSize")); + private final ExecutorService executor; static final int QUEUE_BATCH_LIMIT = 100; static final int MESSAGE_BATCH_LIMIT = 100; private static final long EXCEPTION_PAUSE_MILLIS = Duration.ofSeconds(3).toMillis(); + public static final Duration UNLINK_TIMEOUT = Duration.ofHours(1); private static final int CONSECUTIVE_EMPTY_CACHE_REMOVAL_LIMIT = 3; private static final Logger logger = LoggerFactory.getLogger(MessagePersister.class); public MessagePersister(final MessagesCache messagesCache, final MessagesManager messagesManager, - final AccountsManager accountsManager, + final AccountsManager accountsManager, final ClientPresenceManager clientPresenceManager, + final KeysManager keysManager, final DynamicConfigurationManager dynamicConfigurationManager, final Duration persistDelay, - final int dedicatedProcessWorkerThreadCount) { + final int dedicatedProcessWorkerThreadCount, + final ExecutorService executor + ) { this.messagesCache = messagesCache; this.messagesManager = messagesManager; this.accountsManager = accountsManager; + this.clientPresenceManager = clientPresenceManager; + this.keysManager = keysManager; this.persistDelay = persistDelay; this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount]; this.dedicatedProcess = true; + this.executor = executor; for (int i = 0; i < workerThreads.length; i++) { workerThreads[i] = new Thread(() -> { @@ -139,12 +162,14 @@ public class MessagePersister implements Managed { final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queue); final byte deviceId = MessagesCache.getDeviceIdFromQueueName(queue); + final Optional maybeAccount = accountsManager.getByAccountIdentifier(accountUuid); + if (maybeAccount.isEmpty()) { + logger.error("No account record found for account {}", accountUuid); + continue; + } try { - persistQueue(accountUuid, deviceId); + persistQueue(maybeAccount.get(), deviceId); } catch (final Exception e) { - if (e instanceof ItemCollectionSizeLimitExceededException) { - oversizedQueueCounter.increment(); - } persistQueueExceptionMeter.mark(); logger.warn("Failed to persist queue {}::{}; will schedule for retry", accountUuid, deviceId, e); @@ -161,14 +186,8 @@ public class MessagePersister implements Managed { } @VisibleForTesting - void persistQueue(final UUID accountUuid, final byte deviceId) throws MessagePersistenceException { - final Optional maybeAccount = accountsManager.getByAccountIdentifier(accountUuid); - - if (maybeAccount.isEmpty()) { - logger.error("No account record found for account {}", accountUuid); - return; - } - + void persistQueue(final Account account, final byte deviceId) throws MessagePersistenceException { + final UUID accountUuid = account.getUuid(); try (final Timer.Context ignored = persistQueueTimer.time()) { messagesCache.lockQueueForPersistence(accountUuid, deviceId); @@ -197,9 +216,73 @@ public class MessagePersister implements Managed { } while (!messages.isEmpty()); queueSizeHistogram.update(messageCount); + } catch (ItemCollectionSizeLimitExceededException e) { + oversizedQueueCounter.increment(); + unlinkLeastActiveDevice(account, deviceId); // this will either do a deferred reschedule for retry or throw } finally { messagesCache.unlockQueueForPersistence(accountUuid, deviceId); } } } + + @VisibleForTesting + void unlinkLeastActiveDevice(final Account account, byte destinationDeviceId) throws MessagePersistenceException { + if (!messagesCache.lockAccountForMessagePersisterCleanup(account.getUuid())) { + // don't try to unlink an account multiple times in parallel; just fail this persist attempt + // and we'll try again, hopefully deletions in progress will have made room + throw new MessagePersistenceException("account has a full queue and another device-unlinking attempt is in progress"); + } + + // Selection logic: + + // The primary device is never unlinked + // The least-recently-seen inactive device gets unlinked + // If there are none, the device with the oldest queued message (not necessarily the + // least-recently-seen; a device could be connecting frequently but have some problem fetching + // its messages) is unlinked + final Device deviceToDelete = account.getDevices() + .stream() + .filter(d -> !d.isPrimary() && !d.isEnabled()) + .min(Comparator.comparing(Device::getLastSeen)) + .or(() -> + Flux.fromIterable(account.getDevices()) + .filter(d -> !d.isPrimary()) + .flatMap(d -> + messagesManager + .getEarliestUndeliveredTimestampForDevice(account.getUuid(), d.getId()) + .map(t -> Tuples.of(d, t))) + .sort(Comparator.comparing(Tuple2::getT2)) + .map(Tuple2::getT1) + .next() + .blockOptional()) + .orElseThrow(() -> new MessagePersistenceException("account has a full queue and no unlinkable devices")); + + logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device {}{}", + account.getUuid(), destinationDeviceId, deviceToDelete.getId(), deviceToDelete.getId() == destinationDeviceId ? "" : " and schedule for retry"); + executor.execute( + () -> { + try { + // Lock the device's auth token first to prevent it from connecting while we're + // destroying its queue, but we don't want to completely remove the device from the + // account until we're sure the messages have been cleared because otherwise we won't + // be able to find it later to try again, in the event of a failure or timeout + final Account updatedAccount = accountsManager.update( + account, + a -> a.getDevice(deviceToDelete.getId()).ifPresent(Device::lockAuthTokenHash)); + clientPresenceManager.disconnectPresence(account.getUuid(), deviceToDelete.getId()); + CompletableFuture + .allOf( + keysManager.delete(account.getUuid(), deviceToDelete.getId()), + messagesManager.clear(account.getUuid(), deviceToDelete.getId())) + .orTimeout((UNLINK_TIMEOUT.toSeconds() * 3) / 4, TimeUnit.SECONDS) + .join(); + accountsManager.update(updatedAccount, a -> a.removeDevice(deviceToDelete.getId())); + } finally { + messagesCache.unlockAccountForMessagePersisterCleanup(account.getUuid()); + if (deviceToDelete.getId() != destinationDeviceId) { // no point in persisting a device we just purged + messagesCache.addQueueToPersist(account.getUuid(), destinationDeviceId); + } + } + }); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 9386dfdbe..e335ae6ce 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -12,6 +12,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.lifecycle.Managed; import io.lettuce.core.ScoredValue; import io.lettuce.core.ScriptOutputType; +import io.lettuce.core.SetArgs; import io.lettuce.core.ZAddArgs; import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.models.partitions.RedisClusterNode; @@ -382,6 +383,20 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp connection -> connection.sync().del(getPersistInProgressKey(accountUuid, deviceId))); } + boolean lockAccountForMessagePersisterCleanup(final UUID accountUuid) { + return readDeleteCluster.withBinaryCluster( + connection -> "OK".equals( + connection.sync().set( + getUnlinkInProgressKey(accountUuid), + LOCK_VALUE, + new SetArgs().ex(MessagePersister.UNLINK_TIMEOUT.toSeconds()).nx()))); + } + + void unlockAccountForMessagePersisterCleanup(final UUID accountUuid) { + readDeleteCluster.useBinaryCluster( + connection -> connection.sync().del(getUnlinkInProgressKey(accountUuid))); + } + public void addMessageAvailabilityListener(final UUID destinationUuid, final byte deviceId, final MessageAvailabilityListener listener) { final String queueName = getQueueName(destinationUuid, deviceId); @@ -531,6 +546,10 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp return ("user_queue_persisting::{" + accountUuid + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } + private static byte[] getUnlinkInProgressKey(final UUID accountUuid) { + return ("user_account_unlinking::{" + accountUuid + "}").getBytes(StandardCharsets.UTF_8); + } + static UUID getAccountUuidFromQueueName(final String queueName) { final int startOfHashTag = queueName.indexOf('{'); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index e0f0476c3..b2d926aa9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -102,6 +102,10 @@ public class MessagesManager { .tap(Micrometer.metrics(Metrics.globalRegistry)); } + public Mono getEarliestUndeliveredTimestampForDevice(UUID destinationUuid, byte destinationDevice) { + return Mono.from(messagesDynamoDb.load(destinationUuid, destinationDevice, 1)).map(Envelope::getServerTimestamp); + } + public CompletableFuture clear(UUID destinationUuid) { return CompletableFuture.allOf( messagesCache.clear(destinationUuid), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java index 8b0ca1cab..e56d2705e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java @@ -62,9 +62,14 @@ public class MessagePersisterServiceCommand extends ServerCommand { final UUID destinationUuid = invocation.getArgument(0); @@ -172,6 +188,7 @@ class MessagePersisterTest { final Account account = mock(Account.class); when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account)); + when(account.getUuid()).thenReturn(accountUuid); when(account.getNumber()).thenReturn(accountNumber); insertMessages(accountUuid, deviceId, messagesPerQueue, now); @@ -223,7 +240,150 @@ class MessagePersisterTest { assertTimeoutPreemptively(Duration.ofSeconds(1), () -> assertThrows(MessagePersistenceException.class, - () -> messagePersister.persistQueue(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID))); + () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID))); + } + + @Test + void testUnlinkFirstInactiveDeviceOnFullQueue() { + final String queueName = new String( + MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + final int messageCount = 1; + final Instant now = Instant.now(); + + insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); + setNextSlotToPersist(SlotHash.getSlot(queueName)); + + final Device primary = mock(Device.class); + when(primary.getId()).thenReturn((byte) 1); + when(primary.isPrimary()).thenReturn(true); + when(primary.isEnabled()).thenReturn(true); + final Device activeA = mock(Device.class); + when(activeA.getId()).thenReturn((byte) 2); + when(activeA.isEnabled()).thenReturn(true); + final Device inactiveB = mock(Device.class); + final byte inactiveId = 3; + when(inactiveB.getId()).thenReturn(inactiveId); + when(inactiveB.isEnabled()).thenReturn(false); + final Device inactiveC = mock(Device.class); + when(inactiveC.getId()).thenReturn((byte) 4); + when(inactiveC.isEnabled()).thenReturn(false); + final Device activeD = mock(Device.class); + when(activeD.getId()).thenReturn((byte) 5); + when(activeD.isEnabled()).thenReturn(true); + final Device destination = mock(Device.class); + when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID); + when(destination.isEnabled()).thenReturn(true); + + when(destinationAccount.getDevices()).thenReturn(List.of(primary, activeA, inactiveB, inactiveC, activeD, destination)); + + when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); + when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.delete(any(), eq(inactiveId))).thenReturn(CompletableFuture.completedFuture(null)); + + assertTimeoutPreemptively(Duration.ofSeconds(1), () -> + messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID)); + + verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, inactiveId); + } + + @Test + void testUnlinkActiveDeviceWithOldestMessageOnFullQueueWithNoInactiveDevices() { + final String queueName = new String( + MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + final int messageCount = 1; + final Instant now = Instant.now(); + + insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); + setNextSlotToPersist(SlotHash.getSlot(queueName)); + + final Device primary = mock(Device.class); + final byte primaryId = 1; + when(primary.getId()).thenReturn(primaryId); + when(primary.isPrimary()).thenReturn(true); + when(primary.isEnabled()).thenReturn(true); + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(primaryId))) + .thenReturn(Mono.just(4L)); + + final Device deviceA = mock(Device.class); + final byte deviceIdA = 2; + when(deviceA.getId()).thenReturn(deviceIdA); + when(deviceA.isEnabled()).thenReturn(true); + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdA))) + .thenReturn(Mono.empty()); + + final Device deviceB = mock(Device.class); + final byte deviceIdB = 3; + when(deviceB.getId()).thenReturn(deviceIdB); + when(deviceB.isEnabled()).thenReturn(true); + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdB))) + .thenReturn(Mono.just(2L)); + + final Device destination = mock(Device.class); + when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID); + when(destination.isEnabled()).thenReturn(true); + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(DESTINATION_DEVICE_ID))) + .thenReturn(Mono.just(5L)); + + when(destinationAccount.getDevices()).thenReturn(List.of(primary, deviceA, deviceB, destination)); + + when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); + when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.delete(any(), eq(deviceIdB))).thenReturn(CompletableFuture.completedFuture(null)); + + assertTimeoutPreemptively(Duration.ofSeconds(1), () -> + messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID)); + + verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, deviceIdB); + } + + @Test + void testUnlinkDestinationDevice() { + final String queueName = new String( + MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + final int messageCount = 1; + final Instant now = Instant.now(); + + insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); + setNextSlotToPersist(SlotHash.getSlot(queueName)); + + final Device primary = mock(Device.class); + final byte primaryId = 1; + when(primary.getId()).thenReturn(primaryId); + when(primary.isPrimary()).thenReturn(true); + when(primary.isEnabled()).thenReturn(true); + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(primaryId))) + .thenReturn(Mono.just(1L)); + + final Device deviceA = mock(Device.class); + final byte deviceIdA = 2; + when(deviceA.getId()).thenReturn(deviceIdA); + when(deviceA.isEnabled()).thenReturn(true); + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdA))) + .thenReturn(Mono.just(3L)); + + final Device deviceB = mock(Device.class); + final byte deviceIdB = 2; + when(deviceB.getId()).thenReturn(deviceIdB); + when(deviceB.isEnabled()).thenReturn(true); + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdB))) + .thenReturn(Mono.empty()); + + final Device destination = mock(Device.class); + when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID); + when(destination.isEnabled()).thenReturn(true); + when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(DESTINATION_DEVICE_ID))) + .thenReturn(Mono.just(2L)); + + when(destinationAccount.getDevices()).thenReturn(List.of(primary, deviceA, deviceB, destination)); + + when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); + when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); + + assertTimeoutPreemptively(Duration.ofSeconds(1), () -> + messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID)); + + verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID); } @SuppressWarnings("SameParameterValue") @@ -265,5 +425,4 @@ class MessagePersisterTest { REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster( connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1))); } - } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java index af5a62b3b..bca1ad942 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java @@ -9,12 +9,23 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; +import static org.mockito.internal.exceptions.Reporter.noMoreInteractionsWanted; +import static org.mockito.internal.exceptions.Reporter.wantedButNotInvoked; +import static org.mockito.internal.invocation.InvocationMarker.markVerified; +import static org.mockito.internal.invocation.InvocationsFinder.findFirstUnverified; +import static org.mockito.internal.invocation.InvocationsFinder.findInvocations; import java.time.Duration; +import java.util.List; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.function.Predicate; + import org.apache.commons.lang3.RandomUtils; import org.mockito.Mockito; +import org.mockito.invocation.Invocation; +import org.mockito.invocation.MatchableInvocation; +import org.mockito.verification.VerificationMode; import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.limits.RateLimiter; @@ -154,4 +165,30 @@ public final class MockUtils { } return new SecretBytes(bytes); } + + /** + * modeled after {@link org.mockito.Mockito#only()}, verifies that the matched invocation is the only invocation of + * this method + */ + public static VerificationMode exactly() { + return data -> { + MatchableInvocation target = data.getTarget(); + final List allInvocations = data.getAllInvocations(); + List chunk = findInvocations(allInvocations, target); + List otherInvocations = allInvocations.stream() + .filter(target::hasSameMethod) + .filter(Predicate.not(target::matches)) + .toList(); + + if (!otherInvocations.isEmpty()) { + Invocation unverified = findFirstUnverified(otherInvocations); + throw noMoreInteractionsWanted(unverified, (List) allInvocations); + } + if (chunk.isEmpty()) { + throw wantedButNotInvoked(target); + } + markVerified(chunk.get(0), target); + }; + } + }