From 417d99a17eac86adc6c0b216b7e8b9beb2a3becb Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Fri, 1 Dec 2023 16:01:14 -0500 Subject: [PATCH] Check story rate limits in parallel --- .../controllers/MessageController.java | 21 ++++++++++- .../controllers/MessageControllerTest.java | 37 ++++++++++++++++++- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 5bb9ca184..8d4f3e055 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -112,6 +112,7 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.websocket.Stories; @@ -150,6 +151,8 @@ public class MessageController { private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8; + private static final CompletableFuture[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0]; + private static final String REJECT_OVERSIZE_MESSAGE_COUNTER = name(MessageController.class, "rejectOversizeMessage"); private static final String SENT_MESSAGE_COUNTER_NAME = name(MessageController.class, "sentMessages"); private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize"); @@ -447,8 +450,22 @@ public class MessageController { if (recipients.isEmpty()) { return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build(); } - for (MultiRecipientDeliveryData recipient : recipients.values()) { - rateLimiters.getStoriesLimiter().validate(recipient.account().getUuid()); + + try { + CompletableFuture.allOf(recipients.values() + .stream() + .map(recipient -> recipient.account().getUuid()) + .map(accountIdentifier -> + rateLimiters.getStoriesLimiter().validateAsync(accountIdentifier).toCompletableFuture()) + .toList() + .toArray(EMPTY_FUTURE_ARRAY)) + .join(); + } catch (final Exception e) { + if (ExceptionUtils.unwrap(e) instanceof RateLimitExceededException rateLimitExceededException) { + throw rateLimitExceededException; + } else { + throw ExceptionUtils.wrap(e); + } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 3d0660cc9..88a6ae3f2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -43,6 +43,7 @@ import java.io.ByteArrayInputStream; import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; @@ -243,6 +244,8 @@ class MessageControllerTest { when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getStoriesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter); + + when(rateLimiter.validateAsync(any(UUID.class))).thenReturn(CompletableFuture.completedFuture(null)); } private static Device generateTestDevice(final byte id, final int registrationId, final int pniRegistrationId, @@ -1148,6 +1151,7 @@ class MessageControllerTest { testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages()); } + @SuppressWarnings("unused") private static ArgumentSets testMultiRecipientMessageNoPni() { final Map> targets = multiRecipientTargetMap(); final Map> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID)); @@ -1449,7 +1453,7 @@ class MessageControllerTest { @ParameterizedTest @MethodSource void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier, final int regId1, final int regId2) - throws NotPushRegisteredException, InterruptedException { + throws NotPushRegisteredException { final List recipients = List.of( new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, regId1, new byte[48]), @@ -1490,6 +1494,37 @@ class MessageControllerTest { Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2)); } + @Test + void sendMultiRecipientMessageStoryRateLimited() { + final List recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); + // initialize our binary payload and create an input stream + byte[] buffer = new byte[2048]; + // InputStream stream = initializeMultiPayload(recipientUUID, buffer); + InputStream stream = initializeMultiPayload(recipients, buffer, true); + + // set up the entity to use in our PUT request + Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); + + // start building the request + final Invocation.Builder invocationBuilder = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("online", false) + .queryParam("ts", System.currentTimeMillis()) + .queryParam("story", true) + .queryParam("urgent", true) + .request() + .header(HttpHeaders.USER_AGENT, "FIXME") + .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); + + when(rateLimiter.validateAsync(any(UUID.class))) + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofSeconds(77), true))); + + try (final Response response = invocationBuilder.put(entity)) { + assertEquals(413, response.getStatus()); + } + } + private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode))); verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());