report exceptions in fanned-out sends of multi-recipient messages

This commit is contained in:
Jonathan Klabunde Tomer 2023-11-20 10:46:26 -08:00 committed by GitHub
parent db7f18aae7
commit cb1fc734c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 48 deletions

View File

@ -32,8 +32,9 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
@ -466,25 +467,37 @@ public class MessageController {
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED)));
multiRecipientMessageExecutor.invokeAll(Arrays.stream(multiRecipientMessage.recipients())
.map(recipient -> (Callable<Void>) () -> {
Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid());
CompletableFuture.allOf(
Arrays.stream(multiRecipientMessage.recipients())
// If we're sending a story, some recipients might not map to existing accounts
.filter(recipient -> accountsByServiceIdentifier.containsKey(recipient.uuid()))
.map(
recipient -> CompletableFuture.runAsync(
() -> {
Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid());
// we asserted this must exist in validateCompleteDeviceList
Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
sentMessageCounter.increment();
try {
sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent,
recipient, multiRecipientMessage.commonPayload());
} catch (NoSuchUserException e) {
uuids404.add(recipient.uuid());
}
return null;
})
.collect(Collectors.toList()));
// we asserted this must exist in validateCompleteDeviceList
Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
sentMessageCounter.increment();
try {
sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent,
recipient, multiRecipientMessage.commonPayload());
} catch (NoSuchUserException e) {
uuids404.add(recipient.uuid());
}
},
multiRecipientMessageExecutor))
.toArray(CompletableFuture[]::new))
.get();
} catch (InterruptedException e) {
logger.error("interrupted while delivering multi-recipient messages", e);
return Response.serverError().entity("interrupted during delivery").build();
} catch (CancellationException e) {
logger.error("cancelled while delivering multi-recipient messages", e);
return Response.serverError().entity("delivery cancelled").build();
} catch (ExecutionException e) {
logger.error("partial failure while delivering multi-recipient messages", e.getCause());
return Response.serverError().entity("failure during delivery").build();
}
return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build();
}

View File

@ -32,6 +32,7 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixtur
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@ -166,7 +167,7 @@ class MessageControllerTest {
private static final RateLimiter rateLimiter = mock(RateLimiter.class);
private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);
private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class);
private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService();
private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
private static final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@ -252,8 +253,7 @@ class MessageControllerTest {
rateLimiter,
cardinalityEstimator,
pushNotificationManager,
reportMessageManager,
multiRecipientMessageExecutor
reportMessageManager
);
}
@ -990,19 +990,6 @@ class MessageControllerTest {
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
when(multiRecipientMessageExecutor.invokeAll(any()))
.thenAnswer(answer -> {
final List<Callable> tasks = answer.getArgument(0, List.class);
tasks.forEach(c -> {
try {
c.call();
} catch (Exception e) {
throw new RuntimeException(e);
}
});
return null;
});
// start building the request
Invocation.Builder bldr = resources
.getJerseyTest()
@ -1110,6 +1097,32 @@ class MessageControllerTest {
);
}
@Test
void testMultiRecipientMessageToAccountsSomeOfWhichDoNotExist() throws Exception {
UUID badUUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(badUUID))).thenReturn(Optional.empty());
final List<Recipient> recipients = List.of(
new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1,
new byte[48]),
new Recipient(new AciServiceIdentifier(badUUID), (byte) 1, 1, new byte[48]));
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1700000000000L)
.queryParam("story", true)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], true),
MultiRecipientMessageProvider.MEDIA_TYPE));
checkGoodMultiRecipientResponse(response, 1);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
@ -1316,19 +1329,6 @@ class MessageControllerTest {
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier)
throws NotPushRegisteredException, InterruptedException {
when(multiRecipientMessageExecutor.invokeAll(any()))
.thenAnswer(answer -> {
final List<Callable> tasks = answer.getArgument(0, List.class);
tasks.forEach(c -> {
try {
c.call();
} catch (Exception e) {
throw new RuntimeException(e);
}
});
return null;
});
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
@ -1371,14 +1371,12 @@ class MessageControllerTest {
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
verify(multiRecipientMessageExecutor, never()).invokeAll(any());
}
private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(200)));
ArgumentCaptor<List<Callable<Void>>> captor = ArgumentCaptor.forClass(List.class);
verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture());
assert (captor.getValue().size() == expectedCount);
verify(messageSender, times(expectedCount)).sendMessage(any(), any(), any(), anyBoolean());
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assert (smrmr.uuids404().isEmpty());
}