Update rate-limiting for requests matching specific criteria

This commit is contained in:
Jon Chambers 2021-08-02 11:26:31 -04:00 committed by Jon Chambers
parent 64eeb1e361
commit eedeaaecee
5 changed files with 13 additions and 127 deletions

View File

@ -395,8 +395,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ScheduledExecutorService recurringJobExecutor = environment.lifecycle() ScheduledExecutorService recurringJobExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "recurringJob-%d")).threads(6).build(); .scheduledExecutorService(name(getClass(), "recurringJob-%d")).threads(6).build();
ScheduledExecutorService declinedMessageReceiptExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "declined-receipt-%d")).threads(2).build();
ScheduledExecutorService retrySchedulingExecutor = environment.lifecycle().scheduledExecutorService(name(getClass(), "retry-%d")).threads(2).build(); ScheduledExecutorService retrySchedulingExecutor = environment.lifecycle().scheduledExecutorService(name(getClass(), "retry-%d")).threads(2).build();
ExecutorService keyspaceNotificationDispatchExecutor = environment.lifecycle().executorService(name(getClass(), "keyspaceNotification-%d")).maxThreads(16).workQueue(keyspaceNotificationDispatchQueue).build(); ExecutorService keyspaceNotificationDispatchExecutor = environment.lifecycle().executorService(name(getClass(), "keyspaceNotification-%d")).maxThreads(16).workQueue(keyspaceNotificationDispatchQueue).build();
ExecutorService apnSenderExecutor = environment.lifecycle().executorService(name(getClass(), "apnSender-%d")).maxThreads(1).minThreads(1).build(); ExecutorService apnSenderExecutor = environment.lifecycle().executorService(name(getClass(), "apnSender-%d")).maxThreads(1).minThreads(1).build();
@ -624,7 +622,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new DirectoryController(directoryCredentialsGenerator), new DirectoryController(directoryCredentialsGenerator),
new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(), new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(),
ReceiptCredentialPresentation::new, stripeExecutor, config.getDonationConfiguration(), config.getStripe()), ReceiptCredentialPresentation::new, stripeExecutor, config.getDonationConfiguration(), config.getStripe()),
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, rateLimitChallengeManager, reportMessageManager, metricsCluster, declinedMessageReceiptExecutor, multiRecipientMessageExecutor), new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, rateLimitChallengeManager, reportMessageManager, metricsCluster, multiRecipientMessageExecutor),
new PaymentsController(currencyManager, paymentsCredentialsGenerator), new PaymentsController(currencyManager, paymentsCredentialsGenerator),
new ProfileController(clock, rateLimiters, accountsManager, profilesManager, usernamesManager, dynamicConfigurationManager, profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, config.getCdnConfiguration().getBucket(), zkProfileOperations), new ProfileController(clock, rateLimiters, accountsManager, profilesManager, usernamesManager, dynamicConfigurationManager, profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, config.getCdnConfiguration().getBucket(), zkProfileOperations),
new ProvisioningController(rateLimiters, provisioningManager), new ProvisioningController(rateLimiters, provisioningManager),

View File

@ -21,22 +21,6 @@ public class DynamicMessageRateConfiguration {
@JsonProperty @JsonProperty
private Set<String> rateLimitedHosts = Collections.emptySet(); private Set<String> rateLimitedHosts = Collections.emptySet();
@JsonProperty
private Duration responseDelay = Duration.ofNanos(1_200_000);
@JsonProperty
private Duration responseDelayJitter = Duration.ofNanos(500_000);
@JsonProperty
private Duration receiptDelay = Duration.ofMillis(1_200);
@JsonProperty
private Duration receiptDelayJitter = Duration.ofMillis(800);
@JsonProperty
private double receiptProbability = 0.82;
public boolean isEnforceUnsealedSenderRateLimit() { public boolean isEnforceUnsealedSenderRateLimit() {
return enforceUnsealedSenderRateLimit; return enforceUnsealedSenderRateLimit;
} }
@ -48,24 +32,4 @@ public class DynamicMessageRateConfiguration {
public Set<String> getRateLimitedHosts() { public Set<String> getRateLimitedHosts() {
return rateLimitedHosts; return rateLimitedHosts;
} }
public Duration getResponseDelay() {
return responseDelay;
}
public Duration getResponseDelayJitter() {
return responseDelayJitter;
}
public Duration getReceiptDelay() {
return receiptDelay;
}
public Duration getReceiptDelayJitter() {
return receiptDelayJitter;
}
public double getReceiptProbability() {
return receiptProbability;
}
} }

View File

@ -35,13 +35,10 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Random;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -69,7 +66,6 @@ import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys; import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices; import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices; import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
@ -134,11 +130,8 @@ public class MessageController {
private final DynamicConfigurationManager dynamicConfigurationManager; private final DynamicConfigurationManager dynamicConfigurationManager;
private final RateLimitChallengeManager rateLimitChallengeManager; private final RateLimitChallengeManager rateLimitChallengeManager;
private final ReportMessageManager reportMessageManager; private final ReportMessageManager reportMessageManager;
private final ScheduledExecutorService receiptExecutorService;
private final ExecutorService multiRecipientMessageExecutor; private final ExecutorService multiRecipientMessageExecutor;
private final Random random = new Random();
private final ClusterLuaScript recordInternationalUnsealedSenderMetricsScript; private final ClusterLuaScript recordInternationalUnsealedSenderMetricsScript;
private static final String LEGACY_MESSAGE_SENT_COUNTER = name(MessageController.class, "legacyMessageSent"); private static final String LEGACY_MESSAGE_SENT_COUNTER = name(MessageController.class, "legacyMessageSent");
@ -146,7 +139,6 @@ public class MessageController {
private static final String REJECT_UNSEALED_SENDER_COUNTER_NAME = name(MessageController.class, "rejectUnsealedSenderLimit"); private static final String REJECT_UNSEALED_SENDER_COUNTER_NAME = name(MessageController.class, "rejectUnsealedSenderLimit");
private static final String INTERNATIONAL_UNSEALED_SENDER_COUNTER_NAME = name(MessageController.class, "internationalUnsealedSender"); private static final String INTERNATIONAL_UNSEALED_SENDER_COUNTER_NAME = name(MessageController.class, "internationalUnsealedSender");
private static final String UNSEALED_SENDER_WITHOUT_PUSH_TOKEN_COUNTER_NAME = name(MessageController.class, "unsealedSenderWithoutPushToken"); private static final String UNSEALED_SENDER_WITHOUT_PUSH_TOKEN_COUNTER_NAME = name(MessageController.class, "unsealedSenderWithoutPushToken");
private static final String DECLINED_DELIVERY_COUNTER = name(MessageController.class, "declinedDelivery");
private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize"); private static final String CONTENT_SIZE_DISTRIBUTION_NAME = name(MessageController.class, "messageContentSize");
private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes"); private static final String OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME = name(MessageController.class, "outgoingMessageListSizeBytes");
@ -168,7 +160,6 @@ public class MessageController {
RateLimitChallengeManager rateLimitChallengeManager, RateLimitChallengeManager rateLimitChallengeManager,
ReportMessageManager reportMessageManager, ReportMessageManager reportMessageManager,
FaultTolerantRedisCluster metricsCluster, FaultTolerantRedisCluster metricsCluster,
ScheduledExecutorService receiptExecutorService,
@Nonnull ExecutorService multiRecipientMessageExecutor) { @Nonnull ExecutorService multiRecipientMessageExecutor) {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.messageSender = messageSender; this.messageSender = messageSender;
@ -180,7 +171,6 @@ public class MessageController {
this.dynamicConfigurationManager = dynamicConfigurationManager; this.dynamicConfigurationManager = dynamicConfigurationManager;
this.rateLimitChallengeManager = rateLimitChallengeManager; this.rateLimitChallengeManager = rateLimitChallengeManager;
this.reportMessageManager = reportMessageManager; this.reportMessageManager = reportMessageManager;
this.receiptExecutorService = receiptExecutorService;
this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor); this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor);
try { try {
@ -303,7 +293,7 @@ public class MessageController {
.orElse(false); .orElse(false);
if (isRateLimitedHost) { if (isRateLimitedHost) {
return declineDelivery(messages, source.get().getAccount(), destination.get()); throw new RateLimitExceededException(Duration.ofDays(1));
} }
} }
} }
@ -481,47 +471,6 @@ public class MessageController {
} }
} }
private Response declineDelivery(final IncomingMessageList messages, final Account source, final Account destination) {
Metrics.counter(DECLINED_DELIVERY_COUNTER, SENDER_COUNTRY_TAG_NAME, Util.getCountryCode(source.getNumber())).increment();
final DynamicMessageRateConfiguration messageRateConfiguration = dynamicConfigurationManager.getConfiguration().getMessageRateConfiguration();
{
final long timestamp = System.currentTimeMillis();
for (final IncomingMessage message : messages.getMessages()) {
final long jitterNanos = random.nextInt((int) messageRateConfiguration.getReceiptDelayJitter().toNanos());
final Duration receiptDelay = messageRateConfiguration.getReceiptDelay().plusNanos(jitterNanos);
if (random.nextDouble() <= messageRateConfiguration.getReceiptProbability()) {
receiptExecutorService.schedule(() -> {
try {
receiptSender.sendReceipt(
new AuthenticatedAccount(() -> new Pair<>(destination, destination.getMasterDevice().get())),
source.getUuid(), timestamp);
} catch (final NoSuchUserException ignored) {
}
}, receiptDelay.toMillis(), TimeUnit.MILLISECONDS);
}
}
}
{
Duration responseDelay = Duration.ZERO;
for (int i = 0; i < messages.getMessages().size(); i++) {
final long jitterNanos = random.nextInt((int) messageRateConfiguration.getResponseDelayJitter().toNanos());
responseDelay = responseDelay.plus(
messageRateConfiguration.getResponseDelay()).plusNanos(jitterNanos);
}
Util.sleep(responseDelay.toMillis());
}
return Response.ok(new SendMessageResponse(source.getEnabledDeviceCount() > 1)).build();
}
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)

View File

@ -10,7 +10,6 @@ import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
@ -45,7 +44,6 @@ public class MessageControllerMetricsTest extends AbstractRedisClusterTest {
mock(RateLimitChallengeManager.class), mock(RateLimitChallengeManager.class),
mock(ReportMessageManager.class), mock(ReportMessageManager.class),
getRedisCluster(), getRedisCluster(),
mock(ScheduledExecutorService.class),
mock(ExecutorService.class)); mock(ExecutorService.class));
} }

View File

@ -15,7 +15,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.anyString;
@ -47,8 +46,6 @@ import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
@ -122,19 +119,18 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final RedisAdvancedClusterCommands<String, String> redisCommands = mock(RedisAdvancedClusterCommands.class); private static final RedisAdvancedClusterCommands<String, String> redisCommands = mock(RedisAdvancedClusterCommands.class);
private static final MessageSender messageSender = mock(MessageSender.class); private static final MessageSender messageSender = mock(MessageSender.class);
private static final ReceiptSender receiptSender = mock(ReceiptSender.class); private static final ReceiptSender receiptSender = mock(ReceiptSender.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class); private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final MessagesManager messagesManager = mock(MessagesManager.class); private static final MessagesManager messagesManager = mock(MessagesManager.class);
private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class); private static final RateLimiter rateLimiter = mock(RateLimiter.class);
private static final UnsealedSenderRateLimiter unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class); private static final UnsealedSenderRateLimiter unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class);
private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class); private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class); private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private static final FaultTolerantRedisCluster metricsCluster = RedisClusterHelper.buildMockRedisCluster(redisCommands); private static final FaultTolerantRedisCluster metricsCluster = RedisClusterHelper.buildMockRedisCluster(redisCommands);
private static final ScheduledExecutorService receiptExecutor = mock(ScheduledExecutorService.class);
private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class); private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class);
private final ObjectMapper mapper = new ObjectMapper(); private final ObjectMapper mapper = new ObjectMapper();
@ -148,8 +144,7 @@ class MessageControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, .addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager,
messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager,
rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor, rateLimitChallengeManager, reportMessageManager, metricsCluster, multiRecipientMessageExecutor))
multiRecipientMessageExecutor))
.build(); .build();
@BeforeEach @BeforeEach
@ -187,12 +182,6 @@ class MessageControllerTest {
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
when(receiptExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(
(Answer<ScheduledFuture<?>>) invocation -> {
invocation.getArgument(0, Runnable.class).run();
return mock(ScheduledFuture.class);
});
} }
@AfterEach @AfterEach
@ -210,8 +199,7 @@ class MessageControllerTest {
dynamicConfigurationManager, dynamicConfigurationManager,
rateLimitChallengeManager, rateLimitChallengeManager,
reportMessageManager, reportMessageManager,
metricsCluster, metricsCluster
receiptExecutor
); );
} }
@ -271,11 +259,6 @@ class MessageControllerTest {
when(dynamicConfiguration.getMessageRateConfiguration()).thenReturn(messageRateConfiguration); when(dynamicConfiguration.getMessageRateConfiguration()).thenReturn(messageRateConfiguration);
when(messageRateConfiguration.getRateLimitedCountryCodes()).thenReturn(Set.of("1")); when(messageRateConfiguration.getRateLimitedCountryCodes()).thenReturn(Set.of("1"));
when(messageRateConfiguration.getRateLimitedHosts()).thenReturn(Set.of(senderHost)); when(messageRateConfiguration.getRateLimitedHosts()).thenReturn(Set.of(senderHost));
when(messageRateConfiguration.getResponseDelay()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getResponseDelayJitter()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getReceiptDelay()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getReceiptDelayJitter()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getReceiptProbability()).thenReturn(1.0);
when(redisCommands.evalsha(any(), any(), any(), any())).thenReturn(List.of(1L, 1L)); when(redisCommands.evalsha(any(), any(), any(), any())).thenReturn(List.of(1L, 1L));
@ -288,10 +271,9 @@ class MessageControllerTest {
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class), .put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat(response.getStatus(), is(equalTo(413)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
verify(receiptSender).sendReceipt(any(), eq(AuthHelper.VALID_UUID), anyLong());
} }
@ParameterizedTest @ParameterizedTest
@ -304,11 +286,6 @@ class MessageControllerTest {
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(dynamicConfiguration.getMessageRateConfiguration()).thenReturn(messageRateConfiguration); when(dynamicConfiguration.getMessageRateConfiguration()).thenReturn(messageRateConfiguration);
when(messageRateConfiguration.isEnforceUnsealedSenderRateLimit()).thenReturn(true); when(messageRateConfiguration.isEnforceUnsealedSenderRateLimit()).thenReturn(true);
when(messageRateConfiguration.getResponseDelay()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getResponseDelayJitter()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getReceiptDelay()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getReceiptDelayJitter()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getReceiptProbability()).thenReturn(1.0);
DynamicRateLimitChallengeConfiguration dynamicRateLimitChallengeConfiguration = mock( DynamicRateLimitChallengeConfiguration dynamicRateLimitChallengeConfiguration = mock(
DynamicRateLimitChallengeConfiguration.class); DynamicRateLimitChallengeConfiguration.class);