Check story rate limits in parallel

This commit is contained in:
Jon Chambers 2023-12-01 16:01:14 -05:00 committed by Jon Chambers
parent e9708b9259
commit 417d99a17e
2 changed files with 55 additions and 3 deletions

View File

@ -112,6 +112,7 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.Stories;
@ -150,6 +151,8 @@ public class MessageController {
private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture<?>[0];
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER = name(MessageController.class, "rejectOversizeMessage");
private static final String SENT_MESSAGE_COUNTER_NAME = name(MessageController.class, "sentMessages");
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize");
@ -447,8 +450,22 @@ public class MessageController {
if (recipients.isEmpty()) {
return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build();
}
for (MultiRecipientDeliveryData recipient : recipients.values()) {
rateLimiters.getStoriesLimiter().validate(recipient.account().getUuid());
try {
CompletableFuture.allOf(recipients.values()
.stream()
.map(recipient -> recipient.account().getUuid())
.map(accountIdentifier ->
rateLimiters.getStoriesLimiter().validateAsync(accountIdentifier).toCompletableFuture())
.toList()
.toArray(EMPTY_FUTURE_ARRAY))
.join();
} catch (final Exception e) {
if (ExceptionUtils.unwrap(e) instanceof RateLimitExceededException rateLimitExceededException) {
throw rateLimitExceededException;
} else {
throw ExceptionUtils.wrap(e);
}
}
}

View File

@ -43,6 +43,7 @@ import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
@ -243,6 +244,8 @@ class MessageControllerTest {
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getStoriesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);
when(rateLimiter.validateAsync(any(UUID.class))).thenReturn(CompletableFuture.completedFuture(null));
}
private static Device generateTestDevice(final byte id, final int registrationId, final int pniRegistrationId,
@ -1148,6 +1151,7 @@ class MessageControllerTest {
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages());
}
@SuppressWarnings("unused")
private static ArgumentSets testMultiRecipientMessageNoPni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID));
@ -1449,7 +1453,7 @@ class MessageControllerTest {
@ParameterizedTest
@MethodSource
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier, final int regId1, final int regId2)
throws NotPushRegisteredException, InterruptedException {
throws NotPushRegisteredException {
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, regId1, new byte[48]),
@ -1490,6 +1494,37 @@ class MessageControllerTest {
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
}
@Test
void sendMultiRecipientMessageStoryRateLimited() {
final List<Recipient> recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, true);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
final Invocation.Builder invocationBuilder = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", false)
.queryParam("ts", System.currentTimeMillis())
.queryParam("story", true)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
when(rateLimiter.validateAsync(any(UUID.class)))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofSeconds(77), true)));
try (final Response response = invocationBuilder.put(entity)) {
assertEquals(413, response.getStatus());
}
}
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());