From e87468fbe0cab9972a4833d8cfb24891c4be5ccd Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Fri, 14 Jul 2023 15:16:33 -0400 Subject: [PATCH] Add a rate limit for inbound message bytes for a given account --- .../textsecuregcm/controllers/MessageController.java | 11 +++++++---- .../textsecuregcm/limits/RateLimiter.java | 4 ++++ .../textsecuregcm/limits/RateLimiters.java | 7 ++++++- .../controllers/MessageControllerTest.java | 1 + 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 6d23f1b20..79ccece30 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -186,9 +186,7 @@ public class MessageController { @PathParam("destination") UUID destinationUuid, @QueryParam("story") boolean isStory, @NotNull @Valid IncomingMessageList messages, - @Context ContainerRequestContext context - ) - throws RateLimitExceededException { + @Context ContainerRequestContext context) throws RateLimitExceededException { if (source.isEmpty() && accessKey.isEmpty() && !isStory) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); @@ -213,8 +211,9 @@ public class MessageController { spamReportToken = Optional.empty(); } - for (final IncomingMessage message : messages.messages()) { + int totalContentLength = 0; + for (final IncomingMessage message : messages.messages()) { int contentLength = 0; if (!Util.isEmpty(message.content())) { @@ -223,8 +222,12 @@ public class MessageController { validateContentLength(contentLength, userAgent); validateEnvelopeType(message.type(), userAgent); + + totalContentLength += contentLength; } + rateLimiters.getInboundMessageBytes().validate(destinationUuid, totalContentLength); + try { boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationUuid); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java index d064fc819..70e7d1cd3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java @@ -33,6 +33,10 @@ public interface RateLimiter { validate(accountUuid.toString()); } + default void validate(final UUID accountUuid, final int permits) throws RateLimitExceededException { + validate(accountUuid.toString(), permits); + } + default void validate(final UUID srcAccountUuid, final UUID dstAccountUuid) throws RateLimitExceededException { validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index ec4c5739d..6737e500a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -48,7 +48,8 @@ public class RateLimiters extends BaseRateLimiters { RECAPTCHA_CHALLENGE_SUCCESS("recaptchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofMinutes(12))), PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofSeconds(144))), PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofMinutes(12))), - CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofSeconds(15))); + CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofSeconds(15))), + INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000))); private final String id; @@ -211,4 +212,8 @@ public class RateLimiters extends BaseRateLimiters { public RateLimiter getCreateCallLinkLimiter() { return forDescriptor(For.CREATE_CALL_LINK); } + + public RateLimiter getInboundMessageBytes() { + return forDescriptor(For.INBOUND_MESSAGE_BYTES); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 617e8dd21..bbb41a4d7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -205,6 +205,7 @@ class MessageControllerTest { when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); + when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter); } private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) {