diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index ff485db80..adb208c9e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -32,8 +32,9 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.UUID; -import java.util.concurrent.Callable; +import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; @@ -466,25 +467,37 @@ public class MessageController { Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)), Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED))); - multiRecipientMessageExecutor.invokeAll(Arrays.stream(multiRecipientMessage.recipients()) - .map(recipient -> (Callable) () -> { - Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid()); + CompletableFuture.allOf( + Arrays.stream(multiRecipientMessage.recipients()) + // If we're sending a story, some recipients might not map to existing accounts + .filter(recipient -> accountsByServiceIdentifier.containsKey(recipient.uuid())) + .map( + recipient -> CompletableFuture.runAsync( + () -> { + Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid()); - // we asserted this must exist in validateCompleteDeviceList - Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow(); - sentMessageCounter.increment(); - try { - sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent, - recipient, multiRecipientMessage.commonPayload()); - } catch (NoSuchUserException e) { - uuids404.add(recipient.uuid()); - } - return null; - }) - .collect(Collectors.toList())); + // we asserted this must exist in validateCompleteDeviceList + Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow(); + sentMessageCounter.increment(); + try { + sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent, + recipient, multiRecipientMessage.commonPayload()); + } catch (NoSuchUserException e) { + uuids404.add(recipient.uuid()); + } + }, + multiRecipientMessageExecutor)) + .toArray(CompletableFuture[]::new)) + .get(); } catch (InterruptedException e) { logger.error("interrupted while delivering multi-recipient messages", e); return Response.serverError().entity("interrupted during delivery").build(); + } catch (CancellationException e) { + logger.error("cancelled while delivering multi-recipient messages", e); + return Response.serverError().entity("delivery cancelled").build(); + } catch (ExecutionException e) { + logger.error("partial failure while delivering multi-recipient messages", e.getCause()); + return Response.serverError().entity("failure during delivery").build(); } return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index d7db49867..b6c4ae0ee 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -32,6 +32,7 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixtur import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.ByteString; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; @@ -166,7 +167,7 @@ class MessageControllerTest { private static final RateLimiter rateLimiter = mock(RateLimiter.class); private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); - private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class); + private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService(); private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); @@ -252,8 +253,7 @@ class MessageControllerTest { rateLimiter, cardinalityEstimator, pushNotificationManager, - reportMessageManager, - multiRecipientMessageExecutor + reportMessageManager ); } @@ -990,19 +990,6 @@ class MessageControllerTest { // set up the entity to use in our PUT request Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); - when(multiRecipientMessageExecutor.invokeAll(any())) - .thenAnswer(answer -> { - final List tasks = answer.getArgument(0, List.class); - tasks.forEach(c -> { - try { - c.call(); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); - return null; - }); - // start building the request Invocation.Builder bldr = resources .getJerseyTest() @@ -1110,6 +1097,32 @@ class MessageControllerTest { ); } + @Test + void testMultiRecipientMessageToAccountsSomeOfWhichDoNotExist() throws Exception { + UUID badUUID = UUID.fromString("33333333-3333-3333-3333-333333333333"); + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(badUUID))).thenReturn(Optional.empty()); + + final List recipients = List.of( + new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, + new byte[48]), + new Recipient(new AciServiceIdentifier(badUUID), (byte) 1, 1, new byte[48])); + + Response response = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("online", true) + .queryParam("ts", 1700000000000L) + .queryParam("story", true) + .queryParam("urgent", false) + .request() + .header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot") + .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) + .put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], true), + MultiRecipientMessageProvider.MEDIA_TYPE)); + + checkGoodMultiRecipientResponse(response, 1); + } + @ParameterizedTest @ValueSource(booleans = {true, false}) void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception { @@ -1316,19 +1329,6 @@ class MessageControllerTest { void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier) throws NotPushRegisteredException, InterruptedException { - when(multiRecipientMessageExecutor.invokeAll(any())) - .thenAnswer(answer -> { - final List tasks = answer.getArgument(0, List.class); - tasks.forEach(c -> { - try { - c.call(); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); - return null; - }); - final List recipients = List.of( new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])); @@ -1371,14 +1371,12 @@ class MessageControllerTest { private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode))); verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); - verify(multiRecipientMessageExecutor, never()).invokeAll(any()); } private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception { assertThat("Unexpected response", response.getStatus(), is(equalTo(200))); ArgumentCaptor>> captor = ArgumentCaptor.forClass(List.class); - verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture()); - assert (captor.getValue().size() == expectedCount); + verify(messageSender, times(expectedCount)).sendMessage(any(), any(), any(), anyBoolean()); SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); assert (smrmr.uuids404().isEmpty()); }