From 68e2c511b73d4beace25b8ec9d76c58356d417cd Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Thu, 13 Feb 2025 16:21:29 -0600 Subject: [PATCH] Split up backup-id rotation rate limits --- .../backup/BackupAuthManager.java | 21 ++- .../textsecuregcm/limits/RateLimiters.java | 3 +- .../backup/BackupAuthManagerTest.java | 150 ++++++++++++------ 3 files changed, 124 insertions(+), 50 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java index f4a762025..fdb661167 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java @@ -14,6 +14,7 @@ import java.time.temporal.ChronoUnit; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.stream.Stream; import javax.annotation.Nullable; import org.signal.libsignal.zkgroup.GenericServerSecretParams; @@ -114,9 +115,17 @@ public class BackupAuthManager { return CompletableFuture.completedFuture(null); } - return rateLimiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID) - .validateAsync(account.getUuid()) - .thenCompose(ignored -> this.accountsManager + CompletionStage rateLimitFuture = rateLimiters + .forDescriptor(RateLimiters.For.SET_BACKUP_ID) + .validateAsync(account.getUuid()); + + if (!mediaCredentialRequestMatches && hasActiveVoucher(account)) { + rateLimitFuture = rateLimitFuture.thenCombine( + rateLimiters.forDescriptor(RateLimiters.For.SET_PAID_MEDIA_BACKUP_ID).validateAsync(account.getUuid()), + (ignore1, ignore2) -> null); + } + + return rateLimitFuture.thenCompose(ignored -> this.accountsManager .updateAsync(account, a -> a.setBackupCredentialRequests(serializedMessageCredentialRequest, serializedMediaCredentialRequest)) .thenRun(Util.NOOP)) .toCompletableFuture(); @@ -280,8 +289,12 @@ public class BackupAuthManager { return next; } + private boolean hasActiveVoucher(final Account account) { + return account.getBackupVoucher() != null && clock.instant().isBefore(account.getBackupVoucher().expiration()); + } + private boolean hasExpiredVoucher(final Account account) { - return account.getBackupVoucher() != null && clock.instant().isAfter(account.getBackupVoucher().expiration()); + return account.getBackupVoucher() != null && !hasActiveVoucher(account); } /** diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index 05c8347c6..800a9b050 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -41,7 +41,8 @@ public class RateLimiters extends BaseRateLimiters { RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, Duration.ofHours(12))), CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))), CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))), - SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(2, Duration.ofDays(7))), + SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(10, Duration.ofHours(1))), + SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", true, new RateLimiterConfig(5, Duration.ofDays(7))), PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))), PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))), GET_CALLING_RELAYS("getCallingRelays", false, new RateLimiterConfig(100, Duration.ofMinutes(10))), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java index aed5e61bd..b1871422c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java @@ -85,10 +85,14 @@ public class BackupAuthManagerTest { reset(redeemedReceiptsManager); } - BackupAuthManager create(@Nullable BackupLevel backupLevel, boolean rateLimit) { + BackupAuthManager create(@Nullable BackupLevel backupLevel) { + return create(backupLevel, rateLimiter(aci, false, false)); + } + + BackupAuthManager create(@Nullable BackupLevel backupLevel, RateLimiters rateLimiters) { return new BackupAuthManager( ExperimentHelper.withEnrollment(experimentName(backupLevel), aci), - rateLimit ? denyRateLimiter(aci) : allowRateLimiter(), + rateLimiters, accountsManager, new ServerZkReceiptOperations(receiptParams), redeemedReceiptsManager, @@ -98,7 +102,7 @@ public class BackupAuthManagerTest { @Test void commitBackupId() { - final BackupAuthManager authManager = create(BackupLevel.FREE, false); + final BackupAuthManager authManager = create(BackupLevel.FREE); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -124,7 +128,7 @@ public class BackupAuthManagerTest { @EnumSource @NullSource void commitRequiresBackupLevel(final BackupLevel backupLevel) { - final BackupAuthManager authManager = create(backupLevel, false); + final BackupAuthManager authManager = create(backupLevel); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); @@ -147,7 +151,7 @@ public class BackupAuthManagerTest { void getBackupAuthCredentials(@CartesianTest.Enum final BackupLevel backupLevel, @CartesianTest.Enum final BackupCredentialType credentialType) { - final BackupAuthManager authManager = create(backupLevel, false); + final BackupAuthManager authManager = create(backupLevel); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -166,7 +170,7 @@ public class BackupAuthManagerTest { @ParameterizedTest @EnumSource void getBackupAuthCredentialsNoBackupLevel(final BackupCredentialType credentialType) { - final BackupAuthManager authManager = create(null, false); + final BackupAuthManager authManager = create(null); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -187,7 +191,7 @@ public class BackupAuthManagerTest { @CartesianTest void getReceiptCredentials(@CartesianTest.Enum final BackupLevel backupLevel, @CartesianTest.Enum final BackupCredentialType credentialType) throws VerificationFailedException { - final BackupAuthManager authManager = create(backupLevel, false); + final BackupAuthManager authManager = create(backupLevel); final byte[] backupKey = switch (credentialType) { case MESSAGES -> messagesBackupKey; @@ -244,7 +248,7 @@ public class BackupAuthManagerTest { @MethodSource void invalidCredentialTimeWindows(final Instant requestRedemptionStart, final Instant requestRedemptionEnd, final Instant now) { - final BackupAuthManager authManager = create(BackupLevel.FREE, false); + final BackupAuthManager authManager = create(BackupLevel.FREE); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -268,7 +272,7 @@ public class BackupAuthManagerTest { final Instant day4 = Instant.EPOCH.plus(Duration.ofDays(4)); final Instant dayMax = day0.plus(BackupAuthManager.MAX_REDEMPTION_DURATION); - final BackupAuthManager authManager = create(BackupLevel.FREE, false); + final BackupAuthManager authManager = create(BackupLevel.FREE); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -301,7 +305,7 @@ public class BackupAuthManagerTest { final Instant day2 = Instant.EPOCH.plus(Duration.ofDays(2)); final Instant day3 = Instant.EPOCH.plus(Duration.ofDays(3)); - final BackupAuthManager authManager = create(BackupLevel.FREE, false); + final BackupAuthManager authManager = create(BackupLevel.FREE); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); when(account.getBackupVoucher()).thenReturn(new Account.BackupVoucher(3, day1)); @@ -341,7 +345,7 @@ public class BackupAuthManagerTest { @Test void redeemReceipt() throws InvalidInputException, VerificationFailedException { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); - final BackupAuthManager authManager = create(BackupLevel.FREE, false); + final BackupAuthManager authManager = create(BackupLevel.FREE); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -358,7 +362,7 @@ public class BackupAuthManagerTest { final Instant newExpirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); final Instant existingExpirationTime = Instant.EPOCH.plus(Duration.ofDays(1)).plus(Duration.ofSeconds(1)); - final BackupAuthManager authManager = create(BackupLevel.FREE, false); + final BackupAuthManager authManager = create(BackupLevel.FREE); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -383,7 +387,7 @@ public class BackupAuthManagerTest { void redeemExpiredReceipt() { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); clock.pin(expirationTime.plus(Duration.ofSeconds(1))); - final BackupAuthManager authManager = create(BackupLevel.FREE, false); + final BackupAuthManager authManager = create(BackupLevel.FREE); assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), receiptPresentation(3, expirationTime)).join()) .extracting(ex -> ex.getStatus().getCode()) @@ -397,7 +401,7 @@ public class BackupAuthManagerTest { void redeemInvalidLevel(long level) { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); clock.pin(expirationTime.plus(Duration.ofSeconds(1))); - final BackupAuthManager authManager = create(BackupLevel.FREE, false); + final BackupAuthManager authManager = create(BackupLevel.FREE); assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), receiptPresentation(level, expirationTime)).join()) @@ -409,7 +413,7 @@ public class BackupAuthManagerTest { @Test void redeemInvalidPresentation() throws InvalidInputException, VerificationFailedException { - final BackupAuthManager authManager = create(BackupLevel.FREE, false); + final BackupAuthManager authManager = create(BackupLevel.FREE); final ReceiptCredentialPresentation invalid = receiptPresentation(ServerSecretParams.generate(), 3L, Instant.EPOCH); assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), invalid).join()) @@ -422,7 +426,7 @@ public class BackupAuthManagerTest { @Test void receiptAlreadyRedeemed() throws InvalidInputException, VerificationFailedException { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); - final BackupAuthManager authManager = create(BackupLevel.FREE, false); + final BackupAuthManager authManager = create(BackupLevel.FREE); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -459,30 +463,85 @@ public class BackupAuthManagerTest { } - @Test - void testRateLimits() { - final AccountsManager accountsManager = mock(AccountsManager.class); - final BackupAuthManager authManager = create(BackupLevel.FREE, true); + @CartesianTest + void testChangeIdRateLimits( + @CartesianTest.Values(booleans = {true, false}) boolean changeMessage, + @CartesianTest.Values(booleans = {true, false}) boolean changeMedia, + @CartesianTest.Values(booleans = {true, false}) boolean rateLimitBackupId) { - final BackupAuthCredentialRequest messagesCredential = backupAuthTestUtil.getRequest(messagesBackupKey, aci); - final BackupAuthCredentialRequest mediaCredential = backupAuthTestUtil.getRequest(mediaBackupKey, aci); + final BackupAuthManager authManager = create(BackupLevel.FREE, rateLimiter(aci, rateLimitBackupId, false)); + final BackupAuthCredentialRequest storedMessagesCredential = backupAuthTestUtil.getRequest(messagesBackupKey, aci); + final BackupAuthCredentialRequest storedMediaCredential = backupAuthTestUtil.getRequest(mediaBackupKey, aci); + final Account account = mockAccount(storedMessagesCredential, storedMediaCredential, null); - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(aci); - when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); + final BackupAuthCredentialRequest newMessagesCredential = changeMessage + ? backupAuthTestUtil.getRequest(TestRandomUtil.nextBytes(32), aci) + : storedMessagesCredential; - // Should be rate limited - CompletableFutureTestUtil.assertFailsWithCause(RateLimitExceededException.class, - authManager.commitBackupId(account, messagesCredential, mediaCredential)); + final BackupAuthCredentialRequest newMediaCredential = changeMedia + ? backupAuthTestUtil.getRequest(TestRandomUtil.nextBytes(32), aci) + : storedMediaCredential; - // If we don't change the request, shouldn't be rate limited - when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); - when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) - .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); - assertDoesNotThrow(() -> authManager.commitBackupId(account, messagesCredential, mediaCredential).join()); + final boolean expectRateLimit = (changeMedia || changeMessage) && rateLimitBackupId; + final CompletableFuture future = authManager.commitBackupId(account, newMessagesCredential, newMediaCredential); + if (expectRateLimit) { + CompletableFutureTestUtil.assertFailsWithCause(RateLimitExceededException.class, future); + } else { + assertDoesNotThrow(() -> future.join()); + } } + @CartesianTest + void testChangePaidMediaIdRateLimits( + @CartesianTest.Values(booleans = {true, false}) boolean changeMessage, + @CartesianTest.Values(booleans = {true, false}) boolean changeMedia, + @CartesianTest.Values(booleans = {true, false}) boolean paid, + @CartesianTest.Values(booleans = {true, false}) boolean rateLimitPaidMedia) { + + final BackupAuthManager authManager = create(BackupLevel.FREE, rateLimiter(aci, false, rateLimitPaidMedia)); + final BackupAuthCredentialRequest storedMessagesCredential = backupAuthTestUtil.getRequest(messagesBackupKey, aci); + final BackupAuthCredentialRequest storedMediaCredential = backupAuthTestUtil.getRequest(mediaBackupKey, aci); + // Set clock before the voucher expires if paid, otherwise after + final Account.BackupVoucher backupVoucher = new Account.BackupVoucher(1, Instant.ofEpochSecond(100)); + clock.pin(paid ? Instant.ofEpochSecond(99) : Instant.ofEpochSecond(101)); + + final Account account = mockAccount(storedMessagesCredential, storedMediaCredential, backupVoucher); + + final BackupAuthCredentialRequest newMessagesCredential = changeMessage + ? backupAuthTestUtil.getRequest(TestRandomUtil.nextBytes(32), aci) + : storedMessagesCredential; + + final BackupAuthCredentialRequest newMediaCredential = changeMedia + ? backupAuthTestUtil.getRequest(TestRandomUtil.nextBytes(32), aci) + : storedMediaCredential; + + // We should get rate limited iff we are out of paid media changes and we changed the media backup-id + final boolean expectRateLimit = changeMedia && paid && rateLimitPaidMedia; + final CompletableFuture future = authManager.commitBackupId(account, newMessagesCredential, newMediaCredential); + if (expectRateLimit) { + CompletableFutureTestUtil.assertFailsWithCause(RateLimitExceededException.class, future); + } else { + assertDoesNotThrow(() -> future.join()); + } + } + + private Account mockAccount(final BackupAuthCredentialRequest storedMessagesCredential, final BackupAuthCredentialRequest storedMediaCredential, Account.BackupVoucher backupVoucher) { + final Account account = mock(Account.class); + when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); + if (storedMessagesCredential != null) { + when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) + .thenReturn(Optional.of(storedMessagesCredential.serialize())); + } + if (storedMediaCredential != null) { + when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) + .thenReturn(Optional.of(storedMediaCredential.serialize())); + } + when(account.getUuid()).thenReturn(aci); + when(account.getBackupVoucher()).thenReturn(backupVoucher); + return account; + } + + private static String experimentName(@Nullable BackupLevel backupLevel) { return switch (backupLevel) { case FREE -> BackupAuthManager.BACKUP_EXPERIMENT_NAME; @@ -491,20 +550,21 @@ public class BackupAuthManagerTest { }; } - private static RateLimiters allowRateLimiter() { + private static RateLimiters rateLimiter(final UUID aci, boolean rateLimitBackupId, + boolean rateLimitPaidMediaBackupId) { final RateLimiters limiters = mock(RateLimiters.class); - final RateLimiter limiter = mock(RateLimiter.class); - when(limiter.validateAsync(any(UUID.class))).thenReturn(CompletableFuture.completedFuture(null)); - when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter); - return limiters; - } - private static RateLimiters denyRateLimiter(final UUID aci) { - final RateLimiters limiters = mock(RateLimiters.class); - final RateLimiter limiter = mock(RateLimiter.class); - when(limiter.validateAsync(aci)) + final RateLimiter allowLimiter = mock(RateLimiter.class); + when(allowLimiter.validateAsync(aci)).thenReturn(CompletableFuture.completedFuture(null)); + + final RateLimiter denyLimiter = mock(RateLimiter.class); + when(denyLimiter.validateAsync(aci)) .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null))); - when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter); + + when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)) + .thenReturn(rateLimitBackupId ? denyLimiter : allowLimiter); + when(limiters.forDescriptor(RateLimiters.For.SET_PAID_MEDIA_BACKUP_ID)) + .thenReturn(rateLimitPaidMediaBackupId ? denyLimiter : allowLimiter); return limiters; } }