From e5d3be16b091cb88e4b75089826437137cc1e5d9 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Fri, 1 Dec 2023 15:42:07 -0500 Subject: [PATCH] Fetch destination accounts in parallel when sending multi-recipient messages --- .../controllers/MessageController.java | 9 +++++++-- .../controllers/MessageControllerTest.java | 18 +++++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 6302a51eb..5bb9ca184 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -147,6 +147,9 @@ public class MessageController { private final ReportSpamTokenProvider reportSpamTokenProvider; private final ClientReleaseManager clientReleaseManager; private final DynamicConfigurationManager dynamicConfigurationManager; + + private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8; + private static final String REJECT_OVERSIZE_MESSAGE_COUNTER = name(MessageController.class, "rejectOversizeMessage"); private static final String SENT_MESSAGE_COUNTER_NAME = name(MessageController.class, "sentMessages"); private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize"); @@ -368,7 +371,8 @@ public class MessageController { return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet()) .map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue())) .flatMap( - t -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(t.getT1())) + t -> Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(t.getT1())) + .flatMap(Mono::justOrEmpty) .switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new)) .map( account -> @@ -379,7 +383,8 @@ public class MessageController { t.getT2().getDevicesAndRegistrationIds().collect( Collectors.toMap(Pair::first, Pair::second)))) // IllegalStateException is thrown by Collectors#toMap when we have multiple entries for the same device - .onErrorMap(e -> e instanceof IllegalStateException ? new BadRequestException() : e)) + .onErrorMap(e -> e instanceof IllegalStateException ? new BadRequestException() : e), + MAX_FETCH_ACCOUNT_CONCURRENCY) .collectMap(MultiRecipientDeliveryData::serviceIdentifier) .block(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 06fca3edb..6c6965ed7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -162,7 +162,7 @@ class MessageControllerTest { private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes(); private static final String INTERNATIONAL_RECIPIENT = "+61123456789"; - private static final UUID INTERNATIONAL_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333"); + private static final UUID INTERNATIONAL_UUID = UUID.fromString("44444444-4444-4444-4444-444444444444"); @SuppressWarnings("unchecked") private static final RedisAdvancedClusterCommands redisCommands = mock(RedisAdvancedClusterCommands.class); @@ -223,6 +223,13 @@ class MessageControllerTest { when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty()); when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty()); + when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount))); + when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount))); + when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount))); + when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount))); + when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount))); + final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration = mock(DynamicInboundMessageByteLimitConfiguration.class); @@ -1019,8 +1026,13 @@ class MessageControllerTest { final UUID pni = new UUID(1L, (long) i); final String e164 = String.format("+1408555%04d", i); final Account account = AccountsHelper.generateTestAccount(e164, aci, pni, devices, UNIDENTIFIED_ACCESS_BYTES); - when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci))).thenReturn(Optional.of(account)); - when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(pni))).thenReturn(Optional.of(account)); + + when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(aci))) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + + when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(pni))) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + devices.forEach(d -> recipients.add(new Recipient(new AciServiceIdentifier(aci), d.getId(), d.getRegistrationId(), new byte[48]))); }