From a70b057e1c12c4620c1319486191a14a6ce1b718 Mon Sep 17 00:00:00 2001 From: Chris Eager <79161849+eager-signal@users.noreply.github.com> Date: Wed, 8 Dec 2021 12:46:05 -0800 Subject: [PATCH] Add second (migration) database to `AbusiveHostRules` --- service/config/sample.yml | 6 + .../WhisperServerConfiguration.java | 9 + .../textsecuregcm/WhisperServerService.java | 42 ++-- ...busiveHostRulesMigrationConfiguration.java | 32 +++ .../dynamic/DynamicConfiguration.java | 8 + .../textsecuregcm/experiment/Experiment.java | 192 +++++++++--------- .../storage/AbusiveHostRules.java | 123 ++++++++--- .../MigrateAbusiveHostRulesCommand.java | 92 +++++++++ .../tests/storage/AbusiveHostRulesTest.java | 100 ++++++--- 9 files changed, 450 insertions(+), 154 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicAbusiveHostRulesMigrationConfiguration.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateAbusiveHostRulesCommand.java diff --git a/service/config/sample.yml b/service/config/sample.yml index 88c2d1060..2c7292df5 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -221,6 +221,12 @@ abuseDatabase: # Postgresql database configuration password: password url: jdbc:postgresql://example.com:5432/abusedb +newAbuseDatabase: # Postgresql database configuration + driverClass: org.postgresql.Driver + user: example + password: password + url: jdbc:postgresql://new.example.com:5432/abusedb + accountDatabaseCrawler: chunkSize: 10 # accounts per run chunkIntervalMs: 60000 # time per run diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index f60b92628..05b7c0a4f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -215,6 +215,11 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private DatabaseConfiguration abuseDatabase; + @Valid + @NotNull + @JsonProperty + private DatabaseConfiguration newAbuseDatabase; + @Valid @NotNull @JsonProperty @@ -456,6 +461,10 @@ public class WhisperServerConfiguration extends Configuration { return abuseDatabase; } + public DatabaseConfiguration getNewAbuseDatabaseConfiguration() { + return newAbuseDatabase; + } + public RateLimitsConfiguration getLimitsConfiguration() { return limits; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index b30958ff8..73f274981 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -214,6 +214,7 @@ import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator; import org.whispersystems.textsecuregcm.workers.CertificateCommand; import org.whispersystems.textsecuregcm.workers.CheckDynamicConfigurationCommand; import org.whispersystems.textsecuregcm.workers.DeleteUserCommand; +import org.whispersystems.textsecuregcm.workers.MigrateAbusiveHostRulesCommand; import org.whispersystems.textsecuregcm.workers.ReserveUsernameCommand; import org.whispersystems.textsecuregcm.workers.ServerVersionCommand; import org.whispersystems.textsecuregcm.workers.SetCrawlerAccelerationTask; @@ -242,6 +243,7 @@ public class WhisperServerService extends Application("abusedb", "abusedb.xml") { @Override @@ -305,9 +307,14 @@ public class WhisperServerService extends Application dynamicConfigurationManager = + new DynamicConfigurationManager<>(config.getAppConfig().getApplication(), + config.getAppConfig().getEnvironment(), + config.getAppConfig().getConfigurationName(), + DynamicConfiguration.class); + Accounts accounts = new Accounts(accountsDynamoDbClient, config.getAccountsDynamoDbConfiguration().getTableName(), config.getAccountsDynamoDbConfiguration().getPhoneNumberTableName(), @@ -381,12 +394,19 @@ public class WhisperServerService extends Application dynamicConfigurationManager = - new DynamicConfigurationManager<>(config.getAppConfig().getApplication(), - config.getAppConfig().getEnvironment(), - config.getAppConfig().getConfigurationName(), - DynamicConfiguration.class); - dynamicConfigurationManager.start(); ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(dynamicConfigurationManager); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicAbusiveHostRulesMigrationConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicAbusiveHostRulesMigrationConfiguration.java new file mode 100644 index 000000000..4ece29013 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicAbusiveHostRulesMigrationConfiguration.java @@ -0,0 +1,32 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration.dynamic; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class DynamicAbusiveHostRulesMigrationConfiguration { + + @JsonProperty + private boolean newReadEnabled = false; + + @JsonProperty + private boolean newWriteEnabled = false; + + @JsonProperty + private boolean newPrimary = false; + + public boolean isNewReadEnabled() { + return newReadEnabled; + } + + public boolean isNewWriteEnabled() { + return newWriteEnabled; + } + + public boolean isNewPrimary() { + return newPrimary; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java index 329be58dc..a6c8f2fa8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java @@ -59,6 +59,10 @@ public class DynamicConfiguration { @Valid private DynamicProfileMigrationConfiguration profileMigration = new DynamicProfileMigrationConfiguration(); + @JsonProperty + @Valid + private DynamicAbusiveHostRulesMigrationConfiguration abusiveHostRulesMigration = new DynamicAbusiveHostRulesMigrationConfiguration(); + public Optional getExperimentEnrollmentConfiguration( final String experimentName) { return Optional.ofNullable(experiments.get(experimentName)); @@ -117,4 +121,8 @@ public class DynamicConfiguration { public DynamicProfileMigrationConfiguration getProfileMigrationConfiguration() { return profileMigration; } + + public DynamicAbusiveHostRulesMigrationConfiguration getAbusiveHostRulesMigrationConfiguration() { + return abusiveHostRulesMigration; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/Experiment.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/Experiment.java index de31a36db..5bd1d38c5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/Experiment.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/Experiment.java @@ -5,12 +5,11 @@ package org.whispersystems.textsecuregcm.experiment; +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.Objects; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -18,115 +17,120 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; - -import static com.codahale.metrics.MetricRegistry.name; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * An experiment compares the results of two operations and records metrics to assess how frequently they match. */ public class Experiment { - private final String name; + private final String name; - private final Timer matchTimer; - private final Timer errorTimer; + private final Timer matchTimer; + private final Timer errorTimer; - private final Timer bothPresentMismatchTimer; - private final Timer controlNullMismatchTimer; - private final Timer experimentNullMismatchTimer; + private final Timer bothPresentMismatchTimer; + private final Timer controlNullMismatchTimer; + private final Timer experimentNullMismatchTimer; - private static final String OUTCOME_TAG = "outcome"; - private static final String MATCH_OUTCOME = "match"; - private static final String MISMATCH_OUTCOME = "mismatch"; - private static final String ERROR_OUTCOME = "error"; + private static final String OUTCOME_TAG = "outcome"; + private static final String MATCH_OUTCOME = "match"; + private static final String MISMATCH_OUTCOME = "mismatch"; + private static final String ERROR_OUTCOME = "error"; - private static final String MISMATCH_TYPE_TAG = "mismatchType"; - private static final String BOTH_PRESENT_MISMATCH = "bothPresent"; - private static final String CONTROL_NULL_MISMATCH = "controlResultNull"; - private static final String EXPERIMENT_NULL_MISMATCH = "experimentResultNull"; + private static final String MISMATCH_TYPE_TAG = "mismatchType"; + private static final String BOTH_PRESENT_MISMATCH = "bothPresent"; + private static final String CONTROL_NULL_MISMATCH = "controlResultNull"; + private static final String EXPERIMENT_NULL_MISMATCH = "experimentResultNull"; - private static final Logger log = LoggerFactory.getLogger(Experiment.class); + private static final Logger log = LoggerFactory.getLogger(Experiment.class); - public Experiment(final String... names) { - this(name(Experiment.class, names), - Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MATCH_OUTCOME), - Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, ERROR_OUTCOME), - Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG, BOTH_PRESENT_MISMATCH), - Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG, CONTROL_NULL_MISMATCH), - Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG, EXPERIMENT_NULL_MISMATCH)); + public Experiment(final String... names) { + this(name(Experiment.class, names), + Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MATCH_OUTCOME), + Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, ERROR_OUTCOME), + Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG, + BOTH_PRESENT_MISMATCH), + Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG, + CONTROL_NULL_MISMATCH), + Metrics.timer(name(Experiment.class, names), OUTCOME_TAG, MISMATCH_OUTCOME, MISMATCH_TYPE_TAG, + EXPERIMENT_NULL_MISMATCH)); + } + + @VisibleForTesting + Experiment(final String name, final Timer matchTimer, final Timer errorTimer, final Timer bothPresentMismatchTimer, + final Timer controlNullMismatchTimer, final Timer experimentNullMismatchTimer) { + this.name = name; + + this.matchTimer = matchTimer; + this.errorTimer = errorTimer; + + this.bothPresentMismatchTimer = bothPresentMismatchTimer; + this.controlNullMismatchTimer = controlNullMismatchTimer; + this.experimentNullMismatchTimer = experimentNullMismatchTimer; + } + + public void compareFutureResult(final T expected, final CompletionStage experimentStage) { + final long startNanos = System.nanoTime(); + + experimentStage.whenComplete((actual, cause) -> { + final long durationNanos = System.nanoTime() - startNanos; + + if (cause != null) { + recordError(cause, durationNanos); + } else { + recordResult(expected, actual, durationNanos); + } + }); + } + + public void compareSupplierResult(final T expected, final Supplier experimentSupplier) { + final long startNanos = System.nanoTime(); + + try { + final T result = experimentSupplier.get(); + + recordResult(expected, result, System.nanoTime() - startNanos); + } catch (final Exception e) { + recordError(e, System.nanoTime() - startNanos); } + } - @VisibleForTesting - Experiment(final String name, final Timer matchTimer, final Timer errorTimer, final Timer bothPresentMismatchTimer, final Timer controlNullMismatchTimer, final Timer experimentNullMismatchTimer) { - this.name = name; + public void compareSupplierResultAsync(final T expected, final Supplier experimentSupplier, + final Executor executor) { + final long startNanos = System.nanoTime(); - this.matchTimer = matchTimer; - this.errorTimer = errorTimer; - - this.bothPresentMismatchTimer = bothPresentMismatchTimer; - this.controlNullMismatchTimer = controlNullMismatchTimer; - this.experimentNullMismatchTimer = experimentNullMismatchTimer; + try { + compareFutureResult(expected, CompletableFuture.supplyAsync(experimentSupplier, executor)); + } catch (final Exception e) { + recordError(e, System.nanoTime() - startNanos); } + } - public void compareFutureResult(final T expected, final CompletionStage experimentStage) { - final long startNanos = System.nanoTime(); + private void recordError(final Throwable cause, final long durationNanos) { + log.warn("Experiment {} threw an exception.", name, cause); + errorTimer.record(durationNanos, TimeUnit.NANOSECONDS); + } - experimentStage.whenComplete((actual, cause) -> { - final long durationNanos = System.nanoTime() - startNanos; + @VisibleForTesting + void recordResult(final T expected, final T actual, final long durationNanos) { + if (expected instanceof Optional && actual instanceof Optional) { + recordResult(((Optional) expected).orElse(null), ((Optional) actual).orElse(null), durationNanos); + } else { + final Timer Timer; - if (cause != null) { - recordError(cause, durationNanos); - } else { - recordResult(expected, actual, durationNanos); - } - }); - } - - public void compareSupplierResult(final T expected, final Supplier experimentSupplier) { - final long startNanos = System.nanoTime(); - - try { - final T result = experimentSupplier.get(); - - recordResult(expected, result, System.nanoTime() - startNanos); - } catch (final Exception e) { - recordError(e, System.nanoTime() - startNanos); - } - } - - public void compareSupplierResultAsync(final T expected, final Supplier experimentSupplier, final Executor executor) { - final long startNanos = System.nanoTime(); - - try { - compareFutureResult(expected, CompletableFuture.supplyAsync(experimentSupplier, executor)); - } catch (final Exception e) { - recordError(e, System.nanoTime() - startNanos); - } - } - - private void recordError(final Throwable cause, final long durationNanos) { - log.warn("Experiment {} threw an exception.", name, cause); - errorTimer.record(durationNanos, TimeUnit.NANOSECONDS); - } - - @VisibleForTesting - void recordResult(final T expected, final T actual, final long durationNanos) { - if (expected instanceof Optional && actual instanceof Optional) { - recordResult(((Optional)expected).orElse(null), ((Optional)actual).orElse(null), durationNanos); - } else { - final Timer Timer; - - if (Objects.equals(expected, actual)) { - Timer = matchTimer; - } else if (expected == null) { - Timer = controlNullMismatchTimer; - } else if (actual == null) { - Timer = experimentNullMismatchTimer; - } else { - Timer = bothPresentMismatchTimer; - } - - Timer.record(durationNanos, TimeUnit.NANOSECONDS); - } + if (Objects.equals(expected, actual)) { + Timer = matchTimer; + } else if (expected == null) { + Timer = controlNullMismatchTimer; + } else if (actual == null) { + Timer = experimentNullMismatchTimer; + } else { + Timer = bothPresentMismatchTimer; + } + + Timer.record(durationNanos, TimeUnit.NANOSECONDS); } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbusiveHostRules.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbusiveHostRules.java index 13bf0c8a5..3f04dc59b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbusiveHostRules.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbusiveHostRules.java @@ -5,56 +5,133 @@ package org.whispersystems.textsecuregcm.storage; +import static com.codahale.metrics.MetricRegistry.name; + import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; +import com.google.common.base.Suppliers; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.experiment.Experiment; import org.whispersystems.textsecuregcm.storage.mappers.AbusiveHostRuleRowMapper; import org.whispersystems.textsecuregcm.util.Constants; - -import java.util.List; - -import static com.codahale.metrics.MetricRegistry.name; +import org.whispersystems.textsecuregcm.util.Pair; public class AbusiveHostRules { - public static final String ID = "id"; - public static final String HOST = "host"; + private static final Logger logger = LoggerFactory.getLogger(AbusiveHostRules.class); + + public static final String ID = "id"; + public static final String HOST = "host"; public static final String BLOCKED = "blocked"; public static final String REGIONS = "regions"; - public static final String NOTES = "notes"; + public static final String NOTES = "notes"; private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - private final Timer getTimer = metricRegistry.timer(name(AbusiveHostRules.class, "get")); - private final Timer insertTimer = metricRegistry.timer(name(AbusiveHostRules.class, "setBlockedHost")); + private final Timer getTimer = metricRegistry.timer(name(AbusiveHostRules.class, "get")); + private final Timer insertTimer = metricRegistry.timer(name(AbusiveHostRules.class, "setBlockedHost")); - private final FaultTolerantDatabase database; + private final FaultTolerantDatabase oldDatabase; + private final FaultTolerantDatabase newDatabase; - public AbusiveHostRules(FaultTolerantDatabase database) { - this.database = database; - this.database.getDatabase().registerRowMapper(new AbusiveHostRuleRowMapper()); + private final DynamicConfigurationManager dynamicConfigurationManager; + private final Experiment migrationExperiment = new Experiment("abusiveHostRulesMigration"); + + public AbusiveHostRules(FaultTolerantDatabase oldDatabase, FaultTolerantDatabase newDatabase, + DynamicConfigurationManager dynamicConfigurationManager) { + + this.oldDatabase = oldDatabase; + this.oldDatabase.getDatabase().registerRowMapper(new AbusiveHostRuleRowMapper()); + + this.newDatabase = newDatabase; + this.newDatabase.getDatabase().registerRowMapper(new AbusiveHostRuleRowMapper()); + + this.dynamicConfigurationManager = dynamicConfigurationManager; } public List getAbusiveHostRulesFor(String host) { - return database.with(jdbi -> jdbi.withHandle(handle -> { + final List oldDbRules = oldDatabase.with(jdbi -> jdbi.withHandle(handle -> { try (Timer.Context timer = getTimer.time()) { return handle.createQuery("SELECT * FROM abusive_host_rules WHERE :host::inet <<= " + HOST) - .bind("host", host) - .mapTo(AbusiveHostRule.class) - .list(); + .bind("host", host) + .mapTo(AbusiveHostRule.class) + .list(); } })); + + final Supplier> newDbRules = Suppliers.memoize( + () -> newDatabase.with(jdbi -> jdbi.withHandle( + handle -> handle.createQuery("SELECT * FROM abusive_host_rules WHERE :host::inet <<= " + HOST) + .bind("host", host) + .mapTo(AbusiveHostRule.class) + .list()))); + + if (dynamicConfigurationManager.getConfiguration().getAbusiveHostRulesMigrationConfiguration().isNewReadEnabled()) { + migrationExperiment.compareSupplierResult(oldDbRules, newDbRules); + } + + return dynamicConfigurationManager.getConfiguration().getAbusiveHostRulesMigrationConfiguration().isNewPrimary() + ? newDbRules.get() + : oldDbRules; } public void setBlockedHost(String host, String notes) { - database.use(jdbi -> jdbi.useHandle(handle -> { + oldDatabase.use(jdbi -> jdbi.useHandle(handle -> { try (Timer.Context timer = insertTimer.time()) { - handle.createUpdate("INSERT INTO abusive_host_rules(host, blocked, notes) VALUES(:host::inet, :blocked, :notes) ON CONFLICT DO NOTHING") - .bind("host", host) - .bind("blocked", 1) - .bind("notes", notes) - .execute(); + handle.createUpdate( + "INSERT INTO abusive_host_rules(host, blocked, notes) VALUES(:host::inet, :blocked, :notes) ON CONFLICT DO NOTHING") + .bind("host", host) + .bind("blocked", 1) + .bind("notes", notes) + .execute(); + } + })); + + if (dynamicConfigurationManager.getConfiguration().getAbusiveHostRulesMigrationConfiguration() + .isNewWriteEnabled()) { + try { + newDatabase.use(jdbi -> jdbi.useHandle(handle -> handle.createUpdate( + "INSERT INTO abusive_host_rules(host, blocked, notes) VALUES(:host::inet, :blocked, :notes) ON CONFLICT DO NOTHING") + .bind("host", host) + .bind("blocked", 1) + .bind("notes", notes) + .execute())); + } catch (final Exception e) { + logger.warn("Failed to insert rule in new database", e); + } + } + } + + public int migrateAbusiveHostRule(AbusiveHostRule rule, String notes) { + return newDatabase.with(jdbi -> jdbi.withHandle(handle -> { + try (Timer.Context timer = insertTimer.time()) { + return handle.createUpdate( + "INSERT INTO abusive_host_rules(host, blocked, notes, regions) VALUES(:host::inet, :blocked, :notes, :regions) ON CONFLICT DO NOTHING") + .bind("host", rule.getHost()) + .bind("blocked", rule.isBlocked() ? 1 : 0) + .bind("notes", notes) + .bind("regions", String.join(",", rule.getRegions())) + .execute(); } })); } + public void forEachInOldDatabase(final BiConsumer consumer, final int fetchSize) { + final AbusiveHostRuleRowMapper rowMapper = new AbusiveHostRuleRowMapper(); + + oldDatabase.use(jdbi -> jdbi.useHandle(handle -> handle.useTransaction(transactionHandle -> + transactionHandle.createQuery("SELECT * FROM abusive_host_rules") + .setFetchSize(fetchSize) + .map((resultSet, ctx) -> { + AbusiveHostRule rule = rowMapper.map(resultSet, ctx); + String notes = resultSet.getString(NOTES); + return new Pair<>(rule, notes); + }) + .forEach(pair -> consumer.accept(pair.first(), pair.second()))))); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateAbusiveHostRulesCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateAbusiveHostRulesCommand.java new file mode 100644 index 000000000..c819e1484 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateAbusiveHostRulesCommand.java @@ -0,0 +1,92 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import com.codahale.metrics.jdbi3.strategies.DefaultNameStrategy; +import io.dropwizard.Application; +import io.dropwizard.cli.EnvironmentCommand; +import io.dropwizard.jdbi3.JdbiFactory; +import io.dropwizard.setup.Environment; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import org.jdbi.v3.core.Jdbi; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.storage.AbusiveHostRules; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; +import java.util.concurrent.atomic.AtomicInteger; + +public class MigrateAbusiveHostRulesCommand extends EnvironmentCommand { + + private static final Logger log = LoggerFactory.getLogger(MigrateAbusiveHostRulesCommand.class); + + public MigrateAbusiveHostRulesCommand() { + super(new Application<>() { + @Override + public void run(WhisperServerConfiguration configuration, Environment environment) { + } + }, "migrate-abusive-host-rules", "Migrate abusive host rules from one Postgres to another"); + } + + @Override + public void configure(Subparser subparser) { + super.configure(subparser); + + subparser.addArgument("-s", "--fetch-size") + .dest("fetchSize") + .type(Integer.class) + .required(false) + .setDefault(512) + .help("The number of rules to fetch from Postgres at once"); + } + + @Override + protected void run(final Environment environment, final Namespace namespace, + final WhisperServerConfiguration config) throws Exception { + + DynamicConfigurationManager dynamicConfigurationManager = + new DynamicConfigurationManager<>(config.getAppConfig().getApplication(), + config.getAppConfig().getEnvironment(), + config.getAppConfig().getConfigurationName(), + DynamicConfiguration.class); + + JdbiFactory jdbiFactory = new JdbiFactory(DefaultNameStrategy.CHECK_EMPTY); + Jdbi abuseJdbi = jdbiFactory.build(environment, config.getAbuseDatabaseConfiguration(), "abusedb"); + + FaultTolerantDatabase abuseDatabase = new FaultTolerantDatabase("abuse_database", abuseJdbi, + config.getAbuseDatabaseConfiguration().getCircuitBreakerConfiguration()); + + Jdbi newAbuseJdbi = jdbiFactory.build(environment, config.getNewAbuseDatabaseConfiguration(), "abusedb2"); + FaultTolerantDatabase newAbuseDatabase = new FaultTolerantDatabase("abuse_database2", newAbuseJdbi, + config.getNewAbuseDatabaseConfiguration().getCircuitBreakerConfiguration()); + + log.info("Beginning migration"); + + AbusiveHostRules abusiveHostRules = new AbusiveHostRules(abuseDatabase, newAbuseDatabase, + dynamicConfigurationManager); + + final int fetchSize = namespace.getInt("fetchSize"); + + final AtomicInteger rulesMigrated = new AtomicInteger(0); + + abusiveHostRules.forEachInOldDatabase((rule, notes) -> { + + abusiveHostRules.migrateAbusiveHostRule(rule, notes); + + int migrated = rulesMigrated.incrementAndGet(); + + if (migrated % 1_000 == 0) { + log.info("Migrated {} rules", migrated); + } + }, fetchSize); + + log.info("Migration complete ({} total rules)", rulesMigrated.get()); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AbusiveHostRulesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AbusiveHostRulesTest.java index 2f231014f..63ebc528c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AbusiveHostRulesTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AbusiveHostRulesTest.java @@ -6,39 +6,63 @@ package org.whispersystems.textsecuregcm.tests.storage; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import io.zonky.test.db.postgres.embedded.LiquibasePreparer; -import io.zonky.test.db.postgres.junit.EmbeddedPostgresRules; -import io.zonky.test.db.postgres.junit.PreparedDbRule; +import io.zonky.test.db.postgres.junit5.EmbeddedPostgresExtension; +import io.zonky.test.db.postgres.junit5.PreparedDbExtension; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.Arrays; import java.util.List; import org.jdbi.v3.core.Jdbi; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicAbusiveHostRulesMigrationConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.storage.AbusiveHostRule; import org.whispersystems.textsecuregcm.storage.AbusiveHostRules; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; -public class AbusiveHostRulesTest { +class AbusiveHostRulesTest { - @Rule - public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("abusedb.xml")); + @RegisterExtension + PreparedDbExtension db = EmbeddedPostgresExtension.preparedDatabase( + LiquibasePreparer.forClasspathLocation("abusedb.xml")); + + @RegisterExtension + PreparedDbExtension newDb = EmbeddedPostgresExtension.preparedDatabase( + LiquibasePreparer.forClasspathLocation("abusedb.xml")); private AbusiveHostRules abusiveHostRules; - @Before - public void setup() { - this.abusiveHostRules = new AbusiveHostRules(new FaultTolerantDatabase("abusive_hosts-test", Jdbi.create(db.getTestDatabase()), new CircuitBreakerConfiguration())); + @BeforeEach + void setup() { + //noinspection unchecked + final DynamicConfigurationManager dynamicConfigurationManager = mock( + DynamicConfigurationManager.class); + final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); + when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); + when(dynamicConfiguration.getAbusiveHostRulesMigrationConfiguration()).thenReturn( + new DynamicAbusiveHostRulesMigrationConfiguration()); + + this.abusiveHostRules = new AbusiveHostRules( + new FaultTolerantDatabase("abusive_hosts-test", Jdbi.create(db.getTestDatabase()), + new CircuitBreakerConfiguration()), + new FaultTolerantDatabase("abusive_hosts-test", Jdbi.create(newDb.getTestDatabase()), + new CircuitBreakerConfiguration()), + dynamicConfigurationManager); } @Test - public void testBlockedHost() throws SQLException { - PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); + void testBlockedHost() throws SQLException { + PreparedStatement statement = db.getTestDatabase().getConnection() + .prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); statement.setString(1, "192.168.1.1"); statement.setInt(2, 1); statement.execute(); @@ -51,8 +75,9 @@ public class AbusiveHostRulesTest { } @Test - public void testBlockedCidr() throws SQLException { - PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); + void testBlockedCidr() throws SQLException { + PreparedStatement statement = db.getTestDatabase().getConnection() + .prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); statement.setString(1, "192.168.1.0/24"); statement.setInt(2, 1); statement.execute(); @@ -65,8 +90,9 @@ public class AbusiveHostRulesTest { } @Test - public void testUnblocked() throws SQLException { - PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); + void testUnblocked() throws SQLException { + PreparedStatement statement = db.getTestDatabase().getConnection() + .prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); statement.setString(1, "192.168.1.0/24"); statement.setInt(2, 1); statement.execute(); @@ -76,8 +102,9 @@ public class AbusiveHostRulesTest { } @Test - public void testRestricted() throws SQLException { - PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked, regions) VALUES (?::INET, ?, ?)"); + void testRestricted() throws SQLException { + PreparedStatement statement = db.getTestDatabase().getConnection() + .prepareStatement("INSERT INTO abusive_host_rules (host, blocked, regions) VALUES (?::INET, ?, ?)"); statement.setString(1, "192.168.1.0/24"); statement.setInt(2, 0); statement.setString(3, "+1,+49"); @@ -90,10 +117,11 @@ public class AbusiveHostRulesTest { } @Test - public void testInsertBlocked() throws Exception { + void testInsertBlocked() throws Exception { abusiveHostRules.setBlockedHost("172.17.0.1", "Testing one two"); - PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * from abusive_host_rules WHERE host = ?::inet"); + PreparedStatement statement = db.getTestDatabase().getConnection() + .prepareStatement("SELECT * from abusive_host_rules WHERE host = ?::inet"); statement.setString(1, "172.17.0.1"); ResultSet resultSet = statement.executeQuery(); @@ -106,8 +134,8 @@ public class AbusiveHostRulesTest { abusiveHostRules.setBlockedHost("172.17.0.1", "Different notes"); - - statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * from abusive_host_rules WHERE host = ?::inet"); + statement = db.getTestDatabase().getConnection() + .prepareStatement("SELECT * from abusive_host_rules WHERE host = ?::inet"); statement.setString(1, "172.17.0.1"); resultSet = statement.executeQuery(); @@ -119,4 +147,30 @@ public class AbusiveHostRulesTest { assertThat(resultSet.getString("notes")).isEqualTo("Testing one two"); } + @Test + void testMigrate() throws Exception { + final int rules = 20; + for (int i = 1; i <= rules; i++) { + abusiveHostRules.setBlockedHost("172.17.0." + i, "Testing one two " + i); + } + + PreparedStatement statement = newDb.getTestDatabase().getConnection() + .prepareStatement("SELECT * from abusive_host_rules"); + + assertThat(queryResultSize(statement.executeQuery())).isEqualTo(0); + + abusiveHostRules.forEachInOldDatabase((rule, host) -> abusiveHostRules.migrateAbusiveHostRule(rule, host), 5); + + assertThat(queryResultSize(statement.executeQuery())).isEqualTo(rules); + } + + private int queryResultSize(ResultSet resultSet) throws SQLException { + int migrated = 0; + while (resultSet.next()) { + migrated++; + } + + return migrated; + } + }