Add second (migration) database to `AbusiveHostRules`

This commit is contained in:
Chris Eager 2021-12-08 12:46:05 -08:00 committed by GitHub
parent 9a5ffea0ad
commit a70b057e1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 450 additions and 154 deletions

View File

@ -221,6 +221,12 @@ abuseDatabase: # Postgresql database configuration
password: password password: password
url: jdbc:postgresql://example.com:5432/abusedb url: jdbc:postgresql://example.com:5432/abusedb
newAbuseDatabase: # Postgresql database configuration
driverClass: org.postgresql.Driver
user: example
password: password
url: jdbc:postgresql://new.example.com:5432/abusedb
accountDatabaseCrawler: accountDatabaseCrawler:
chunkSize: 10 # accounts per run chunkSize: 10 # accounts per run
chunkIntervalMs: 60000 # time per run chunkIntervalMs: 60000 # time per run

View File

@ -215,6 +215,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty @JsonProperty
private DatabaseConfiguration abuseDatabase; private DatabaseConfiguration abuseDatabase;
@Valid
@NotNull
@JsonProperty
private DatabaseConfiguration newAbuseDatabase;
@Valid @Valid
@NotNull @NotNull
@JsonProperty @JsonProperty
@ -456,6 +461,10 @@ public class WhisperServerConfiguration extends Configuration {
return abuseDatabase; return abuseDatabase;
} }
public DatabaseConfiguration getNewAbuseDatabaseConfiguration() {
return newAbuseDatabase;
}
public RateLimitsConfiguration getLimitsConfiguration() { public RateLimitsConfiguration getLimitsConfiguration() {
return limits; return limits;
} }

View File

@ -214,6 +214,7 @@ import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.textsecuregcm.workers.CertificateCommand; import org.whispersystems.textsecuregcm.workers.CertificateCommand;
import org.whispersystems.textsecuregcm.workers.CheckDynamicConfigurationCommand; import org.whispersystems.textsecuregcm.workers.CheckDynamicConfigurationCommand;
import org.whispersystems.textsecuregcm.workers.DeleteUserCommand; import org.whispersystems.textsecuregcm.workers.DeleteUserCommand;
import org.whispersystems.textsecuregcm.workers.MigrateAbusiveHostRulesCommand;
import org.whispersystems.textsecuregcm.workers.ReserveUsernameCommand; import org.whispersystems.textsecuregcm.workers.ReserveUsernameCommand;
import org.whispersystems.textsecuregcm.workers.ServerVersionCommand; import org.whispersystems.textsecuregcm.workers.ServerVersionCommand;
import org.whispersystems.textsecuregcm.workers.SetCrawlerAccelerationTask; import org.whispersystems.textsecuregcm.workers.SetCrawlerAccelerationTask;
@ -242,6 +243,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
bootstrap.addCommand(new CheckDynamicConfigurationCommand()); bootstrap.addCommand(new CheckDynamicConfigurationCommand());
bootstrap.addCommand(new SetUserDiscoverabilityCommand()); bootstrap.addCommand(new SetUserDiscoverabilityCommand());
bootstrap.addCommand(new ReserveUsernameCommand()); bootstrap.addCommand(new ReserveUsernameCommand());
bootstrap.addCommand(new MigrateAbusiveHostRulesCommand());
bootstrap.addBundle(new NameableMigrationsBundle<WhisperServerConfiguration>("abusedb", "abusedb.xml") { bootstrap.addBundle(new NameableMigrationsBundle<WhisperServerConfiguration>("abusedb", "abusedb.xml") {
@Override @Override
@ -305,9 +307,14 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
headerControlledResourceBundleLookup); headerControlledResourceBundleLookup);
JdbiFactory jdbiFactory = new JdbiFactory(DefaultNameStrategy.CHECK_EMPTY); JdbiFactory jdbiFactory = new JdbiFactory(DefaultNameStrategy.CHECK_EMPTY);
Jdbi abuseJdbi = jdbiFactory.build(environment, config.getAbuseDatabaseConfiguration(), "abusedb" ); Jdbi abuseJdbi = jdbiFactory.build(environment, config.getAbuseDatabaseConfiguration(), "abusedb");
FaultTolerantDatabase abuseDatabase = new FaultTolerantDatabase("abuse_database", abuseJdbi, config.getAbuseDatabaseConfiguration().getCircuitBreakerConfiguration()); FaultTolerantDatabase abuseDatabase = new FaultTolerantDatabase("abuse_database", abuseJdbi,
config.getAbuseDatabaseConfiguration().getCircuitBreakerConfiguration());
Jdbi newAbuseJdbi = jdbiFactory.build(environment, config.getAbuseDatabaseConfiguration(), "abusedb2");
FaultTolerantDatabase newAbuseDatabase = new FaultTolerantDatabase("abuse_database2", newAbuseJdbi,
config.getAbuseDatabaseConfiguration().getCircuitBreakerConfiguration());
DynamoDbAsyncClient dynamoDbAsyncClient = DynamoDbFromConfig.asyncClient( DynamoDbAsyncClient dynamoDbAsyncClient = DynamoDbFromConfig.asyncClient(
config.getDynamoDbClientConfiguration(), config.getDynamoDbClientConfiguration(),
@ -365,6 +372,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDeletedAccountsDynamoDbConfiguration().getTableName(), config.getDeletedAccountsDynamoDbConfiguration().getTableName(),
config.getDeletedAccountsDynamoDbConfiguration().getNeedsReconciliationIndexName()); config.getDeletedAccountsDynamoDbConfiguration().getNeedsReconciliationIndexName());
DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
new DynamicConfigurationManager<>(config.getAppConfig().getApplication(),
config.getAppConfig().getEnvironment(),
config.getAppConfig().getConfigurationName(),
DynamicConfiguration.class);
Accounts accounts = new Accounts(accountsDynamoDbClient, Accounts accounts = new Accounts(accountsDynamoDbClient,
config.getAccountsDynamoDbConfiguration().getTableName(), config.getAccountsDynamoDbConfiguration().getTableName(),
config.getAccountsDynamoDbConfiguration().getPhoneNumberTableName(), config.getAccountsDynamoDbConfiguration().getPhoneNumberTableName(),
@ -381,12 +394,19 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(messageDynamoDb, MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(messageDynamoDb,
config.getMessageDynamoDbConfiguration().getTableName(), config.getMessageDynamoDbConfiguration().getTableName(),
config.getMessageDynamoDbConfiguration().getTimeToLive()); config.getMessageDynamoDbConfiguration().getTimeToLive());
AbusiveHostRules abusiveHostRules = new AbusiveHostRules(abuseDatabase); AbusiveHostRules abusiveHostRules = new AbusiveHostRules(abuseDatabase, newAbuseDatabase,
RemoteConfigs remoteConfigs = new RemoteConfigs(dynamoDbClient, config.getDynamoDbTables().getRemoteConfig().getTableName()); dynamicConfigurationManager);
PushChallengeDynamoDb pushChallengeDynamoDb = new PushChallengeDynamoDb(pushChallengeDynamoDbClient, config.getPushChallengeDynamoDbConfiguration().getTableName()); RemoteConfigs remoteConfigs = new RemoteConfigs(dynamoDbClient,
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(reportMessageDynamoDbClient, config.getReportMessageDynamoDbConfiguration().getTableName(), config.getReportMessageConfiguration().getReportTtl()); config.getDynamoDbTables().getRemoteConfig().getTableName());
VerificationCodeStore pendingAccounts = new VerificationCodeStore(pendingAccountsDynamoDbClient, config.getPendingAccountsDynamoDbConfiguration().getTableName()); PushChallengeDynamoDb pushChallengeDynamoDb = new PushChallengeDynamoDb(pushChallengeDynamoDbClient,
VerificationCodeStore pendingDevices = new VerificationCodeStore(pendingDevicesDynamoDbClient, config.getPendingDevicesDynamoDbConfiguration().getTableName()); config.getPushChallengeDynamoDbConfiguration().getTableName());
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(reportMessageDynamoDbClient,
config.getReportMessageDynamoDbConfiguration().getTableName(),
config.getReportMessageConfiguration().getReportTtl());
VerificationCodeStore pendingAccounts = new VerificationCodeStore(pendingAccountsDynamoDbClient,
config.getPendingAccountsDynamoDbConfiguration().getTableName());
VerificationCodeStore pendingDevices = new VerificationCodeStore(pendingDevicesDynamoDbClient,
config.getPendingDevicesDynamoDbConfiguration().getTableName());
RedisClientFactory pubSubClientFactory = new RedisClientFactory("pubsub_cache", config.getPubsubCacheConfiguration().getUrl(), config.getPubsubCacheConfiguration().getReplicaUrls(), config.getPubsubCacheConfiguration().getCircuitBreakerConfiguration()); RedisClientFactory pubSubClientFactory = new RedisClientFactory("pubsub_cache", config.getPubsubCacheConfiguration().getUrl(), config.getPubsubCacheConfiguration().getReplicaUrls(), config.getPubsubCacheConfiguration().getCircuitBreakerConfiguration());
ReplicatedJedisPool pubsubClient = pubSubClientFactory.getRedisClientPool(); ReplicatedJedisPool pubsubClient = pubSubClientFactory.getRedisClientPool();
@ -430,12 +450,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDirectoryV2Configuration().getDirectoryV2ClientConfiguration() config.getDirectoryV2Configuration().getDirectoryV2ClientConfiguration()
.getUserAuthenticationTokenSharedSecret(), false); .getUserAuthenticationTokenSharedSecret(), false);
DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
new DynamicConfigurationManager<>(config.getAppConfig().getApplication(),
config.getAppConfig().getEnvironment(),
config.getAppConfig().getConfigurationName(),
DynamicConfiguration.class);
dynamicConfigurationManager.start(); dynamicConfigurationManager.start();
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(dynamicConfigurationManager); ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(dynamicConfigurationManager);

View File

@ -0,0 +1,32 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration.dynamic;
import com.fasterxml.jackson.annotation.JsonProperty;
public class DynamicAbusiveHostRulesMigrationConfiguration {
@JsonProperty
private boolean newReadEnabled = false;
@JsonProperty
private boolean newWriteEnabled = false;
@JsonProperty
private boolean newPrimary = false;
public boolean isNewReadEnabled() {
return newReadEnabled;
}
public boolean isNewWriteEnabled() {
return newWriteEnabled;
}
public boolean isNewPrimary() {
return newPrimary;
}
}

View File

@ -59,6 +59,10 @@ public class DynamicConfiguration {
@Valid @Valid
private DynamicProfileMigrationConfiguration profileMigration = new DynamicProfileMigrationConfiguration(); private DynamicProfileMigrationConfiguration profileMigration = new DynamicProfileMigrationConfiguration();
@JsonProperty
@Valid
private DynamicAbusiveHostRulesMigrationConfiguration abusiveHostRulesMigration = new DynamicAbusiveHostRulesMigrationConfiguration();
public Optional<DynamicExperimentEnrollmentConfiguration> getExperimentEnrollmentConfiguration( public Optional<DynamicExperimentEnrollmentConfiguration> getExperimentEnrollmentConfiguration(
final String experimentName) { final String experimentName) {
return Optional.ofNullable(experiments.get(experimentName)); return Optional.ofNullable(experiments.get(experimentName));
@ -117,4 +121,8 @@ public class DynamicConfiguration {
public DynamicProfileMigrationConfiguration getProfileMigrationConfiguration() { public DynamicProfileMigrationConfiguration getProfileMigrationConfiguration() {
return profileMigration; return profileMigration;
} }
public DynamicAbusiveHostRulesMigrationConfiguration getAbusiveHostRulesMigrationConfiguration() {
return abusiveHostRulesMigration;
}
} }

View File

@ -5,12 +5,11 @@
package org.whispersystems.textsecuregcm.experiment; package org.whispersystems.textsecuregcm.experiment;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -18,115 +17,120 @@ import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Supplier; import java.util.function.Supplier;
import org.slf4j.Logger;
import static com.codahale.metrics.MetricRegistry.name; import org.slf4j.LoggerFactory;
/** /**
* An experiment compares the results of two operations and records metrics to assess how frequently they match. * An experiment compares the results of two operations and records metrics to assess how frequently they match.
*/ */
public class Experiment { public class Experiment {
private final String name; private final String name;
private final Timer matchTimer; private final Timer matchTimer;
private final Timer errorTimer; private final Timer errorTimer;
private final Timer bothPresentMismatchTimer; private final Timer bothPresentMismatchTimer;
private final Timer controlNullMismatchTimer; private final Timer controlNullMismatchTimer;
private final Timer experimentNullMismatchTimer; private final Timer experimentNullMismatchTimer;
private static final String OUTCOME_TAG = "outcome"; private static final String OUTCOME_TAG = "outcome";
private static final String MATCH_OUTCOME = "match"; private static final String MATCH_OUTCOME = "match";
private static final String MISMATCH_OUTCOME = "mismatch"; private static final String MISMATCH_OUTCOME = "mismatch";
private static final String ERROR_OUTCOME = "error"; private static final String ERROR_OUTCOME = "error";
private static final String MISMATCH_TYPE_TAG = "mismatchType"; private static final String MISMATCH_TYPE_TAG = "mismatchType";
private static final String BOTH_PRESENT_MISMATCH = "bothPresent"; private static final String BOTH_PRESENT_MISMATCH = "bothPresent";
private static final String CONTROL_NULL_MISMATCH = "controlResultNull"; private static final String CONTROL_NULL_MISMATCH = "controlResultNull";
private static final String EXPERIMENT_NULL_MISMATCH = "experimentResultNull"; private static final String EXPERIMENT_NULL_MISMATCH = "experimentResultNull";
private static final Logger log = LoggerFactory.getLogger(Experiment.class); private static final Logger log = LoggerFactory.getLogger(Experiment.class);
public Experiment(final String... names) { public Experiment(final String... names) {
this(name(Experiment.class, names), this(name(Experiment.class, names),
Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MATCH_OUTCOME), Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MATCH_OUTCOME),
Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, ERROR_OUTCOME), Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, ERROR_OUTCOME),
Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG, BOTH_PRESENT_MISMATCH), Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG,
Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG, CONTROL_NULL_MISMATCH), BOTH_PRESENT_MISMATCH),
Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG, EXPERIMENT_NULL_MISMATCH)); Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG,
CONTROL_NULL_MISMATCH),
Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG,
EXPERIMENT_NULL_MISMATCH));
}
@VisibleForTesting
Experiment(final String name, final Timer matchTimer, final Timer errorTimer, final Timer bothPresentMismatchTimer,
final Timer controlNullMismatchTimer, final Timer experimentNullMismatchTimer) {
this.name = name;
this.matchTimer = matchTimer;
this.errorTimer = errorTimer;
this.bothPresentMismatchTimer = bothPresentMismatchTimer;
this.controlNullMismatchTimer = controlNullMismatchTimer;
this.experimentNullMismatchTimer = experimentNullMismatchTimer;
}
public <T> void compareFutureResult(final T expected, final CompletionStage<T> experimentStage) {
final long startNanos = System.nanoTime();
experimentStage.whenComplete((actual, cause) -> {
final long durationNanos = System.nanoTime() - startNanos;
if (cause != null) {
recordError(cause, durationNanos);
} else {
recordResult(expected, actual, durationNanos);
}
});
}
public <T> void compareSupplierResult(final T expected, final Supplier<T> experimentSupplier) {
final long startNanos = System.nanoTime();
try {
final T result = experimentSupplier.get();
recordResult(expected, result, System.nanoTime() - startNanos);
} catch (final Exception e) {
recordError(e, System.nanoTime() - startNanos);
} }
}
@VisibleForTesting public <T> void compareSupplierResultAsync(final T expected, final Supplier<T> experimentSupplier,
Experiment(final String name, final Timer matchTimer, final Timer errorTimer, final Timer bothPresentMismatchTimer, final Timer controlNullMismatchTimer, final Timer experimentNullMismatchTimer) { final Executor executor) {
this.name = name; final long startNanos = System.nanoTime();
this.matchTimer = matchTimer; try {
this.errorTimer = errorTimer; compareFutureResult(expected, CompletableFuture.supplyAsync(experimentSupplier, executor));
} catch (final Exception e) {
this.bothPresentMismatchTimer = bothPresentMismatchTimer; recordError(e, System.nanoTime() - startNanos);
this.controlNullMismatchTimer = controlNullMismatchTimer;
this.experimentNullMismatchTimer = experimentNullMismatchTimer;
} }
}
public <T> void compareFutureResult(final T expected, final CompletionStage<T> experimentStage) { private void recordError(final Throwable cause, final long durationNanos) {
final long startNanos = System.nanoTime(); log.warn("Experiment {} threw an exception.", name, cause);
errorTimer.record(durationNanos, TimeUnit.NANOSECONDS);
}
experimentStage.whenComplete((actual, cause) -> { @VisibleForTesting
final long durationNanos = System.nanoTime() - startNanos; <T> void recordResult(final T expected, final T actual, final long durationNanos) {
if (expected instanceof Optional && actual instanceof Optional) {
recordResult(((Optional) expected).orElse(null), ((Optional) actual).orElse(null), durationNanos);
} else {
final Timer Timer;
if (cause != null) { if (Objects.equals(expected, actual)) {
recordError(cause, durationNanos); Timer = matchTimer;
} else { } else if (expected == null) {
recordResult(expected, actual, durationNanos); Timer = controlNullMismatchTimer;
} } else if (actual == null) {
}); Timer = experimentNullMismatchTimer;
} } else {
Timer = bothPresentMismatchTimer;
public <T> void compareSupplierResult(final T expected, final Supplier<T> experimentSupplier) { }
final long startNanos = System.nanoTime();
Timer.record(durationNanos, TimeUnit.NANOSECONDS);
try {
final T result = experimentSupplier.get();
recordResult(expected, result, System.nanoTime() - startNanos);
} catch (final Exception e) {
recordError(e, System.nanoTime() - startNanos);
}
}
public <T> void compareSupplierResultAsync(final T expected, final Supplier<T> experimentSupplier, final Executor executor) {
final long startNanos = System.nanoTime();
try {
compareFutureResult(expected, CompletableFuture.supplyAsync(experimentSupplier, executor));
} catch (final Exception e) {
recordError(e, System.nanoTime() - startNanos);
}
}
private void recordError(final Throwable cause, final long durationNanos) {
log.warn("Experiment {} threw an exception.", name, cause);
errorTimer.record(durationNanos, TimeUnit.NANOSECONDS);
}
@VisibleForTesting
<T> void recordResult(final T expected, final T actual, final long durationNanos) {
if (expected instanceof Optional && actual instanceof Optional) {
recordResult(((Optional)expected).orElse(null), ((Optional)actual).orElse(null), durationNanos);
} else {
final Timer Timer;
if (Objects.equals(expected, actual)) {
Timer = matchTimer;
} else if (expected == null) {
Timer = controlNullMismatchTimer;
} else if (actual == null) {
Timer = experimentNullMismatchTimer;
} else {
Timer = bothPresentMismatchTimer;
}
Timer.record(durationNanos, TimeUnit.NANOSECONDS);
}
} }
}
} }

View File

@ -5,56 +5,133 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer; import com.codahale.metrics.Timer;
import com.google.common.base.Suppliers;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.storage.mappers.AbusiveHostRuleRowMapper; import org.whispersystems.textsecuregcm.storage.mappers.AbusiveHostRuleRowMapper;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import java.util.List;
import static com.codahale.metrics.MetricRegistry.name;
public class AbusiveHostRules { public class AbusiveHostRules {
public static final String ID = "id"; private static final Logger logger = LoggerFactory.getLogger(AbusiveHostRules.class);
public static final String HOST = "host";
public static final String ID = "id";
public static final String HOST = "host";
public static final String BLOCKED = "blocked"; public static final String BLOCKED = "blocked";
public static final String REGIONS = "regions"; public static final String REGIONS = "regions";
public static final String NOTES = "notes"; public static final String NOTES = "notes";
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Timer getTimer = metricRegistry.timer(name(AbusiveHostRules.class, "get")); private final Timer getTimer = metricRegistry.timer(name(AbusiveHostRules.class, "get"));
private final Timer insertTimer = metricRegistry.timer(name(AbusiveHostRules.class, "setBlockedHost")); private final Timer insertTimer = metricRegistry.timer(name(AbusiveHostRules.class, "setBlockedHost"));
private final FaultTolerantDatabase database; private final FaultTolerantDatabase oldDatabase;
private final FaultTolerantDatabase newDatabase;
public AbusiveHostRules(FaultTolerantDatabase database) { private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
this.database = database; private final Experiment migrationExperiment = new Experiment("abusiveHostRulesMigration");
this.database.getDatabase().registerRowMapper(new AbusiveHostRuleRowMapper());
public AbusiveHostRules(FaultTolerantDatabase oldDatabase, FaultTolerantDatabase newDatabase,
DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.oldDatabase = oldDatabase;
this.oldDatabase.getDatabase().registerRowMapper(new AbusiveHostRuleRowMapper());
this.newDatabase = newDatabase;
this.newDatabase.getDatabase().registerRowMapper(new AbusiveHostRuleRowMapper());
this.dynamicConfigurationManager = dynamicConfigurationManager;
} }
public List<AbusiveHostRule> getAbusiveHostRulesFor(String host) { public List<AbusiveHostRule> getAbusiveHostRulesFor(String host) {
return database.with(jdbi -> jdbi.withHandle(handle -> { final List<AbusiveHostRule> oldDbRules = oldDatabase.with(jdbi -> jdbi.withHandle(handle -> {
try (Timer.Context timer = getTimer.time()) { try (Timer.Context timer = getTimer.time()) {
return handle.createQuery("SELECT * FROM abusive_host_rules WHERE :host::inet <<= " + HOST) return handle.createQuery("SELECT * FROM abusive_host_rules WHERE :host::inet <<= " + HOST)
.bind("host", host) .bind("host", host)
.mapTo(AbusiveHostRule.class) .mapTo(AbusiveHostRule.class)
.list(); .list();
} }
})); }));
final Supplier<List<AbusiveHostRule>> newDbRules = Suppliers.memoize(
() -> newDatabase.with(jdbi -> jdbi.withHandle(
handle -> handle.createQuery("SELECT * FROM abusive_host_rules WHERE :host::inet <<= " + HOST)
.bind("host", host)
.mapTo(AbusiveHostRule.class)
.list())));
if (dynamicConfigurationManager.getConfiguration().getAbusiveHostRulesMigrationConfiguration().isNewReadEnabled()) {
migrationExperiment.compareSupplierResult(oldDbRules, newDbRules);
}
return dynamicConfigurationManager.getConfiguration().getAbusiveHostRulesMigrationConfiguration().isNewPrimary()
? newDbRules.get()
: oldDbRules;
} }
public void setBlockedHost(String host, String notes) { public void setBlockedHost(String host, String notes) {
database.use(jdbi -> jdbi.useHandle(handle -> { oldDatabase.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context timer = insertTimer.time()) { try (Timer.Context timer = insertTimer.time()) {
handle.createUpdate("INSERT INTO abusive_host_rules(host, blocked, notes) VALUES(:host::inet, :blocked, :notes) ON CONFLICT DO NOTHING") handle.createUpdate(
.bind("host", host) "INSERT INTO abusive_host_rules(host, blocked, notes) VALUES(:host::inet, :blocked, :notes) ON CONFLICT DO NOTHING")
.bind("blocked", 1) .bind("host", host)
.bind("notes", notes) .bind("blocked", 1)
.execute(); .bind("notes", notes)
.execute();
}
}));
if (dynamicConfigurationManager.getConfiguration().getAbusiveHostRulesMigrationConfiguration()
.isNewWriteEnabled()) {
try {
newDatabase.use(jdbi -> jdbi.useHandle(handle -> handle.createUpdate(
"INSERT INTO abusive_host_rules(host, blocked, notes) VALUES(:host::inet, :blocked, :notes) ON CONFLICT DO NOTHING")
.bind("host", host)
.bind("blocked", 1)
.bind("notes", notes)
.execute()));
} catch (final Exception e) {
logger.warn("Failed to insert rule in new database", e);
}
}
}
public int migrateAbusiveHostRule(AbusiveHostRule rule, String notes) {
return newDatabase.with(jdbi -> jdbi.withHandle(handle -> {
try (Timer.Context timer = insertTimer.time()) {
return handle.createUpdate(
"INSERT INTO abusive_host_rules(host, blocked, notes, regions) VALUES(:host::inet, :blocked, :notes, :regions) ON CONFLICT DO NOTHING")
.bind("host", rule.getHost())
.bind("blocked", rule.isBlocked() ? 1 : 0)
.bind("notes", notes)
.bind("regions", String.join(",", rule.getRegions()))
.execute();
} }
})); }));
} }
public void forEachInOldDatabase(final BiConsumer<AbusiveHostRule, String> consumer, final int fetchSize) {
final AbusiveHostRuleRowMapper rowMapper = new AbusiveHostRuleRowMapper();
oldDatabase.use(jdbi -> jdbi.useHandle(handle -> handle.useTransaction(transactionHandle ->
transactionHandle.createQuery("SELECT * FROM abusive_host_rules")
.setFetchSize(fetchSize)
.map((resultSet, ctx) -> {
AbusiveHostRule rule = rowMapper.map(resultSet, ctx);
String notes = resultSet.getString(NOTES);
return new Pair<>(rule, notes);
})
.forEach(pair -> consumer.accept(pair.first(), pair.second())))));
}
} }

View File

@ -0,0 +1,92 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.workers;
import com.codahale.metrics.jdbi3.strategies.DefaultNameStrategy;
import io.dropwizard.Application;
import io.dropwizard.cli.EnvironmentCommand;
import io.dropwizard.jdbi3.JdbiFactory;
import io.dropwizard.setup.Environment;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.jdbi.v3.core.Jdbi;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase;
import java.util.concurrent.atomic.AtomicInteger;
public class MigrateAbusiveHostRulesCommand extends EnvironmentCommand<WhisperServerConfiguration> {
private static final Logger log = LoggerFactory.getLogger(MigrateAbusiveHostRulesCommand.class);
public MigrateAbusiveHostRulesCommand() {
super(new Application<>() {
@Override
public void run(WhisperServerConfiguration configuration, Environment environment) {
}
}, "migrate-abusive-host-rules", "Migrate abusive host rules from one Postgres to another");
}
@Override
public void configure(Subparser subparser) {
super.configure(subparser);
subparser.addArgument("-s", "--fetch-size")
.dest("fetchSize")
.type(Integer.class)
.required(false)
.setDefault(512)
.help("The number of rules to fetch from Postgres at once");
}
@Override
protected void run(final Environment environment, final Namespace namespace,
final WhisperServerConfiguration config) throws Exception {
DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
new DynamicConfigurationManager<>(config.getAppConfig().getApplication(),
config.getAppConfig().getEnvironment(),
config.getAppConfig().getConfigurationName(),
DynamicConfiguration.class);
JdbiFactory jdbiFactory = new JdbiFactory(DefaultNameStrategy.CHECK_EMPTY);
Jdbi abuseJdbi = jdbiFactory.build(environment, config.getAbuseDatabaseConfiguration(), "abusedb");
FaultTolerantDatabase abuseDatabase = new FaultTolerantDatabase("abuse_database", abuseJdbi,
config.getAbuseDatabaseConfiguration().getCircuitBreakerConfiguration());
Jdbi newAbuseJdbi = jdbiFactory.build(environment, config.getNewAbuseDatabaseConfiguration(), "abusedb2");
FaultTolerantDatabase newAbuseDatabase = new FaultTolerantDatabase("abuse_database2", newAbuseJdbi,
config.getNewAbuseDatabaseConfiguration().getCircuitBreakerConfiguration());
log.info("Beginning migration");
AbusiveHostRules abusiveHostRules = new AbusiveHostRules(abuseDatabase, newAbuseDatabase,
dynamicConfigurationManager);
final int fetchSize = namespace.getInt("fetchSize");
final AtomicInteger rulesMigrated = new AtomicInteger(0);
abusiveHostRules.forEachInOldDatabase((rule, notes) -> {
abusiveHostRules.migrateAbusiveHostRule(rule, notes);
int migrated = rulesMigrated.incrementAndGet();
if (migrated % 1_000 == 0) {
log.info("Migrated {} rules", migrated);
}
}, fetchSize);
log.info("Migration complete ({} total rules)", rulesMigrated.get());
}
}

View File

@ -6,39 +6,63 @@
package org.whispersystems.textsecuregcm.tests.storage; package org.whispersystems.textsecuregcm.tests.storage;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.zonky.test.db.postgres.embedded.LiquibasePreparer; import io.zonky.test.db.postgres.embedded.LiquibasePreparer;
import io.zonky.test.db.postgres.junit.EmbeddedPostgresRules; import io.zonky.test.db.postgres.junit5.EmbeddedPostgresExtension;
import io.zonky.test.db.postgres.junit.PreparedDbRule; import io.zonky.test.db.postgres.junit5.PreparedDbExtension;
import java.sql.PreparedStatement; import java.sql.PreparedStatement;
import java.sql.ResultSet; import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import org.jdbi.v3.core.Jdbi; import org.jdbi.v3.core.Jdbi;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Rule; import org.junit.jupiter.api.Test;
import org.junit.Test; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicAbusiveHostRulesMigrationConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRule; import org.whispersystems.textsecuregcm.storage.AbusiveHostRule;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules; import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase;
public class AbusiveHostRulesTest { class AbusiveHostRulesTest {
@Rule @RegisterExtension
public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("abusedb.xml")); PreparedDbExtension db = EmbeddedPostgresExtension.preparedDatabase(
LiquibasePreparer.forClasspathLocation("abusedb.xml"));
@RegisterExtension
PreparedDbExtension newDb = EmbeddedPostgresExtension.preparedDatabase(
LiquibasePreparer.forClasspathLocation("abusedb.xml"));
private AbusiveHostRules abusiveHostRules; private AbusiveHostRules abusiveHostRules;
@Before @BeforeEach
public void setup() { void setup() {
this.abusiveHostRules = new AbusiveHostRules(new FaultTolerantDatabase("abusive_hosts-test", Jdbi.create(db.getTestDatabase()), new CircuitBreakerConfiguration())); //noinspection unchecked
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(
DynamicConfigurationManager.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(dynamicConfiguration.getAbusiveHostRulesMigrationConfiguration()).thenReturn(
new DynamicAbusiveHostRulesMigrationConfiguration());
this.abusiveHostRules = new AbusiveHostRules(
new FaultTolerantDatabase("abusive_hosts-test", Jdbi.create(db.getTestDatabase()),
new CircuitBreakerConfiguration()),
new FaultTolerantDatabase("abusive_hosts-test", Jdbi.create(newDb.getTestDatabase()),
new CircuitBreakerConfiguration()),
dynamicConfigurationManager);
} }
@Test @Test
public void testBlockedHost() throws SQLException { void testBlockedHost() throws SQLException {
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); PreparedStatement statement = db.getTestDatabase().getConnection()
.prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)");
statement.setString(1, "192.168.1.1"); statement.setString(1, "192.168.1.1");
statement.setInt(2, 1); statement.setInt(2, 1);
statement.execute(); statement.execute();
@ -51,8 +75,9 @@ public class AbusiveHostRulesTest {
} }
@Test @Test
public void testBlockedCidr() throws SQLException { void testBlockedCidr() throws SQLException {
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); PreparedStatement statement = db.getTestDatabase().getConnection()
.prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)");
statement.setString(1, "192.168.1.0/24"); statement.setString(1, "192.168.1.0/24");
statement.setInt(2, 1); statement.setInt(2, 1);
statement.execute(); statement.execute();
@ -65,8 +90,9 @@ public class AbusiveHostRulesTest {
} }
@Test @Test
public void testUnblocked() throws SQLException { void testUnblocked() throws SQLException {
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); PreparedStatement statement = db.getTestDatabase().getConnection()
.prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)");
statement.setString(1, "192.168.1.0/24"); statement.setString(1, "192.168.1.0/24");
statement.setInt(2, 1); statement.setInt(2, 1);
statement.execute(); statement.execute();
@ -76,8 +102,9 @@ public class AbusiveHostRulesTest {
} }
@Test @Test
public void testRestricted() throws SQLException { void testRestricted() throws SQLException {
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked, regions) VALUES (?::INET, ?, ?)"); PreparedStatement statement = db.getTestDatabase().getConnection()
.prepareStatement("INSERT INTO abusive_host_rules (host, blocked, regions) VALUES (?::INET, ?, ?)");
statement.setString(1, "192.168.1.0/24"); statement.setString(1, "192.168.1.0/24");
statement.setInt(2, 0); statement.setInt(2, 0);
statement.setString(3, "+1,+49"); statement.setString(3, "+1,+49");
@ -90,10 +117,11 @@ public class AbusiveHostRulesTest {
} }
@Test @Test
public void testInsertBlocked() throws Exception { void testInsertBlocked() throws Exception {
abusiveHostRules.setBlockedHost("172.17.0.1", "Testing one two"); abusiveHostRules.setBlockedHost("172.17.0.1", "Testing one two");
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * from abusive_host_rules WHERE host = ?::inet"); PreparedStatement statement = db.getTestDatabase().getConnection()
.prepareStatement("SELECT * from abusive_host_rules WHERE host = ?::inet");
statement.setString(1, "172.17.0.1"); statement.setString(1, "172.17.0.1");
ResultSet resultSet = statement.executeQuery(); ResultSet resultSet = statement.executeQuery();
@ -106,8 +134,8 @@ public class AbusiveHostRulesTest {
abusiveHostRules.setBlockedHost("172.17.0.1", "Different notes"); abusiveHostRules.setBlockedHost("172.17.0.1", "Different notes");
statement = db.getTestDatabase().getConnection()
statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * from abusive_host_rules WHERE host = ?::inet"); .prepareStatement("SELECT * from abusive_host_rules WHERE host = ?::inet");
statement.setString(1, "172.17.0.1"); statement.setString(1, "172.17.0.1");
resultSet = statement.executeQuery(); resultSet = statement.executeQuery();
@ -119,4 +147,30 @@ public class AbusiveHostRulesTest {
assertThat(resultSet.getString("notes")).isEqualTo("Testing one two"); assertThat(resultSet.getString("notes")).isEqualTo("Testing one two");
} }
@Test
void testMigrate() throws Exception {
final int rules = 20;
for (int i = 1; i <= rules; i++) {
abusiveHostRules.setBlockedHost("172.17.0." + i, "Testing one two " + i);
}
PreparedStatement statement = newDb.getTestDatabase().getConnection()
.prepareStatement("SELECT * from abusive_host_rules");
assertThat(queryResultSize(statement.executeQuery())).isEqualTo(0);
abusiveHostRules.forEachInOldDatabase((rule, host) -> abusiveHostRules.migrateAbusiveHostRule(rule, host), 5);
assertThat(queryResultSize(statement.executeQuery())).isEqualTo(rules);
}
private int queryResultSize(ResultSet resultSet) throws SQLException {
int migrated = 0;
while (resultSet.next()) {
migrated++;
}
return migrated;
}
} }