diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index f398b5123..7e1b73e63 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -488,10 +488,13 @@ class AccountControllerTest { final boolean locateLinkByUuid, final int expectedStatus) { - MockUtils.updateRateLimiterResponseToAllow( - rateLimiters, RateLimiters.For.USERNAME_LINK_LOOKUP_PER_IP, NICE_HOST); - MockUtils.updateRateLimiterResponseToFail( - rateLimiters, RateLimiters.For.USERNAME_LINK_LOOKUP_PER_IP, RATE_LIMITED_IP_HOST, Duration.ofMinutes(10), false); + if (passRateLimiting) { + MockUtils.updateRateLimiterResponseToAllow( + rateLimiters, RateLimiters.For.USERNAME_LINK_LOOKUP_PER_IP, "127.0.0.1"); + } else { + MockUtils.updateRateLimiterResponseToFail( + rateLimiters, RateLimiters.For.USERNAME_LINK_LOOKUP_PER_IP, "127.0.0.1", Duration.ofMinutes(10), false); + } when(accountsManager.getByUsernameLinkHandle(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); @@ -509,9 +512,7 @@ class AccountControllerTest { if (!stayUnauthenticated) { builder.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); } - final Response get = builder - .header(HttpHeaders.X_FORWARDED_FOR, passRateLimiting ? NICE_HOST : RATE_LIMITED_IP_HOST) - .get(); + final Response get = builder.get(); assertEquals(expectedStatus, get.getStatus()); } @@ -864,21 +865,18 @@ class AccountControllerTest { assertThat(resources.getJerseyTest() .target(String.format("/v1/accounts/account/%s", accountIdentifier)) .request() - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .head() .getStatus()).isEqualTo(200); assertThat(resources.getJerseyTest() .target(String.format("/v1/accounts/account/PNI:%s", phoneNumberIdentifier)) .request() - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .head() .getStatus()).isEqualTo(200); assertThat(resources.getJerseyTest() .target(String.format("/v1/accounts/account/%s", UUID.randomUUID())) .request() - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .head() .getStatus()).isEqualTo(404); } @@ -896,32 +894,18 @@ class AccountControllerTest { final Response response = resources.getJerseyTest() .target(String.format("/v1/accounts/account/%s", accountIdentifier)) .request() - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .head(); assertThat(response.getStatus()).isEqualTo(413); assertThat(response.getHeaderString("Retry-After")).isEqualTo(String.valueOf(expectedRetryAfter.toSeconds())); } - @Test - void testAccountExistsNoForwardedFor() throws RateLimitExceededException { - final Response response = resources.getJerseyTest() - .target(String.format("/v1/accounts/account/%s", UUID.randomUUID())) - .request() - .header(HttpHeaders.X_FORWARDED_FOR, "") - .head(); - - assertThat(response.getStatus()).isEqualTo(413); - assertThat(Long.parseLong(response.getHeaderString("Retry-After"))).isNotNegative(); - } - @Test void testAccountExistsAuthenticated() { assertThat(resources.getJerseyTest() .target(String.format("/v1/accounts/account/%s", UUID.randomUUID())) .request() .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .head() .getStatus()).isEqualTo(400); } @@ -936,7 +920,6 @@ class AccountControllerTest { Response response = resources.getJerseyTest() .target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1)) .request() - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .get(); assertThat(response.getStatus()).isEqualTo(200); assertThat(response.readEntity(AccountIdentifierResponse.class).uuid().uuid()).isEqualTo(uuid); @@ -948,7 +931,6 @@ class AccountControllerTest { assertThat(resources.getJerseyTest() .target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1)) .request() - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .get().getStatus()).isEqualTo(404); } @@ -960,7 +942,6 @@ class AccountControllerTest { final Response response = resources.getJerseyTest() .target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1)) .request() - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .get(); assertThat(response.getStatus()).isEqualTo(413); @@ -973,7 +954,6 @@ class AccountControllerTest { .target(String.format("/v1/accounts/username_hash/%s", USERNAME_HASH_1)) .request() .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .get() .getStatus()).isEqualTo(400); } @@ -983,14 +963,12 @@ class AccountControllerTest { assertThat(resources.getJerseyTest() .target(String.format("/v1/accounts/username_hash/%s", INVALID_USERNAME_HASH)) .request() - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .get() .getStatus()).isEqualTo(422); assertThat(resources.getJerseyTest() .target(String.format("/v1/accounts/username_hash/%s", TOO_SHORT_USERNAME_HASH)) .request() - .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .get() .getStatus()).isEqualTo(422); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java index ca5532281..3a6ccf763 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java @@ -146,13 +146,14 @@ class ChallengeControllerTest { } final Response response = EXTENSION.target("/v1/challenge") .request() - .header(HttpHeaders.X_FORWARDED_FOR, "10.0.0.1") .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(recaptchaChallengeJson)); assertEquals(200, response.getStatus()); - verify(rateLimitChallengeManager).answerRecaptchaChallenge(eq(AuthHelper.VALID_ACCOUNT), eq("The value of the solved captcha token"), eq("10.0.0.1"), anyString(), eq(hasThreshold ? Optional.of(0.5f) : Optional.empty())); + verify(rateLimitChallengeManager).answerRecaptchaChallenge(eq(AuthHelper.VALID_ACCOUNT), + eq("The value of the solved captcha token"), eq("127.0.0.1"), anyString(), + eq(hasThreshold ? Optional.of(0.5f) : Optional.empty())); } @Test @@ -164,12 +165,12 @@ class ChallengeControllerTest { "captcha": "The value of the solved captcha token" } """; - when(rateLimitChallengeManager.answerRecaptchaChallenge(eq(AuthHelper.VALID_ACCOUNT), eq("The value of the solved captcha token"), eq("10.0.0.1"), anyString(), any())) + when(rateLimitChallengeManager.answerRecaptchaChallenge(eq(AuthHelper.VALID_ACCOUNT), + eq("The value of the solved captcha token"), eq("127.0.0.1"), anyString(), any())) .thenReturn(false); final Response response = EXTENSION.target("/v1/challenge") .request() - .header(HttpHeaders.X_FORWARDED_FOR, "10.0.0.1") .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(recaptchaChallengeJson)); @@ -192,7 +193,6 @@ class ChallengeControllerTest { final Response response = EXTENSION.target("/v1/challenge") .request() - .header(HttpHeaders.X_FORWARDED_FOR, "10.0.0.1") .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(recaptchaChallengeJson)); @@ -200,25 +200,6 @@ class ChallengeControllerTest { assertEquals(String.valueOf(retryAfter.toSeconds()), response.getHeaderString("Retry-After")); } - @Test - void testHandleRecaptchaNoForwardedFor() { - final String recaptchaChallengeJson = """ - { - "type": "recaptcha", - "token": "A server-generated token", - "captcha": "The value of the solved captcha token" - } - """; - - final Response response = EXTENSION.target("/v1/challenge") - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.json(recaptchaChallengeJson)); - - assertEquals(400, response.getStatus()); - verifyNoInteractions(rateLimitChallengeManager); - } - @Test void testHandleUnrecognizedAnswer() { final String unrecognizedJson = """ diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java index fe5e2421f..3f63155d6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java @@ -6,6 +6,11 @@ package org.whispersystems.textsecuregcm.limits; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import com.google.common.net.HttpHeaders; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; @@ -17,7 +22,6 @@ import javax.ws.rs.core.Response; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mockito; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -25,18 +29,10 @@ import org.whispersystems.textsecuregcm.util.SystemMapper; @ExtendWith(DropwizardExtensionsSupport.class) public class RateLimitedByIpTest { - private static final String IP = "70.130.130.200"; - - private static final String VALID_X_FORWARDED_FOR = "1.1.1.1," + IP; - - private static final String INVALID_X_FORWARDED_FOR = "1.1.1.1,"; + private static final String IP = "127.0.0.1"; private static final Duration RETRY_AFTER = Duration.ofSeconds(100); - private static final Duration RETRY_AFTER_INVALID_HEADER = RateLimitByIpFilter.INVALID_HEADER_EXCEPTION - .getRetryDuration() - .orElseThrow(); - @Path("/test") public static class Controller { @@ -55,10 +51,10 @@ public class RateLimitedByIpTest { } } - private static final RateLimiter RATE_LIMITER = Mockito.mock(RateLimiter.class); + private static final RateLimiter RATE_LIMITER = mock(RateLimiter.class); private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rl -> - Mockito.when(rl.forDescriptor(Mockito.eq(RateLimiters.For.BACKUP_AUTH_CHECK))).thenReturn(RATE_LIMITER)); + when(rl.forDescriptor(eq(RateLimiters.For.BACKUP_AUTH_CHECK))).thenReturn(RATE_LIMITER)); private static final ResourceExtension RESOURCES = ResourceExtension.builder() .setMapper(SystemMapper.jsonMapper()) @@ -69,49 +65,29 @@ public class RateLimitedByIpTest { @Test public void testRateLimits() throws Exception { - Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP)); - validateSuccess("/test/strict", VALID_X_FORWARDED_FOR); - Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.eq(IP)); - validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER); - Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP)); - validateSuccess("/test/strict", VALID_X_FORWARDED_FOR); - Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.eq(IP)); - validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER); + doNothing().when(RATE_LIMITER).validate(eq(IP)); + validateSuccess("/test/strict"); + doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(eq(IP)); + validateFailure("/test/strict", RETRY_AFTER); + doNothing().when(RATE_LIMITER).validate(eq(IP)); + validateSuccess("/test/strict"); + doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(eq(IP)); + validateFailure("/test/strict", RETRY_AFTER); } - @Test - public void testInvalidHeader() throws Exception { - Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP)); - validateSuccess("/test/strict", VALID_X_FORWARDED_FOR); - validateFailure("/test/strict", INVALID_X_FORWARDED_FOR, RETRY_AFTER_INVALID_HEADER); - validateFailure("/test/strict", "", RETRY_AFTER_INVALID_HEADER); - - validateSuccess("/test/loose", VALID_X_FORWARDED_FOR); - validateSuccess("/test/loose", INVALID_X_FORWARDED_FOR); - validateSuccess("/test/loose", ""); - - // also checking that even if rate limiter is failing -- it doesn't matter in the case of invalid IP - Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.anyString()); - validateFailure("/test/loose", VALID_X_FORWARDED_FOR, RETRY_AFTER); - validateSuccess("/test/loose", INVALID_X_FORWARDED_FOR); - validateSuccess("/test/loose", ""); - } - - private static void validateSuccess(final String path, final String xff) { + private static void validateSuccess(final String path) { final Response response = RESOURCES.getJerseyTest() .target(path) .request() - .header(HttpHeaders.X_FORWARDED_FOR, xff) .get(); assertEquals(200, response.getStatus()); } - private static void validateFailure(final String path, final String xff, final Duration expectedRetryAfter) { + private static void validateFailure(final String path, final Duration expectedRetryAfter) { final Response response = RESOURCES.getJerseyTest() .target(path) .request() - .header(HttpHeaders.X_FORWARDED_FOR, xff) .get(); assertEquals(413, response.getStatus());