Add support for host filtering

This commit is contained in:
Moxie Marlinspike 2018-12-17 14:46:40 -08:00
parent b97fd17146
commit 2daabd000f
7 changed files with 226 additions and 9 deletions

View File

@ -78,6 +78,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private DataSourceFactory messageStore;
@Valid
@NotNull
@JsonProperty
private DataSourceFactory abuseDatabase;
@Valid
@NotNull
@JsonProperty
@ -181,6 +186,10 @@ public class WhisperServerConfiguration extends Configuration {
return messageStore;
}
public DataSourceFactory getAbuseDatabaseConfiguration() {
return abuseDatabase;
}
public DataSourceFactory getDataSourceFactory() {
return database;
}

View File

@ -91,6 +91,7 @@ import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.db.DataSourceFactory;
import io.dropwizard.db.PooledDataSourceFactory;
import io.dropwizard.jdbi.DBIFactory;
import io.dropwizard.setup.Bootstrap;
import io.dropwizard.setup.Environment;
@ -122,6 +123,13 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
return configuration.getMessageStoreConfiguration();
}
});
bootstrap.addBundle(new NameableMigrationsBundle<WhisperServerConfiguration>("abusedb", "abusedb.xml") {
@Override
public PooledDataSourceFactory getDataSourceFactory(WhisperServerConfiguration configuration) {
return configuration.getAbuseDatabaseConfiguration();
}
});
}
@Override
@ -141,12 +149,14 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
DBIFactory dbiFactory = new DBIFactory();
DBI database = dbiFactory.build(environment, config.getDataSourceFactory(), "accountdb");
DBI messagedb = dbiFactory.build(environment, config.getMessageStoreConfiguration(), "messagedb");
DBI abusedb = dbiFactory.build(environment, config.getAbuseDatabaseConfiguration(), "abusedb");
Accounts accounts = database.onDemand(Accounts.class);
PendingAccounts pendingAccounts = database.onDemand(PendingAccounts.class);
PendingDevices pendingDevices = database.onDemand(PendingDevices.class);
Keys keys = database.onDemand(Keys.class);
Messages messages = messagedb.onDemand(Messages.class);
Accounts accounts = database.onDemand(Accounts.class );
PendingAccounts pendingAccounts = database.onDemand(PendingAccounts.class);
PendingDevices pendingDevices = database.onDemand(PendingDevices.class );
Keys keys = database.onDemand(Keys.class );
Messages messages = messagedb.onDemand(Messages.class);
AbusiveHostRules abusiveHostRules = abusedb.onDemand(AbusiveHostRules.class);
RedisClientFactory cacheClientFactory = new RedisClientFactory(config.getCacheConfiguration().getUrl(), config.getCacheConfiguration().getReplicaUrls() );
RedisClientFactory directoryClientFactory = new RedisClientFactory(config.getDirectoryConfiguration().getRedisConfiguration().getUrl(), config.getDirectoryConfiguration().getRedisConfiguration().getReplicaUrls() );
@ -209,7 +219,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.buildAuthFilter()));
environment.jersey().register(new AuthValueFactoryProvider.Binder<>(Account.class));
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender, directoryQueue, messagesManager, turnTokenGenerator, config.getTestDevices()));
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, abusiveHostRules, rateLimiters, smsSender, directoryQueue, messagesManager, turnTokenGenerator, config.getTestDevices()));
environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, messagesManager, directoryQueue, rateLimiters, config.getMaxDevices()));
environment.jersey().register(new DirectoryController(rateLimiters, directory, directoryCredentialsGenerator));
environment.jersey().register(new ProvisioningController(rateLimiters, pushSender));

View File

@ -38,6 +38,8 @@ import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRule;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@ -63,6 +65,7 @@ import javax.ws.rs.core.Response;
import java.io.IOException;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
@ -80,6 +83,7 @@ public class AccountController {
private final PendingAccountsManager pendingAccounts;
private final AccountsManager accounts;
private final AbusiveHostRules abusiveHostRules;
private final RateLimiters rateLimiters;
private final SmsSender smsSender;
private final DirectoryQueue directoryQueue;
@ -89,6 +93,7 @@ public class AccountController {
public AccountController(PendingAccountsManager pendingAccounts,
AccountsManager accounts,
AbusiveHostRules abusiveHostRules,
RateLimiters rateLimiters,
SmsSender smsSenderFactory,
DirectoryQueue directoryQueue,
@ -98,6 +103,7 @@ public class AccountController {
{
this.pendingAccounts = pendingAccounts;
this.accounts = accounts;
this.abusiveHostRules = abusiveHostRules;
this.rateLimiters = rateLimiters;
this.smsSender = smsSenderFactory;
this.directoryQueue = directoryQueue;
@ -121,6 +127,22 @@ public class AccountController {
throw new WebApplicationException(Response.status(400).build());
}
List<AbusiveHostRule> abuseRules = abusiveHostRules.getAbusiveHostRulesFor(requester);
for (AbusiveHostRule abuseRule : abuseRules) {
if (abuseRule.isBlocked()) {
logger.info("Blocked host: " + transport + ", " + number + ", " + requester);
return Response.ok().build();
}
if (!abuseRule.getRegions().isEmpty()) {
if (abuseRule.getRegions().stream().noneMatch(number::startsWith)) {
logger.info("Restricted host: " + transport + ", " + number + ", " + requester);
return Response.ok().build();
}
}
}
try {
rateLimiters.getSmsVoiceIpLimiter().validate(requester);
} catch (RateLimitExceededException e) {

View File

@ -0,0 +1,29 @@
package org.whispersystems.textsecuregcm.storage;
import java.net.InetAddress;
import java.util.List;
public class AbusiveHostRule {
private final String host;
private final boolean blocked;
private final List<String> regions;
public AbusiveHostRule(String host, boolean blocked, List<String> regions) {
this.host = host;
this.blocked = blocked;
this.regions = regions;
}
public List<String> getRegions() {
return regions;
}
public boolean isBlocked() {
return blocked;
}
public String getHost() {
return host;
}
}

View File

@ -0,0 +1,43 @@
package org.whispersystems.textsecuregcm.storage;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
public abstract class AbusiveHostRules {
private static final String ID = "id";
private static final String HOST = "host";
private static final String BLOCKED = "blocked";
private static final String REGIONS = "regions";
@Mapper(AbusiveHostRuleMapper.class)
@SqlQuery("SELECT * FROM abusive_host_rules WHERE :host::inet <<= " + HOST)
public abstract List<AbusiveHostRule> getAbusiveHostRulesFor(@Bind("host") String host);
public static class AbusiveHostRuleMapper implements ResultSetMapper<AbusiveHostRule> {
@Override
public AbusiveHostRule map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
String regionsData = resultSet.getString(REGIONS);
List<String> regions;
if (regionsData == null) regions = new LinkedList<>();
else regions = Arrays.asList(regionsData.split(","));
return new AbusiveHostRule(resultSet.getString(HOST), resultSet.getInt(BLOCKED) == 1, regions);
}
}
}

View File

@ -0,0 +1,31 @@
<?xml version="1.0" encoding="UTF-8"?>
<databaseChangeLog
xmlns="http://www.liquibase.org/xml/ns/dbchangelog"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog
http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-2.0.xsd">
<changeSet id="1" author="moxie">
<createTable tableName="abusive_host_rules">
<column name="id" type="bigint" autoIncrement="true">
<constraints primaryKey="true" nullable="false"/>
</column>
<column name="host" type="inet">
<constraints nullable="false" unique="true"/>
</column>
<column name="blocked" type="tinyint">
<constraints nullable="false"/>
</column>
<column name="regions" type="text"/>
</createTable>
<createIndex tableName="abusive_host_rules" indexName="host_index">
<column name="host"/>
</createIndex>
</changeSet>
</databaseChangeLog>

View File

@ -17,6 +17,8 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.providers.TimeProvider;
import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRule;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@ -27,7 +29,9 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
@ -44,8 +48,13 @@ public class AccountControllerTest {
private static final String SENDER_PIN = "+14153333333";
private static final String SENDER_OVER_PIN = "+14154444444";
private static final String ABUSIVE_HOST = "192.168.1.1";
private static final String RESTRICTED_HOST = "192.168.1.2";
private static final String NICE_HOST = "127.0.0.1";
private PendingAccountsManager pendingAccountsManager = mock(PendingAccountsManager.class);
private AccountsManager accountsManager = mock(AccountsManager.class );
private AbusiveHostRules abusiveHostRules = mock(AbusiveHostRules.class );
private RateLimiters rateLimiters = mock(RateLimiters.class );
private RateLimiter rateLimiter = mock(RateLimiter.class );
private RateLimiter pinLimiter = mock(RateLimiter.class );
@ -66,6 +75,7 @@ public class AccountControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new AccountController(pendingAccountsManager,
accountsManager,
abusiveHostRules,
rateLimiters,
smsSender,
directoryQueue,
@ -88,7 +98,6 @@ public class AccountControllerTest {
when(senderPinAccount.getPin()).thenReturn(Optional.of("31337"));
when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis());
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis())));
when(pendingAccountsManager.getCodeForNumber(SENDER_OLD)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(31))));
when(pendingAccountsManager.getCodeForNumber(SENDER_PIN)).thenReturn(Optional.of(new StoredVerificationCode("333333", System.currentTimeMillis())));
@ -99,6 +108,10 @@ public class AccountControllerTest {
when(accountsManager.get(eq(SENDER))).thenReturn(Optional.empty());
when(accountsManager.get(eq(SENDER_OLD))).thenReturn(Optional.empty());
when(abusiveHostRules.getAbusiveHostRulesFor(eq(ABUSIVE_HOST))).thenReturn(Collections.singletonList(new AbusiveHostRule(ABUSIVE_HOST, true, Collections.emptyList())));
when(abusiveHostRules.getAbusiveHostRulesFor(eq(RESTRICTED_HOST))).thenReturn(Collections.singletonList(new AbusiveHostRule(RESTRICTED_HOST, false, Collections.singletonList("+123"))));
when(abusiveHostRules.getAbusiveHostRulesFor(eq(NICE_HOST))).thenReturn(Collections.emptyList());
doThrow(new RateLimitExceededException(SENDER_OVER_PIN)).when(pinLimiter).validate(eq(SENDER_OVER_PIN));
}
@ -108,12 +121,13 @@ public class AccountControllerTest {
resources.getJerseyTest()
.target(String.format("/v1/accounts/sms/code/%s", SENDER))
.request()
.header("X-Forwarded-For", "127.0.0.1")
.header("X-Forwarded-For", NICE_HOST)
.get();
assertThat(response.getStatus()).isEqualTo(200);
verify(smsSender).deliverSmsVerification(eq(SENDER), eq(Optional.empty()), anyString());
verify(abusiveHostRules).getAbusiveHostRulesFor(eq(NICE_HOST));
}
@Test
@ -123,7 +137,7 @@ public class AccountControllerTest {
.target(String.format("/v1/accounts/sms/code/%s", SENDER))
.queryParam("client", "ios")
.request()
.header("X-Forwarded-For", "127.0.0.1")
.header("X-Forwarded-For", NICE_HOST)
.get();
assertThat(response.getStatus()).isEqualTo(200);
@ -131,6 +145,65 @@ public class AccountControllerTest {
verify(smsSender).deliverSmsVerification(eq(SENDER), eq(Optional.of("ios")), anyString());
}
@Test
public void testSendAndroidNgCode() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/sms/code/%s", SENDER))
.queryParam("client", "android-ng")
.request()
.header("X-Forwarded-For", NICE_HOST)
.get();
assertThat(response.getStatus()).isEqualTo(200);
verify(smsSender).deliverSmsVerification(eq(SENDER), eq(Optional.of("android-ng")), anyString());
}
@Test
public void testSendAbusiveHost() {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/sms/code/%s", SENDER))
.request()
.header("X-Forwarded-For", ABUSIVE_HOST)
.get();
assertThat(response.getStatus()).isEqualTo(200);
verify(abusiveHostRules).getAbusiveHostRulesFor(eq(ABUSIVE_HOST));
verifyNoMoreInteractions(smsSender);
}
@Test
public void testSendRestrictedHostOut() {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/sms/code/%s", SENDER))
.request()
.header("X-Forwarded-For", RESTRICTED_HOST)
.get();
assertThat(response.getStatus()).isEqualTo(200);
verify(abusiveHostRules).getAbusiveHostRulesFor(eq(RESTRICTED_HOST));
verifyNoMoreInteractions(smsSender);
}
@Test
public void testSendRestrictedIn() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/sms/code/%s", "+1234567890"))
.request()
.header("X-Forwarded-For", RESTRICTED_HOST)
.get();
assertThat(response.getStatus()).isEqualTo(200);
verify(smsSender).deliverSmsVerification(eq("+1234567890"), eq(Optional.empty()), anyString());
}
@Test
public void testVerifyCode() throws Exception {
Response response =