From 65b2892de58926b0246e31c8266a3bbc993ba1f7 Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer <125505367+jkt-signal@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:03:43 -0700 Subject: [PATCH] Simplify unlink-device-on-full-DB process --- .../storage/MessagePersister.java | 74 ++---------- .../textsecuregcm/storage/MessagesCache.java | 14 --- .../MessagePersisterServiceCommand.java | 5 +- .../MessagePersisterIntegrationTest.java | 3 +- .../storage/MessagePersisterTest.java | 105 ++++-------------- 5 files changed, 32 insertions(+), 169 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index a909ccbbe..230457690 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -20,7 +20,6 @@ import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; @@ -62,13 +61,11 @@ public class MessagePersister implements Managed { .publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999) .distributionStatisticExpiry(Duration.ofMinutes(10)) .register(Metrics.globalRegistry); - private final ExecutorService executor; static final int QUEUE_BATCH_LIMIT = 100; static final int MESSAGE_BATCH_LIMIT = 100; private static final long EXCEPTION_PAUSE_MILLIS = Duration.ofSeconds(3).toMillis(); - public static final Duration UNLINK_TIMEOUT = Duration.ofHours(1); private static final int CONSECUTIVE_EMPTY_CACHE_REMOVAL_LIMIT = 3; @@ -79,8 +76,7 @@ public class MessagePersister implements Managed { final KeysManager keysManager, final DynamicConfigurationManager dynamicConfigurationManager, final Duration persistDelay, - final int dedicatedProcessWorkerThreadCount, - final ExecutorService executor + final int dedicatedProcessWorkerThreadCount ) { this.messagesCache = messagesCache; this.messagesManager = messagesManager; @@ -89,7 +85,6 @@ public class MessagePersister implements Managed { this.keysManager = keysManager; this.persistDelay = persistDelay; this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount]; - this.executor = executor; for (int i = 0; i < workerThreads.length; i++) { workerThreads[i] = new Thread(() -> { @@ -221,7 +216,7 @@ public class MessagePersister implements Managed { queueSizeDistributionSummery.record(messageCount); } catch (ItemCollectionSizeLimitExceededException e) { oversizedQueueCounter.increment(); - unlinkLeastActiveDevice(account, deviceId); // this will either do a deferred reschedule for retry or throw + maybeUnlink(account, deviceId); // may throw, in which case we'll retry later by the usual mechanism } finally { messagesCache.unlockQueueForPersistence(accountUuid, deviceId); sample.stop(persistQueueTimer); @@ -230,67 +225,12 @@ public class MessagePersister implements Managed { } @VisibleForTesting - void unlinkLeastActiveDevice(final Account account, byte destinationDeviceId) throws MessagePersistenceException { - if (!messagesCache.lockAccountForMessagePersisterCleanup(account.getUuid())) { - // don't try to unlink an account multiple times in parallel; just fail this persist attempt - // and we'll try again, hopefully deletions in progress will have made room - throw new MessagePersistenceException("account has a full queue and another device-unlinking attempt is in progress"); + void maybeUnlink(final Account account, byte destinationDeviceId) throws MessagePersistenceException { + if (destinationDeviceId == Device.PRIMARY_ID) { + throw new MessagePersistenceException("primary device has a full queue"); } - // Selection logic: - - // The primary device is never unlinked - // The least-recently-seen inactive device gets unlinked - // If there are none, the device with the oldest queued message (not necessarily the - // least-recently-seen; a device could be connecting frequently but have some problem fetching - // its messages) is unlinked - final Device deviceToDelete = account.getDevices() - .stream() - .filter(d -> !d.isPrimary() && !deviceHasMessageDeliveryChannel(d)) - .min(Comparator.comparing(Device::getLastSeen)) - .or(() -> - Flux.fromIterable(account.getDevices()) - .filter(d -> !d.isPrimary()) - .flatMap(d -> - messagesManager - .getEarliestUndeliveredTimestampForDevice(account.getUuid(), d) - .map(t -> Tuples.of(d, t))) - .sort(Comparator.comparing(Tuple2::getT2)) - .map(Tuple2::getT1) - .next() - .blockOptional()) - .orElseThrow(() -> new MessagePersistenceException("account has a full queue and no unlinkable devices")); - - logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device {}{}", - account.getUuid(), destinationDeviceId, deviceToDelete.getId(), deviceToDelete.getId() == destinationDeviceId ? "" : " and schedule for retry"); - executor.execute( - () -> { - try { - // Lock the device's auth token first to prevent it from connecting while we're - // destroying its queue, but we don't want to completely remove the device from the - // account until we're sure the messages have been cleared because otherwise we won't - // be able to find it later to try again, in the event of a failure or timeout - final Account updatedAccount = accountsManager.update( - account, - a -> a.getDevice(deviceToDelete.getId()).ifPresent(Device::lockAuthTokenHash)); - clientPresenceManager.disconnectPresence(account.getUuid(), deviceToDelete.getId()); - CompletableFuture - .allOf( - keysManager.deleteSingleUsePreKeys(account.getUuid(), deviceToDelete.getId()), - messagesManager.clear(account.getUuid(), deviceToDelete.getId())) - .orTimeout((UNLINK_TIMEOUT.toSeconds() * 3) / 4, TimeUnit.SECONDS) - .join(); - accountsManager.removeDevice(updatedAccount, deviceToDelete.getId()).join(); - } finally { - messagesCache.unlockAccountForMessagePersisterCleanup(account.getUuid()); - if (deviceToDelete.getId() != destinationDeviceId) { // no point in persisting a device we just purged - messagesCache.addQueueToPersist(account.getUuid(), destinationDeviceId); - } - } - }); - } - - private static boolean deviceHasMessageDeliveryChannel(final Device device) { - return device.getFetchesMessages() || StringUtils.isNotEmpty(device.getApnId()) || StringUtils.isNotEmpty(device.getGcmId()); + logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", account.getUuid(), destinationDeviceId); + accountsManager.removeDevice(account, destinationDeviceId).join(); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 645220b1d..80b7ec758 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -395,20 +395,6 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp connection -> connection.sync().del(getPersistInProgressKey(accountUuid, deviceId))); } - boolean lockAccountForMessagePersisterCleanup(final UUID accountUuid) { - return redisCluster.withBinaryCluster( - connection -> "OK".equals( - connection.sync().set( - getUnlinkInProgressKey(accountUuid), - LOCK_VALUE, - new SetArgs().ex(MessagePersister.UNLINK_TIMEOUT.toSeconds()).nx()))); - } - - void unlockAccountForMessagePersisterCleanup(final UUID accountUuid) { - redisCluster.useBinaryCluster( - connection -> connection.sync().del(getUnlinkInProgressKey(accountUuid))); - } - public void addMessageAvailabilityListener(final UUID destinationUuid, final byte deviceId, final MessageAvailabilityListener listener) { final String queueName = getQueueName(destinationUuid, deviceId); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java index ae885b8fe..1a89e4447 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java @@ -66,10 +66,7 @@ public class MessagePersisterServiceCommand extends ServerCommand messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); - - verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, inactiveId); + verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID); } @Test - void testUnlinkActiveDeviceWithOldestMessageOnFullQueueWithNoInactiveDevices() { + void testFailedUnlinkOnFullQueueThrowsForRetry() { final String queueName = new String( MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); final int messageCount = 1; @@ -302,93 +302,34 @@ class MessagePersisterTest { setNextSlotToPersist(SlotHash.getSlot(queueName)); final Device primary = mock(Device.class); - final byte primaryId = 1; - when(primary.getId()).thenReturn(primaryId); + when(primary.getId()).thenReturn((byte) 1); when(primary.isPrimary()).thenReturn(true); when(primary.getFetchesMessages()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(primary))) - .thenReturn(Mono.just(4L)); - final Device deviceA = mock(Device.class); - final byte deviceIdA = 2; - when(deviceA.getId()).thenReturn(deviceIdA); - when(deviceA.getFetchesMessages()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceA))) - .thenReturn(Mono.empty()); + final Device activeA = mock(Device.class); + when(activeA.getId()).thenReturn((byte) 2); + when(activeA.getFetchesMessages()).thenReturn(true); - final Device deviceB = mock(Device.class); - final byte deviceIdB = 3; - when(deviceB.getId()).thenReturn(deviceIdB); - when(deviceB.getFetchesMessages()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceB))) - .thenReturn(Mono.just(2L)); + final Device inactiveB = mock(Device.class); + final byte inactiveId = 3; + when(inactiveB.getId()).thenReturn(inactiveId); + + final Device inactiveC = mock(Device.class); + when(inactiveC.getId()).thenReturn((byte) 4); + + final Device activeD = mock(Device.class); + when(activeD.getId()).thenReturn((byte) 5); + when(activeD.getFetchesMessages()).thenReturn(true); final Device destination = mock(Device.class); when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID); - when(destination.getFetchesMessages()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(destination))) - .thenReturn(Mono.just(5L)); - when(destinationAccount.getDevices()).thenReturn(List.of(primary, deviceA, deviceB, destination)); + when(destinationAccount.getDevices()).thenReturn(List.of(primary, activeA, inactiveB, inactiveC, activeD, destination)); when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); - when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.deleteSingleUsePreKeys(any(), eq(deviceIdB))).thenReturn(CompletableFuture.completedFuture(null)); + when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenReturn(CompletableFuture.failedFuture(new TimeoutException())); - assertTimeoutPreemptively(Duration.ofSeconds(1), () -> - messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); - - verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, deviceIdB); - } - - @Test - void testUnlinkDestinationDevice() { - final String queueName = new String( - MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); - final int messageCount = 1; - final Instant now = Instant.now(); - - insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); - setNextSlotToPersist(SlotHash.getSlot(queueName)); - - final Device primary = mock(Device.class); - final byte primaryId = 1; - when(primary.getId()).thenReturn(primaryId); - when(primary.isPrimary()).thenReturn(true); - when(primary.getFetchesMessages()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(primary))) - .thenReturn(Mono.just(1L)); - - final Device deviceA = mock(Device.class); - final byte deviceIdA = 2; - when(deviceA.getId()).thenReturn(deviceIdA); - when(deviceA.getFetchesMessages()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceA))) - .thenReturn(Mono.just(3L)); - - final Device deviceB = mock(Device.class); - final byte deviceIdB = 2; - when(deviceB.getId()).thenReturn(deviceIdB); - when(deviceB.getFetchesMessages()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceB))) - .thenReturn(Mono.empty()); - - final Device destination = mock(Device.class); - when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID); - when(destination.getFetchesMessages()).thenReturn(true); - when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(destination))) - .thenReturn(Mono.just(2L)); - - when(destinationAccount.getDevices()).thenReturn(List.of(primary, deviceA, deviceB, destination)); - - when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); - when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.deleteSingleUsePreKeys(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); - - assertTimeoutPreemptively(Duration.ofSeconds(1), () -> - messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); - - verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID); + assertThrows(CompletionException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE)); } @SuppressWarnings("SameParameterValue")