diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index e8efd4835..b91b4b1bb 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -65,10 +65,12 @@ import javax.ws.rs.core.Response; import java.io.IOException; import java.security.MessageDigest; import java.security.SecureRandom; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static com.codahale.metrics.MetricRegistry.name; import io.dropwizard.auth.Auth; @@ -117,7 +119,7 @@ public class AccountController { @Path("/{transport}/code/{number}") public Response createAccount(@PathParam("transport") String transport, @PathParam("number") String number, - @HeaderParam("X-Forwarded-For") String requester, + @HeaderParam("X-Forwarded-For") String forwardedFor, @HeaderParam("Accept-Language") Optional locale, @QueryParam("client") Optional client) throws IOException, RateLimitExceededException @@ -127,27 +129,36 @@ public class AccountController { throw new WebApplicationException(Response.status(400).build()); } - List abuseRules = abusiveHostRules.getAbusiveHostRulesFor(requester); + List requesters = Arrays.stream(forwardedFor.split(",")).map(String::trim).collect(Collectors.toList()); - for (AbusiveHostRule abuseRule : abuseRules) { - if (abuseRule.isBlocked()) { - logger.info("Blocked host: " + transport + ", " + number + ", " + requester); - return Response.ok().build(); - } - - if (!abuseRule.getRegions().isEmpty()) { - if (abuseRule.getRegions().stream().noneMatch(number::startsWith)) { - logger.info("Restricted host: " + transport + ", " + number + ", " + requester); - return Response.ok().build(); - } - } + if (requesters.size() > 10) { + logger.info("Request with more than 10 hops: " + transport + ", " + number + ", " + forwardedFor); + return Response.status(400).build(); } - try { - rateLimiters.getSmsVoiceIpLimiter().validate(requester); - } catch (RateLimitExceededException e) { - logger.info("Rate limited exceeded: " + transport + ", " + number + ", " + requester); - return Response.ok().build(); + for (String requester : requesters) { + List abuseRules = abusiveHostRules.getAbusiveHostRulesFor(requester); + + for (AbusiveHostRule abuseRule : abuseRules) { + if (abuseRule.isBlocked()) { + logger.info("Blocked host: " + transport + ", " + number + ", " + requester + " (" + forwardedFor + ")"); + return Response.ok().build(); + } + + if (!abuseRule.getRegions().isEmpty()) { + if (abuseRule.getRegions().stream().noneMatch(number::startsWith)) { + logger.info("Restricted host: " + transport + ", " + number + ", " + requester + " (" + forwardedFor + ")"); + return Response.ok().build(); + } + } + } + + try { + rateLimiters.getSmsVoiceIpLimiter().validate(requester); + } catch (RateLimitExceededException e) { + logger.info("Rate limited exceeded: " + transport + ", " + number + ", " + requester + " (" + forwardedFor + ")"); + return Response.ok().build(); + } } switch (transport) { diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java index 7d2195d92..fa690c402 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java @@ -175,6 +175,25 @@ public class AccountControllerTest { verifyNoMoreInteractions(smsSender); } + @Test + public void testSendMultipleHost() { + Response response = + resources.getJerseyTest() + .target(String.format("/v1/accounts/sms/code/%s", SENDER)) + .request() + .header("X-Forwarded-For", NICE_HOST + ", " + ABUSIVE_HOST) + .get(); + + assertThat(response.getStatus()).isEqualTo(200); + + verify(abusiveHostRules, times(1)).getAbusiveHostRulesFor(eq(ABUSIVE_HOST)); + verify(abusiveHostRules, times(1)).getAbusiveHostRulesFor(eq(NICE_HOST)); + + verifyNoMoreInteractions(abusiveHostRules); + verifyNoMoreInteractions(smsSender); + } + + @Test public void testSendRestrictedHostOut() { Response response =