Add a rate limit for inbound message bytes for a given account

This commit is contained in:
Jon Chambers 2023-07-14 15:16:33 -04:00 committed by Jon Chambers
parent e38a713ccc
commit e87468fbe0
4 changed files with 18 additions and 5 deletions

View File

@ -186,9 +186,7 @@ public class MessageController {
@PathParam("destination") UUID destinationUuid, @PathParam("destination") UUID destinationUuid,
@QueryParam("story") boolean isStory, @QueryParam("story") boolean isStory,
@NotNull @Valid IncomingMessageList messages, @NotNull @Valid IncomingMessageList messages,
@Context ContainerRequestContext context @Context ContainerRequestContext context) throws RateLimitExceededException {
)
throws RateLimitExceededException {
if (source.isEmpty() && accessKey.isEmpty() && !isStory) { if (source.isEmpty() && accessKey.isEmpty() && !isStory) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
@ -213,8 +211,9 @@ public class MessageController {
spamReportToken = Optional.empty(); spamReportToken = Optional.empty();
} }
for (final IncomingMessage message : messages.messages()) { int totalContentLength = 0;
for (final IncomingMessage message : messages.messages()) {
int contentLength = 0; int contentLength = 0;
if (!Util.isEmpty(message.content())) { if (!Util.isEmpty(message.content())) {
@ -223,8 +222,12 @@ public class MessageController {
validateContentLength(contentLength, userAgent); validateContentLength(contentLength, userAgent);
validateEnvelopeType(message.type(), userAgent); validateEnvelopeType(message.type(), userAgent);
totalContentLength += contentLength;
} }
rateLimiters.getInboundMessageBytes().validate(destinationUuid, totalContentLength);
try { try {
boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationUuid); boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationUuid);

View File

@ -33,6 +33,10 @@ public interface RateLimiter {
validate(accountUuid.toString()); validate(accountUuid.toString());
} }
default void validate(final UUID accountUuid, final int permits) throws RateLimitExceededException {
validate(accountUuid.toString(), permits);
}
default void validate(final UUID srcAccountUuid, final UUID dstAccountUuid) throws RateLimitExceededException { default void validate(final UUID srcAccountUuid, final UUID dstAccountUuid) throws RateLimitExceededException {
validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString()); validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
} }

View File

@ -48,7 +48,8 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
RECAPTCHA_CHALLENGE_SUCCESS("recaptchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofMinutes(12))), RECAPTCHA_CHALLENGE_SUCCESS("recaptchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofMinutes(12))),
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofSeconds(144))), PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofSeconds(144))),
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofMinutes(12))), PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofMinutes(12))),
CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofSeconds(15))); CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofSeconds(15))),
INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000)));
private final String id; private final String id;
@ -211,4 +212,8 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public RateLimiter getCreateCallLinkLimiter() { public RateLimiter getCreateCallLinkLimiter() {
return forDescriptor(For.CREATE_CALL_LINK); return forDescriptor(For.CREATE_CALL_LINK);
} }
public RateLimiter getInboundMessageBytes() {
return forDescriptor(For.INBOUND_MESSAGE_BYTES);
}
} }

View File

@ -205,6 +205,7 @@ class MessageControllerTest {
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);
} }
private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) { private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) {