From f1904628797e6eb1606cf78e7e68f11bd5448b55 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Thu, 22 Apr 2021 12:41:03 -0500 Subject: [PATCH] Fully implement unsealed sender cardinality rate limiter --- .../DynamicMessageRateConfiguration.java | 1 - .../controllers/MessageController.java | 21 ++++----- .../controllers/MessageControllerTest.java | 44 +++++++++++++++++++ 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessageRateConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessageRateConfiguration.java index c4ec00be0..c344d66e6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessageRateConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicMessageRateConfiguration.java @@ -6,7 +6,6 @@ package org.whispersystems.textsecuregcm.configuration.dynamic; import com.fasterxml.jackson.annotation.JsonProperty; - import java.time.Duration; import java.util.Collections; import java.util.Set; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index efcdb1824..cea695a9f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -172,19 +172,16 @@ public class MessageController { Metrics.counter(UNSEALED_SENDER_WITHOUT_PUSH_TOKEN_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode).increment(); } - if (dynamicConfigurationManager.getConfiguration().getMessageRateConfiguration().getRateLimitedCountryCodes().contains(senderCountryCode)) { - try { - rateLimiters.getUnsealedSenderLimiter().validate(source.get().getNumber(), destinationName.toString()); - rateLimiters.getUnsealedSenderLimiter().validate(source.get().getUuid().toString(), destinationName.toString()); - } catch (RateLimitExceededException e) { - Metrics.counter(REJECT_UNSEALED_SENDER_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode).increment(); + try { + rateLimiters.getUnsealedSenderLimiter().validate(source.get().getNumber(), destinationName.toString()); + } catch (RateLimitExceededException e) { - if (dynamicConfigurationManager.getConfiguration().getMessageRateConfiguration().isEnforceUnsealedSenderRateLimit()) { - logger.debug("Rejected unsealed sender limit from: {}", source.get().getNumber()); - throw e; - } else { - logger.debug("Would reject unsealed sender limit from: {}", source.get().getNumber()); - } + if (dynamicConfigurationManager.getConfiguration().getMessageRateConfiguration().isEnforceUnsealedSenderRateLimit()) { + Metrics.counter(REJECT_UNSEALED_SENDER_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode).increment(); + logger.debug("Rejected unsealed sender limit from: {}", source.get().getNumber()); + throw e; + } else { + logger.debug("Would reject unsealed sender limit from: {}", source.get().getNumber()); } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index ddaa12cfb..13f84d859 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -18,6 +18,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.argThat; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -56,6 +57,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatcher; import org.mockito.stubbing.Answer; @@ -65,6 +67,7 @@ import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration; import org.whispersystems.textsecuregcm.controllers.MessageController; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; @@ -75,6 +78,7 @@ import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.limits.CardinalityRateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -120,6 +124,7 @@ class MessageControllerTest { private static final ResourceExtension resources = ResourceExtension.builder() .addProvider(AuthHelper.getAuthFilter()) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) + .addProvider(RateLimitExceededExceptionMapper.class) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, apnFallbackManager, dynamicConfigurationManager, metricsCluster, receiptExecutor)) @@ -248,6 +253,45 @@ class MessageControllerTest { verify(receiptSender).sendReceipt(any(), eq(AuthHelper.VALID_NUMBER), anyLong()); } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testUnsealedSenderCardinalityRateLimited(final boolean rateLimited) throws Exception { + final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); + final DynamicMessageRateConfiguration messageRateConfiguration = mock(DynamicMessageRateConfiguration.class); + + when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); + when(dynamicConfiguration.getMessageRateConfiguration()).thenReturn(messageRateConfiguration); + when(messageRateConfiguration.isEnforceUnsealedSenderRateLimit()).thenReturn(true); + when(messageRateConfiguration.getResponseDelay()).thenReturn(Duration.ofMillis(1)); + when(messageRateConfiguration.getResponseDelayJitter()).thenReturn(Duration.ofMillis(1)); + when(messageRateConfiguration.getReceiptDelay()).thenReturn(Duration.ofMillis(1)); + when(messageRateConfiguration.getReceiptDelayJitter()).thenReturn(Duration.ofMillis(1)); + when(messageRateConfiguration.getReceiptProbability()).thenReturn(1.0); + + when(redisCommands.evalsha(any(), any(), any(), any())).thenReturn(List.of(1L, 1L)); + + if (rateLimited) { + doThrow(RateLimitExceededException.class) + .when(unsealedSenderLimiter).validate(eq(AuthHelper.VALID_NUMBER), eq(INTERNATIONAL_RECIPIENT)); + } + + Response response = + resources.getJerseyTest() + .target(String.format("/v1/messages/%s", INTERNATIONAL_RECIPIENT)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class), + MediaType.APPLICATION_JSON_TYPE)); + + if (rateLimited) { + assertThat("Error Response", response.getStatus(), is(equalTo(413))); + } else { + assertThat("Good Response", response.getStatus(), is(equalTo(200))); + } + + verify(messageSender, rateLimited ? never() : times(1)).sendMessage(any(), any(), any(), anyBoolean()); + } + @Test void testSingleDeviceCurrentUnidentified() throws Exception { Response response =