Fetch destination accounts in parallel when sending multi-recipient messages

This commit is contained in:
Jon Chambers 2023-12-01 15:42:07 -05:00 committed by Jon Chambers
parent 2ab3c97ee8
commit e5d3be16b0
2 changed files with 22 additions and 5 deletions

View File

@ -147,6 +147,9 @@ public class MessageController {
private final ReportSpamTokenProvider reportSpamTokenProvider;
private final ClientReleaseManager clientReleaseManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER = name(MessageController.class, "rejectOversizeMessage");
private static final String SENT_MESSAGE_COUNTER_NAME = name(MessageController.class, "sentMessages");
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize");
@ -368,7 +371,8 @@ public class MessageController {
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
.map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue()))
.flatMap(
t -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(t.getT1()))
t -> Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(t.getT1()))
.flatMap(Mono::justOrEmpty)
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
.map(
account ->
@ -379,7 +383,8 @@ public class MessageController {
t.getT2().getDevicesAndRegistrationIds().collect(
Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second))))
// IllegalStateException is thrown by Collectors#toMap when we have multiple entries for the same device
.onErrorMap(e -> e instanceof IllegalStateException ? new BadRequestException() : e))
.onErrorMap(e -> e instanceof IllegalStateException ? new BadRequestException() : e),
MAX_FETCH_ACCOUNT_CONCURRENCY)
.collectMap(MultiRecipientDeliveryData::serviceIdentifier)
.block();
}

View File

@ -162,7 +162,7 @@ class MessageControllerTest {
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
private static final UUID INTERNATIONAL_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
private static final UUID INTERNATIONAL_UUID = UUID.fromString("44444444-4444-4444-4444-444444444444");
@SuppressWarnings("unchecked")
private static final RedisAdvancedClusterCommands<String, String> redisCommands = mock(RedisAdvancedClusterCommands.class);
@ -223,6 +223,13 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount)));
final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration =
mock(DynamicInboundMessageByteLimitConfiguration.class);
@ -1019,8 +1026,13 @@ class MessageControllerTest {
final UUID pni = new UUID(1L, (long) i);
final String e164 = String.format("+1408555%04d", i);
final Account account = AccountsHelper.generateTestAccount(e164, aci, pni, devices, UNIDENTIFIED_ACCESS_BYTES);
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci))).thenReturn(Optional.of(account));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(pni))).thenReturn(Optional.of(account));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(aci)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(pni)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
devices.forEach(d -> recipients.add(new Recipient(new AciServiceIdentifier(aci), d.getId(), d.getRegistrationId(), new byte[48])));
}