From 20f09e6c6e3421c5481cdefb3ef38d6314834b50 Mon Sep 17 00:00:00 2001 From: Moxie Marlinspike Date: Mon, 1 Apr 2019 20:09:02 -0700 Subject: [PATCH] Update to JDBIv3 --- pom.xml | 2 +- .../textsecuregcm/WhisperServerService.java | 29 +- .../storage/AbusiveHostRules.java | 48 ++-- .../textsecuregcm/storage/Accounts.java | 165 +++++------- .../storage/AccountsManager.java | 16 +- .../textsecuregcm/storage/Keys.java | 163 ++++------- .../textsecuregcm/storage/Messages.java | 209 +++++++-------- .../storage/MessagesManager.java | 4 +- .../storage/PendingAccounts.java | 65 ++--- .../storage/PendingAccountsManager.java | 9 +- .../textsecuregcm/storage/PendingDevices.java | 55 ++-- .../storage/PendingDevicesManager.java | 9 +- .../mappers/AbusiveHostRuleRowMapper.java | 28 ++ .../storage/mappers/AccountRowMapper.java | 29 ++ .../storage/mappers/KeyRecordRowMapper.java | 20 ++ .../OutgoingMessageEntityRowMapper.java | 37 +++ .../StoredVerificationCodeRowMapper.java | 17 ++ .../workers/DeleteUserCommand.java | 22 +- .../workers/TrimMessagesCommand.java | 50 ---- .../textsecuregcm/workers/VacuumCommand.java | 26 +- .../tests/storage/AbusiveHostRulesTest.java | 85 ++++++ .../tests/storage/AccountsManagerTest.java | 4 +- .../tests/storage/AccountsTest.java | 244 +++++++++++++++++ .../textsecuregcm/tests/storage/KeysTest.java | 44 +-- .../tests/storage/MessagesTest.java | 252 ++++++++++++++++++ .../tests/storage/PendingAccountsTest.java | 114 ++++++++ .../tests/storage/PendingDevicesTest.java | 100 +++++++ 27 files changed, 1275 insertions(+), 571 deletions(-) create mode 100644 src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AbusiveHostRuleRowMapper.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AccountRowMapper.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/storage/mappers/KeyRecordRowMapper.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/storage/mappers/OutgoingMessageEntityRowMapper.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/storage/mappers/StoredVerificationCodeRowMapper.java delete mode 100644 src/main/java/org/whispersystems/textsecuregcm/workers/TrimMessagesCommand.java create mode 100644 src/test/java/org/whispersystems/textsecuregcm/tests/storage/AbusiveHostRulesTest.java create mode 100644 src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsTest.java create mode 100644 src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java create mode 100644 src/test/java/org/whispersystems/textsecuregcm/tests/storage/PendingAccountsTest.java create mode 100644 src/test/java/org/whispersystems/textsecuregcm/tests/storage/PendingDevicesTest.java diff --git a/pom.xml b/pom.xml index ff3dd03aa..5ff997f4a 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ io.dropwizard - dropwizard-jdbi + dropwizard-jdbi3 ${dropwizard.version} diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index ef0e19a18..a091388d8 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -22,7 +22,7 @@ import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.databind.DeserializationFeature; import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.eclipse.jetty.servlets.CrossOriginFilter; -import org.skife.jdbi.v2.DBI; +import org.jdbi.v3.core.Jdbi; import org.whispersystems.dispatch.DispatchManager; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.CertificateGenerator; @@ -38,8 +38,8 @@ import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.ProfileController; import org.whispersystems.textsecuregcm.controllers.ProvisioningController; -import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController; import org.whispersystems.textsecuregcm.controllers.TransparentDataController; +import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.liquibase.NameableMigrationsBundle; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; @@ -73,8 +73,6 @@ import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener; import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator; import org.whispersystems.textsecuregcm.workers.CertificateCommand; import org.whispersystems.textsecuregcm.workers.DeleteUserCommand; -import org.whispersystems.textsecuregcm.workers.DirectoryCommand; -import org.whispersystems.textsecuregcm.workers.TrimMessagesCommand; import org.whispersystems.textsecuregcm.workers.VacuumCommand; import org.whispersystems.websocket.WebSocketResourceProviderFactory; import org.whispersystems.websocket.setup.WebSocketEnvironment; @@ -95,7 +93,7 @@ import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.auth.basic.BasicCredentialAuthFilter; import io.dropwizard.db.DataSourceFactory; import io.dropwizard.db.PooledDataSourceFactory; -import io.dropwizard.jdbi.DBIFactory; +import io.dropwizard.jdbi3.JdbiFactory; import io.dropwizard.setup.Bootstrap; import io.dropwizard.setup.Environment; @@ -108,7 +106,6 @@ public class WhisperServerService extends Application bootstrap) { bootstrap.addCommand(new VacuumCommand()); - bootstrap.addCommand(new TrimMessagesCommand()); bootstrap.addCommand(new DeleteUserCommand()); bootstrap.addCommand(new CertificateCommand()); bootstrap.addBundle(new NameableMigrationsBundle("accountdb", "accountsdb.xml") { @@ -147,17 +144,17 @@ public class WhisperServerService extends Application { - @Override - public AbusiveHostRule map(int i, ResultSet resultSet, StatementContext statementContext) - throws SQLException - { - String regionsData = resultSet.getString(REGIONS); + public AbusiveHostRules(Jdbi database) { + this.database = database; + this.database.registerRowMapper(new AbusiveHostRuleRowMapper()); + } - List regions; - - if (regionsData == null) regions = new LinkedList<>(); - else regions = Arrays.asList(regionsData.split(",")); - - - return new AbusiveHostRule(resultSet.getString(HOST), resultSet.getInt(BLOCKED) == 1, regions); - } + public List getAbusiveHostRulesFor(String host) { + return database.withHandle(handle -> handle.createQuery("SELECT * FROM abusive_host_rules WHERE :host::inet <<= " + HOST) + .bind("host", host) + .mapTo(AbusiveHostRule.class) + .list()); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 19958b289..815f332e2 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -18,124 +18,87 @@ package org.whispersystems.textsecuregcm.storage; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import org.skife.jdbi.v2.SQLStatement; -import org.skife.jdbi.v2.StatementContext; -import org.skife.jdbi.v2.TransactionIsolationLevel; -import org.skife.jdbi.v2.sqlobject.Bind; -import org.skife.jdbi.v2.sqlobject.Binder; -import org.skife.jdbi.v2.sqlobject.BinderFactory; -import org.skife.jdbi.v2.sqlobject.BindingAnnotation; -import org.skife.jdbi.v2.sqlobject.GetGeneratedKeys; -import org.skife.jdbi.v2.sqlobject.SqlQuery; -import org.skife.jdbi.v2.sqlobject.SqlUpdate; -import org.skife.jdbi.v2.sqlobject.Transaction; -import org.skife.jdbi.v2.sqlobject.customizers.Mapper; -import org.skife.jdbi.v2.tweak.ResultSetMapper; +import org.jdbi.v3.core.Jdbi; +import org.jdbi.v3.core.transaction.TransactionIsolationLevel; +import org.whispersystems.textsecuregcm.storage.mappers.AccountRowMapper; import org.whispersystems.textsecuregcm.util.SystemMapper; -import java.io.IOException; -import java.lang.annotation.Annotation; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.Iterator; import java.util.List; +import java.util.Optional; -public abstract class Accounts { +public class Accounts { - private static final String ID = "id"; - private static final String NUMBER = "number"; - private static final String DATA = "data"; + public static final String ID = "id"; + public static final String NUMBER = "number"; + public static final String DATA = "data"; private static final ObjectMapper mapper = SystemMapper.getMapper(); - @SqlUpdate("INSERT INTO accounts (" + NUMBER + ", " + DATA + ") VALUES (:number, CAST(:data AS json))") - abstract void insertStep(@AccountBinder Account account); + private final Jdbi database; - @SqlUpdate("DELETE FROM accounts WHERE " + NUMBER + " = :number") - abstract int removeAccount(@Bind("number") String number); + public Accounts(Jdbi database) { + this.database = database; + this.database.registerRowMapper(new AccountRowMapper()); + } - @SqlUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + NUMBER + " = :number") - abstract void update(@AccountBinder Account account); - - @Mapper(AccountMapper.class) - @SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number") - public abstract Account get(@Bind("number") String number); - - @SqlQuery("SELECT COUNT(DISTINCT " + NUMBER + ") from accounts") - public abstract long getCount(); - - @Mapper(AccountMapper.class) - @SqlQuery("SELECT * FROM accounts OFFSET :offset LIMIT :limit") - abstract List getAll(@Bind("offset") int offset, @Bind("limit") int length); - - @Mapper(AccountMapper.class) - @SqlQuery("SELECT * FROM accounts") - public abstract Iterator getAll(); - - @Mapper(AccountMapper.class) - @SqlQuery("SELECT * FROM accounts ORDER BY " + NUMBER + " LIMIT :limit") - public abstract List getAllFrom(@Bind("limit") int length); - - @Mapper(AccountMapper.class) - @SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " > :from ORDER BY " + NUMBER + " LIMIT :limit") - public abstract List getAllFrom(@Bind("from") String from, @Bind("limit") int length); - - @Transaction(TransactionIsolationLevel.SERIALIZABLE) public boolean create(Account account) { - int rows = removeAccount(account.getNumber()); - insertStep(account); - - return rows == 0; - } - - @SqlUpdate("VACUUM accounts") - public abstract void vacuum(); - - public static class AccountMapper implements ResultSetMapper { - @Override - public Account map(int i, ResultSet resultSet, StatementContext statementContext) - throws SQLException - { + return database.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { try { - Account account = mapper.readValue(resultSet.getString(DATA), Account.class); - account.setNumber(resultSet.getString(NUMBER)); + int rows = handle.createUpdate("DELETE FROM accounts WHERE " + NUMBER + " = :number") + .bind("number", account.getNumber()) + .execute(); - return account; - } catch (IOException e) { - throw new SQLException(e); + handle.createUpdate("INSERT INTO accounts (" + NUMBER + ", " + DATA + ") VALUES (:number, CAST(:data AS json))") + .bind("number", account.getNumber()) + .bind("data", mapper.writeValueAsString(account)) + .execute(); + + + return rows == 0; + } catch (JsonProcessingException e) { + throw new IllegalArgumentException(e); } - } + }); } - @BindingAnnotation(AccountBinder.AccountBinderFactory.class) - @Retention(RetentionPolicy.RUNTIME) - @Target({ElementType.PARAMETER}) - public @interface AccountBinder { - public static class AccountBinderFactory implements BinderFactory { - @Override - public Binder build(Annotation annotation) { - return new Binder() { - @Override - public void bind(SQLStatement sql, - AccountBinder accountBinder, - Account account) - { - try { - String serialized = mapper.writeValueAsString(account); - - sql.bind(NUMBER, account.getNumber()); - sql.bind(DATA, serialized); - } catch (JsonProcessingException e) { - throw new IllegalArgumentException(e); - } - } - }; + public void update(Account account) { + database.useHandle(handle -> { + try { + handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + NUMBER + " = :number") + .bind("number", account.getNumber()) + .bind("data", mapper.writeValueAsString(account)) + .execute(); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException(e); } - } + }); + } + + public Optional get(String number) { + return database.withHandle(handle -> handle.createQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number") + .bind("number", number) + .mapTo(Account.class) + .findFirst()); + } + + + public List getAllFrom(String from, int length) { + return database.withHandle(handle -> handle.createQuery("SELECT * FROM accounts WHERE " + NUMBER + " > :from ORDER BY " + NUMBER + " LIMIT :limit") + .bind("from", from) + .bind("limit", length) + .mapTo(Account.class) + .list()); + } + + public List getAllFrom(int length) { + return database.withHandle(handle -> handle.createQuery("SELECT * FROM accounts ORDER BY " + NUMBER + " LIMIT :limit") + .bind("limit", length) + .mapTo(Account.class) + .list()); + } + + public void vacuum() { + database.useHandle(handle -> handle.execute("VACUUM accounts")); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 3d6819fd4..908c0950c 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -31,8 +31,6 @@ import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; import java.io.IOException; -import java.util.Iterator; -import java.util.List; import java.util.Optional; import static com.codahale.metrics.MetricRegistry.name; @@ -63,18 +61,6 @@ public class AccountsManager { this.mapper = SystemMapper.getMapper(); } - public long getCount() { - return accounts.getCount(); - } - - public List getAll(int offset, int length) { - return accounts.getAll(offset, length); - } - - public Iterator getAll() { - return accounts.getAll(); - } - public boolean create(Account account) { try (Timer.Context context = createTimer.time()) { boolean freshUser = databaseCreate(account); @@ -154,7 +140,7 @@ public class AccountsManager { } private Optional databaseGet(String number) { - return Optional.ofNullable(accounts.get(number)); + return accounts.get(number); } private boolean databaseCreate(Account account) { diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java index e0424a0fa..d36c0dbb0 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java @@ -16,138 +16,75 @@ */ package org.whispersystems.textsecuregcm.storage; -import org.skife.jdbi.v2.SQLStatement; -import org.skife.jdbi.v2.StatementContext; -import org.skife.jdbi.v2.TransactionIsolationLevel; -import org.skife.jdbi.v2.exceptions.UnableToExecuteStatementException; -import org.skife.jdbi.v2.sqlobject.Bind; -import org.skife.jdbi.v2.sqlobject.Binder; -import org.skife.jdbi.v2.sqlobject.BinderFactory; -import org.skife.jdbi.v2.sqlobject.BindingAnnotation; -import org.skife.jdbi.v2.sqlobject.SqlBatch; -import org.skife.jdbi.v2.sqlobject.SqlQuery; -import org.skife.jdbi.v2.sqlobject.SqlUpdate; -import org.skife.jdbi.v2.sqlobject.Transaction; -import org.skife.jdbi.v2.sqlobject.customizers.Mapper; -import org.skife.jdbi.v2.tweak.ResultSetMapper; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.jdbi.v3.core.Jdbi; +import org.jdbi.v3.core.statement.PreparedBatch; +import org.jdbi.v3.core.transaction.SerializableTransactionRunner; +import org.jdbi.v3.core.transaction.TransactionIsolationLevel; import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.storage.mappers.KeyRecordRowMapper; -import java.lang.annotation.Annotation; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; -import java.sql.ResultSet; -import java.sql.SQLException; import java.util.List; -import java.util.concurrent.Callable; -import java.util.stream.Collectors; -public abstract class Keys { +public class Keys { - private static final Logger logger = LoggerFactory.getLogger(Keys.class); + private final Jdbi database; - @SqlUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id") - abstract void remove(@Bind("number") String number, @Bind("device_id") long deviceId); - - @SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key) VALUES (:number, :device_id, :key_id, :public_key)") - abstract void append(@KeyRecordBinder List preKeys); - - @SqlQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *") - @Mapper(KeyRecordMapper.class) - abstract List getInternal(@Bind("number") String number, @Bind("device_id") long deviceId); - - @SqlQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *") - @Mapper(KeyRecordMapper.class) - abstract List getInternal(@Bind("number") String number); - - @SqlQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id") - public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId); - - // Apparently transaction annotations don't work on the annotated query methods - @SuppressWarnings("WeakerAccess") - @Transaction(TransactionIsolationLevel.SERIALIZABLE) - List getInternalWithTransaction(String number) { - return getInternal(number); + public Keys(Jdbi database) { + this.database = database; + this.database.registerRowMapper(new KeyRecordRowMapper()); + this.database.setTransactionHandler(new SerializableTransactionRunner()); + this.database.getConfig(SerializableTransactionRunner.Configuration.class).setMaxRetries(10); } - // Apparently transaction annotations don't work on the annotated query methods - @SuppressWarnings("WeakerAccess") - @Transaction(TransactionIsolationLevel.SERIALIZABLE) - List getInternalWithTransaction(String number, long deviceId) { - return getInternal(number, deviceId); - } + public void store(String number, long deviceId, List keys) { + database.useTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { + PreparedBatch preparedBatch = handle.prepareBatch("INSERT INTO keys (number, device_id, key_id, public_key) VALUES (:number, :device_id, :key_id, :public_key)"); - public List get(String number) { - return executeAndRetrySerializableAction(() -> getInternalWithTransaction(number)); + for (PreKey key : keys) { + preparedBatch.bind("number", number) + .bind("device_id", deviceId) + .bind("key_id", key.getKeyId()) + .bind("public_key", key.getPublicKey()) + .add(); + } + + handle.createUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id") + .bind("number", number) + .bind("device_id", deviceId) + .execute(); + + preparedBatch.execute(); + }); } public List get(String number, long deviceId) { - return executeAndRetrySerializableAction(() -> getInternalWithTransaction(number, deviceId)); + return database.inTransaction(TransactionIsolationLevel.SERIALIZABLE, + handle -> handle.createQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *") + .bind("number", number) + .bind("device_id", deviceId) + .mapTo(KeyRecord.class) + .list()); } - @Transaction(TransactionIsolationLevel.SERIALIZABLE) - public void store(String number, long deviceId, List keys) { - List records = keys.stream() - .map(key -> new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey())) - .collect(Collectors.toList()); + public List get(String number) { + return database.inTransaction(TransactionIsolationLevel.SERIALIZABLE, + handle -> handle.createQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *") + .bind("number", number) + .mapTo(KeyRecord.class) + .list()); - remove(number, deviceId); - append(records); } - private List executeAndRetrySerializableAction(Callable> action) { - for (int i=0;i<20;i++) { - try { - return action.call(); - } catch (UnableToExecuteStatementException e) { - logger.info("Serializable conflict, retrying: " + e.getMessage()); - } catch (Exception e) { - throw new AssertionError(e); - } - } - - throw new UnableToExecuteStatementException("Retried statement too many times!"); + public int getCount(String number, long deviceId) { + return database.withHandle(handle -> handle.createQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id") + .bind("number", number) + .bind("device_id", deviceId) + .mapTo(Integer.class) + .findOnly()); } - - @SqlUpdate("VACUUM keys") - public abstract void vacuum(); - - @BindingAnnotation(KeyRecordBinder.PreKeyBinderFactory.class) - @Retention(RetentionPolicy.RUNTIME) - @Target({ElementType.PARAMETER}) - public @interface KeyRecordBinder { - public static class PreKeyBinderFactory implements BinderFactory { - @Override - public Binder build(Annotation annotation) { - return new Binder() { - @Override - public void bind(SQLStatement sql, KeyRecordBinder keyRecordBinder, KeyRecord record) - { - sql.bind("id", record.getId()); - sql.bind("number", record.getNumber()); - sql.bind("device_id", record.getDeviceId()); - sql.bind("key_id", record.getKeyId()); - sql.bind("public_key", record.getPublicKey()); - } - }; - } - } - } - - - public static class KeyRecordMapper implements ResultSetMapper { - @Override - public KeyRecord map(int i, ResultSet resultSet, StatementContext statementContext) - throws SQLException - { - return new KeyRecord(resultSet.getLong("id"), resultSet.getString("number"), - resultSet.getLong("device_id"), resultSet.getLong("key_id"), - resultSet.getString("public_key")); - } + public void vacuum() { + database.useHandle(handle -> handle.execute("VACUUM keys")); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java index 4175b479a..b38efd01e 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java @@ -1,137 +1,106 @@ package org.whispersystems.textsecuregcm.storage; -import org.skife.jdbi.v2.SQLStatement; -import org.skife.jdbi.v2.StatementContext; -import org.skife.jdbi.v2.sqlobject.Bind; -import org.skife.jdbi.v2.sqlobject.Binder; -import org.skife.jdbi.v2.sqlobject.BinderFactory; -import org.skife.jdbi.v2.sqlobject.BindingAnnotation; -import org.skife.jdbi.v2.sqlobject.SqlQuery; -import org.skife.jdbi.v2.sqlobject.SqlUpdate; -import org.skife.jdbi.v2.sqlobject.customizers.Mapper; -import org.skife.jdbi.v2.tweak.ResultSetMapper; +import org.jdbi.v3.core.Jdbi; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.storage.mappers.OutgoingMessageEntityRowMapper; -import java.lang.annotation.Annotation; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; -import java.sql.ResultSet; -import java.sql.SQLException; import java.util.List; +import java.util.Optional; import java.util.UUID; -public abstract class Messages { +public class Messages { static final int RESULT_SET_CHUNK_SIZE = 100; - private static final String ID = "id"; - private static final String GUID = "guid"; - private static final String TYPE = "type"; - private static final String RELAY = "relay"; - private static final String TIMESTAMP = "timestamp"; - private static final String SERVER_TIMESTAMP = "server_timestamp"; - private static final String SOURCE = "source"; - private static final String SOURCE_DEVICE = "source_device"; - private static final String DESTINATION = "destination"; - private static final String DESTINATION_DEVICE = "destination_device"; - private static final String MESSAGE = "message"; - private static final String CONTENT = "content"; + public static final String ID = "id"; + public static final String GUID = "guid"; + public static final String TYPE = "type"; + public static final String RELAY = "relay"; + public static final String TIMESTAMP = "timestamp"; + public static final String SERVER_TIMESTAMP = "server_timestamp"; + public static final String SOURCE = "source"; + public static final String SOURCE_DEVICE = "source_device"; + public static final String DESTINATION = "destination"; + public static final String DESTINATION_DEVICE = "destination_device"; + public static final String MESSAGE = "message"; + public static final String CONTENT = "content"; - @SqlUpdate("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " + - "VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_device, :destination, :destination_device, :message, :content)") - abstract void store(@Bind("guid") UUID guid, - @MessageBinder Envelope message, - @Bind("destination") String destination, - @Bind("destination_device") long destinationDevice); + private final Jdbi database; - @Mapper(MessageMapper.class) - @SqlQuery("SELECT * FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device ORDER BY " + TIMESTAMP + " ASC LIMIT " + RESULT_SET_CHUNK_SIZE) - abstract List load(@Bind("destination") String destination, - @Bind("destination_device") long destinationDevice); - - @Mapper(MessageMapper.class) - @SqlQuery("DELETE FROM messages WHERE " + ID + " IN (SELECT " + ID + " FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device AND " + SOURCE + " = :source AND " + TIMESTAMP + " = :timestamp ORDER BY " + ID + " LIMIT 1) RETURNING *") - abstract OutgoingMessageEntity remove(@Bind("destination") String destination, - @Bind("destination_device") long destinationDevice, - @Bind("source") String source, - @Bind("timestamp") long timestamp); - - @Mapper(MessageMapper.class) - @SqlQuery("DELETE FROM messages WHERE "+ ID + " IN (SELECT " + ID + " FROM MESSAGES WHERE " + GUID + " = :guid AND " + DESTINATION + " = :destination ORDER BY " + ID + " LIMIT 1) RETURNING *") - abstract OutgoingMessageEntity remove(@Bind("destination") String destination, @Bind("guid") UUID guid); - - @Mapper(MessageMapper.class) - @SqlUpdate("DELETE FROM messages WHERE " + ID + " = :id AND " + DESTINATION + " = :destination") - abstract void remove(@Bind("destination") String destination, @Bind("id") long id); - - @SqlUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination") - abstract void clear(@Bind("destination") String destination); - - @SqlUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device") - abstract void clear(@Bind("destination") String destination, @Bind("destination_device") long destinationDevice); - - @SqlUpdate("DELETE FROM messages WHERE " + TIMESTAMP + " < :timestamp") - public abstract void removeOld(@Bind("timestamp") long timestamp); - - @SqlUpdate("VACUUM messages") - public abstract void vacuum(); - - public static class MessageMapper implements ResultSetMapper { - @Override - public OutgoingMessageEntity map(int i, ResultSet resultSet, StatementContext statementContext) - throws SQLException - { - - int type = resultSet.getInt(TYPE); - byte[] legacyMessage = resultSet.getBytes(MESSAGE); - String guid = resultSet.getString(GUID); - - if (type == Envelope.Type.RECEIPT_VALUE && legacyMessage == null) { - /// XXX - REMOVE AFTER 10/01/15 - legacyMessage = new byte[0]; - } - - return new OutgoingMessageEntity(resultSet.getLong(ID), - false, - guid == null ? null : UUID.fromString(guid), - type, - resultSet.getString(RELAY), - resultSet.getLong(TIMESTAMP), - resultSet.getString(SOURCE), - resultSet.getInt(SOURCE_DEVICE), - legacyMessage, - resultSet.getBytes(CONTENT), - resultSet.getLong(SERVER_TIMESTAMP)); - } + public Messages(Jdbi database) { + this.database = database; + this.database.registerRowMapper(new OutgoingMessageEntityRowMapper()); } - @BindingAnnotation(MessageBinder.MessageBinderFactory.class) - @Retention(RetentionPolicy.RUNTIME) - @Target({ElementType.PARAMETER}) - public @interface MessageBinder { - public static class MessageBinderFactory implements BinderFactory { - @Override - public Binder build(Annotation annotation) { - return new Binder() { - @Override - public void bind(SQLStatement sql, - MessageBinder accountBinder, - Envelope message) - { - sql.bind(TYPE, message.getType().getNumber()); - sql.bind(RELAY, message.getRelay()); - sql.bind(TIMESTAMP, message.getTimestamp()); - sql.bind(SERVER_TIMESTAMP, message.getServerTimestamp()); - sql.bind(SOURCE, message.hasSource() ? message.getSource() : null); - sql.bind(SOURCE_DEVICE, message.hasSourceDevice() ? message.getSourceDevice() : null); - sql.bind(MESSAGE, message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null); - sql.bind(CONTENT, message.hasContent() ? message.getContent().toByteArray() : null); - } - }; - } - } + public void store(UUID guid, Envelope message, String destination, long destinationDevice) { + database.useHandle(handle -> { + handle.createUpdate("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " + + "VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_device, :destination, :destination_device, :message, :content)") + .bind("guid", guid) + .bind("destination", destination) + .bind("destination_device", destinationDevice) + .bind("type", message.getType().getNumber()) + .bind("relay", message.getRelay()) + .bind("timestamp", message.getTimestamp()) + .bind("server_timestamp", message.getServerTimestamp()) + .bind("source", message.hasSource() ? message.getSource() : null) + .bind("source_device", message.hasSourceDevice() ? message.getSourceDevice() : null) + .bind("message", message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null) + .bind("content", message.hasContent() ? message.getContent().toByteArray() : null) + .execute(); + }); } + + public List load(String destination, long destinationDevice) { + return database.withHandle(handle -> handle.createQuery("SELECT * FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device ORDER BY " + TIMESTAMP + " ASC LIMIT " + RESULT_SET_CHUNK_SIZE) + .bind("destination", destination) + .bind("destination_device", destinationDevice) + .mapTo(OutgoingMessageEntity.class) + .list()); + } + + public Optional remove(String destination, long destinationDevice, String source, long timestamp) { + return database.withHandle(handle -> handle.createQuery("DELETE FROM messages WHERE " + ID + " IN (SELECT " + ID + " FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device AND " + SOURCE + " = :source AND " + TIMESTAMP + " = :timestamp ORDER BY " + ID + " LIMIT 1) RETURNING *") + .bind("destination", destination) + .bind("destination_device", destinationDevice) + .bind("source", source) + .bind("timestamp", timestamp) + .mapTo(OutgoingMessageEntity.class) + .findFirst()); + } + + public Optional remove(String destination, UUID guid) { + return database.withHandle(handle -> handle.createQuery("DELETE FROM messages WHERE "+ ID + " IN (SELECT " + ID + " FROM MESSAGES WHERE " + GUID + " = :guid AND " + DESTINATION + " = :destination ORDER BY " + ID + " LIMIT 1) RETURNING *") + .bind("destination", destination) + .bind("guid", guid) + .mapTo(OutgoingMessageEntity.class) + .findFirst()); + } + + public void remove(String destination, long id) { + database.useHandle(handle -> handle.createUpdate("DELETE FROM messages WHERE " + ID + " = :id AND " + DESTINATION + " = :destination") + .bind("destination", destination) + .bind("id", id) + .execute()); + } + + public void clear(String destination) { + database.useHandle(handle -> handle.createUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination") + .bind("destination", destination) + .execute()); + } + + public void clear(String destination, long destinationDevice) { + database.useHandle(handle -> handle.createUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device") + .bind("destination", destination) + .bind("destination_device", destinationDevice) + .execute()); + } + + public void vacuum() { + database.useHandle(handle -> handle.execute("VACUUM messages")); + } + + } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index d34d80444..b87a254a3 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -64,7 +64,7 @@ public class MessagesManager { Optional removed = this.messagesCache.remove(destination, destinationDevice, source, timestamp); if (!removed.isPresent()) { - removed = Optional.ofNullable(this.messages.remove(destination, destinationDevice, source, timestamp)); + removed = this.messages.remove(destination, destinationDevice, source, timestamp); cacheMissByNameMeter.mark(); } else { cacheHitByNameMeter.mark(); @@ -77,7 +77,7 @@ public class MessagesManager { Optional removed = this.messagesCache.remove(destination, deviceId, guid); if (!removed.isPresent()) { - removed = Optional.ofNullable(this.messages.remove(destination, guid)); + removed = this.messages.remove(destination, guid); cacheMissByGuidMeter.mark(); } else { cacheHitByGuidMeter.mark(); diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccounts.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccounts.java index 2cd45b5b8..f9a92a838 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccounts.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccounts.java @@ -16,40 +16,45 @@ */ package org.whispersystems.textsecuregcm.storage; -import org.skife.jdbi.v2.StatementContext; -import org.skife.jdbi.v2.sqlobject.Bind; -import org.skife.jdbi.v2.sqlobject.SqlQuery; -import org.skife.jdbi.v2.sqlobject.SqlUpdate; -import org.skife.jdbi.v2.sqlobject.customizers.Mapper; -import org.skife.jdbi.v2.tweak.ResultSetMapper; +import org.jdbi.v3.core.Jdbi; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; +import org.whispersystems.textsecuregcm.storage.mappers.StoredVerificationCodeRowMapper; -import java.sql.ResultSet; -import java.sql.SQLException; +import java.util.Optional; -public interface PendingAccounts { +public class PendingAccounts { - @SqlUpdate("WITH upsert AS (UPDATE pending_accounts SET verification_code = :verification_code, timestamp = :timestamp WHERE number = :number RETURNING *) " + - "INSERT INTO pending_accounts (number, verification_code, timestamp) SELECT :number, :verification_code, :timestamp WHERE NOT EXISTS (SELECT * FROM upsert)") - void insert(@Bind("number") String number, @Bind("verification_code") String verificationCode, @Bind("timestamp") long timestamp); + private final Jdbi database; - @Mapper(StoredVerificationCodeMapper.class) - @SqlQuery("SELECT verification_code, timestamp FROM pending_accounts WHERE number = :number") - StoredVerificationCode getCodeForNumber(@Bind("number") String number); - - @SqlUpdate("DELETE FROM pending_accounts WHERE number = :number") - void remove(@Bind("number") String number); - - @SqlUpdate("VACUUM pending_accounts") - public void vacuum(); - - public static class StoredVerificationCodeMapper implements ResultSetMapper { - @Override - public StoredVerificationCode map(int i, ResultSet resultSet, StatementContext statementContext) - throws SQLException - { - return new StoredVerificationCode(resultSet.getString("verification_code"), - resultSet.getLong("timestamp")); - } + public PendingAccounts(Jdbi database) { + this.database = database; + this.database.registerRowMapper(new StoredVerificationCodeRowMapper()); } + + public void insert(String number, String verificationCode, long timestamp) { + database.useHandle(handle -> handle.createUpdate("WITH upsert AS (UPDATE pending_accounts SET verification_code = :verification_code, timestamp = :timestamp WHERE number = :number RETURNING *) " + + "INSERT INTO pending_accounts (number, verification_code, timestamp) SELECT :number, :verification_code, :timestamp WHERE NOT EXISTS (SELECT * FROM upsert)") + .bind("verification_code", verificationCode) + .bind("timestamp", timestamp) + .bind("number", number) + .execute()); + } + + public Optional getCodeForNumber(String number) { + return database.withHandle(handle -> handle.createQuery("SELECT verification_code, timestamp FROM pending_accounts WHERE number = :number") + .bind("number", number) + .mapTo(StoredVerificationCode.class) + .findFirst()); + } + + public void remove(String number) { + database.useHandle(handle -> handle.createUpdate("DELETE FROM pending_accounts WHERE number = :number") + .bind("number", number) + .execute()); + } + + public void vacuum() { + database.useHandle(handle -> handle.execute("VACUUM pending_accounts")); + } + } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccountsManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccountsManager.java index 0963a0214..17dd31ea8 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccountsManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccountsManager.java @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2013 Open WhisperSystems * * This program is free software: you can redistribute it and/or modify @@ -60,11 +60,8 @@ public class PendingAccountsManager { Optional code = memcacheGet(number); if (!code.isPresent()) { - code = Optional.ofNullable(pendingAccounts.getCodeForNumber(number)); - - if (code.isPresent()) { - memcacheSet(number, code.get()); - } + code = pendingAccounts.getCodeForNumber(number); + code.ifPresent(storedVerificationCode -> memcacheSet(number, storedVerificationCode)); } return code; diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevices.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevices.java index 74a1d459c..32c7fbf16 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevices.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevices.java @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2014 Open WhisperSystems * * This program is free software: you can redistribute it and/or modify @@ -16,38 +16,41 @@ */ package org.whispersystems.textsecuregcm.storage; -import org.skife.jdbi.v2.StatementContext; -import org.skife.jdbi.v2.sqlobject.Bind; -import org.skife.jdbi.v2.sqlobject.SqlQuery; -import org.skife.jdbi.v2.sqlobject.SqlUpdate; -import org.skife.jdbi.v2.sqlobject.customizers.Mapper; -import org.skife.jdbi.v2.tweak.ResultSetMapper; +import org.jdbi.v3.core.Jdbi; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; +import org.whispersystems.textsecuregcm.storage.mappers.StoredVerificationCodeRowMapper; -import java.sql.ResultSet; -import java.sql.SQLException; +import java.util.Optional; -public interface PendingDevices { +public class PendingDevices { - @SqlUpdate("WITH upsert AS (UPDATE pending_devices SET verification_code = :verification_code, timestamp = :timestamp WHERE number = :number RETURNING *) " + - "INSERT INTO pending_devices (number, verification_code, timestamp) SELECT :number, :verification_code, :timestamp WHERE NOT EXISTS (SELECT * FROM upsert)") - void insert(@Bind("number") String number, @Bind("verification_code") String verificationCode, @Bind("timestamp") long timestamp); + private final Jdbi database; - @Mapper(StoredVerificationCodeMapper.class) - @SqlQuery("SELECT verification_code, timestamp FROM pending_devices WHERE number = :number") - StoredVerificationCode getCodeForNumber(@Bind("number") String number); + public PendingDevices(Jdbi database) { + this.database = database; + this.database.registerRowMapper(new StoredVerificationCodeRowMapper()); + } - @SqlUpdate("DELETE FROM pending_devices WHERE number = :number") - void remove(@Bind("number") String number); + public void insert(String number, String verificationCode, long timestamp) { + database.useHandle(handle -> handle.createUpdate("WITH upsert AS (UPDATE pending_devices SET verification_code = :verification_code, timestamp = :timestamp WHERE number = :number RETURNING *) " + + "INSERT INTO pending_devices (number, verification_code, timestamp) SELECT :number, :verification_code, :timestamp WHERE NOT EXISTS (SELECT * FROM upsert)") + .bind("number", number) + .bind("verification_code", verificationCode) + .bind("timestamp", timestamp) + .execute()); + } - public static class StoredVerificationCodeMapper implements ResultSetMapper { - @Override - public StoredVerificationCode map(int i, ResultSet resultSet, StatementContext statementContext) - throws SQLException - { - return new StoredVerificationCode(resultSet.getString("verification_code"), - resultSet.getLong("timestamp")); - } + public Optional getCodeForNumber(String number) { + return database.withHandle(handle -> handle.createQuery("SELECT verification_code, timestamp FROM pending_devices WHERE number = :number") + .bind("number", number) + .mapTo(StoredVerificationCode.class) + .findFirst()); + } + + public void remove(String number) { + database.useHandle(handle -> handle.createUpdate("DELETE FROM pending_devices WHERE number = :number") + .bind("number", number) + .execute()); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevicesManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevicesManager.java index edd594073..90207458c 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevicesManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevicesManager.java @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2014 Open WhisperSystems * * This program is free software: you can redistribute it and/or modify @@ -59,11 +59,8 @@ public class PendingDevicesManager { Optional code = memcacheGet(number); if (!code.isPresent()) { - code = Optional.ofNullable(pendingDevices.getCodeForNumber(number)); - - if (code.isPresent()) { - memcacheSet(number, code.get()); - } + code = pendingDevices.getCodeForNumber(number); + code.ifPresent(storedVerificationCode -> memcacheSet(number, storedVerificationCode)); } return code; diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AbusiveHostRuleRowMapper.java b/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AbusiveHostRuleRowMapper.java new file mode 100644 index 000000000..d7451cc4d --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AbusiveHostRuleRowMapper.java @@ -0,0 +1,28 @@ +package org.whispersystems.textsecuregcm.storage.mappers; + +import org.jdbi.v3.core.mapper.RowMapper; +import org.jdbi.v3.core.statement.StatementContext; +import org.whispersystems.textsecuregcm.storage.AbusiveHostRule; +import org.whispersystems.textsecuregcm.storage.AbusiveHostRules; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; + + +public class AbusiveHostRuleRowMapper implements RowMapper { + @Override + public AbusiveHostRule map(ResultSet resultSet, StatementContext ctx) throws SQLException { + String regionsData = resultSet.getString(AbusiveHostRules.REGIONS); + + List regions; + + if (regionsData == null) regions = new LinkedList<>(); + else regions = Arrays.asList(regionsData.split(",")); + + + return new AbusiveHostRule(resultSet.getString(AbusiveHostRules.HOST), resultSet.getInt(AbusiveHostRules.BLOCKED) == 1, regions); + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AccountRowMapper.java b/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AccountRowMapper.java new file mode 100644 index 000000000..72b52add1 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AccountRowMapper.java @@ -0,0 +1,29 @@ +package org.whispersystems.textsecuregcm.storage.mappers; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.jdbi.v3.core.mapper.RowMapper; +import org.jdbi.v3.core.statement.StatementContext; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Accounts; +import org.whispersystems.textsecuregcm.util.SystemMapper; + +import java.io.IOException; +import java.sql.ResultSet; +import java.sql.SQLException; + +public class AccountRowMapper implements RowMapper { + + private static ObjectMapper mapper = SystemMapper.getMapper(); + + @Override + public Account map(ResultSet resultSet, StatementContext ctx) throws SQLException { + try { + Account account = mapper.readValue(resultSet.getString(Accounts.DATA), Account.class); + account.setNumber(resultSet.getString(Accounts.NUMBER)); + + return account; + } catch (IOException e) { + throw new SQLException(e); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/KeyRecordRowMapper.java b/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/KeyRecordRowMapper.java new file mode 100644 index 000000000..dbac32522 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/KeyRecordRowMapper.java @@ -0,0 +1,20 @@ +package org.whispersystems.textsecuregcm.storage.mappers; + +import org.jdbi.v3.core.mapper.RowMapper; +import org.jdbi.v3.core.statement.StatementContext; +import org.whispersystems.textsecuregcm.storage.KeyRecord; + +import java.sql.ResultSet; +import java.sql.SQLException; + +public class KeyRecordRowMapper implements RowMapper { + + @Override + public KeyRecord map(ResultSet resultSet, StatementContext ctx) throws SQLException { + return new KeyRecord(resultSet.getLong("id"), + resultSet.getString("number"), + resultSet.getLong("device_id"), + resultSet.getLong("key_id"), + resultSet.getString("public_key")); + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/OutgoingMessageEntityRowMapper.java b/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/OutgoingMessageEntityRowMapper.java new file mode 100644 index 000000000..2f238ea98 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/OutgoingMessageEntityRowMapper.java @@ -0,0 +1,37 @@ +package org.whispersystems.textsecuregcm.storage.mappers; + +import org.jdbi.v3.core.mapper.RowMapper; +import org.jdbi.v3.core.statement.StatementContext; +import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.storage.Messages; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.UUID; + +public class OutgoingMessageEntityRowMapper implements RowMapper { + @Override + public OutgoingMessageEntity map(ResultSet resultSet, StatementContext ctx) throws SQLException { + int type = resultSet.getInt(Messages.TYPE); + byte[] legacyMessage = resultSet.getBytes(Messages.MESSAGE); + String guid = resultSet.getString(Messages.GUID); + + if (type == Envelope.Type.RECEIPT_VALUE && legacyMessage == null) { + /// XXX - REMOVE AFTER 10/01/15 + legacyMessage = new byte[0]; + } + + return new OutgoingMessageEntity(resultSet.getLong(Messages.ID), + false, + guid == null ? null : UUID.fromString(guid), + type, + resultSet.getString(Messages.RELAY), + resultSet.getLong(Messages.TIMESTAMP), + resultSet.getString(Messages.SOURCE), + resultSet.getInt(Messages.SOURCE_DEVICE), + legacyMessage, + resultSet.getBytes(Messages.CONTENT), + resultSet.getLong(Messages.SERVER_TIMESTAMP)); + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/StoredVerificationCodeRowMapper.java b/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/StoredVerificationCodeRowMapper.java new file mode 100644 index 000000000..a7bad5304 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/StoredVerificationCodeRowMapper.java @@ -0,0 +1,17 @@ +package org.whispersystems.textsecuregcm.storage.mappers; + +import org.jdbi.v3.core.mapper.RowMapper; +import org.jdbi.v3.core.statement.StatementContext; +import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; + +import java.sql.ResultSet; +import java.sql.SQLException; + +public class StoredVerificationCodeRowMapper implements RowMapper { + + @Override + public StoredVerificationCode map(ResultSet resultSet, StatementContext ctx) throws SQLException { + return new StoredVerificationCode(resultSet.getString("verification_code"), + resultSet.getLong("timestamp")); + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java b/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java index 0d41326e8..f99ac0c59 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java +++ b/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java @@ -3,7 +3,7 @@ package org.whispersystems.textsecuregcm.workers; import com.fasterxml.jackson.databind.DeserializationFeature; import net.sourceforge.argparse4j.inf.Namespace; import net.sourceforge.argparse4j.inf.Subparser; -import org.skife.jdbi.v2.DBI; +import org.jdbi.v3.core.Jdbi; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.WhisperServerConfiguration; @@ -23,17 +23,12 @@ import java.util.Optional; import io.dropwizard.Application; import io.dropwizard.cli.EnvironmentCommand; -import io.dropwizard.db.DataSourceFactory; -import io.dropwizard.jdbi.ImmutableListContainerFactory; -import io.dropwizard.jdbi.ImmutableSetContainerFactory; -import io.dropwizard.jdbi.OptionalContainerFactory; -import io.dropwizard.jdbi.args.OptionalArgumentFactory; +import io.dropwizard.jdbi3.JdbiFactory; import io.dropwizard.setup.Environment; -import redis.clients.jedis.JedisPool; public class DeleteUserCommand extends EnvironmentCommand { - private final Logger logger = LoggerFactory.getLogger(DirectoryCommand.class); + private final Logger logger = LoggerFactory.getLogger(DeleteUserCommand.class); public DeleteUserCommand() { super(new Application() { @@ -66,15 +61,10 @@ public class DeleteUserCommand extends EnvironmentCommand { - private final Logger logger = LoggerFactory.getLogger(TrimMessagesCommand.class); - - public TrimMessagesCommand() { - super("trim", "Trim Messages Database"); - } - - @Override - protected void run(Bootstrap bootstrap, - Namespace namespace, - WhisperServerConfiguration config) - throws Exception - { - DataSourceFactory messageDbConfig = config.getMessageStoreConfiguration(); - DBI messageDbi = new DBI(messageDbConfig.getUrl(), messageDbConfig.getUser(), messageDbConfig.getPassword()); - - messageDbi.registerArgumentFactory(new OptionalArgumentFactory(messageDbConfig.getDriverClass())); - messageDbi.registerContainerFactory(new ImmutableListContainerFactory()); - messageDbi.registerContainerFactory(new ImmutableSetContainerFactory()); - messageDbi.registerContainerFactory(new OptionalContainerFactory()); - - Messages messages = messageDbi.onDemand(Messages.class); - long timestamp = System.currentTimeMillis() - TimeUnit.DAYS.toMillis(90); - - logger.info("Trimming old messages: " + timestamp + "..."); - messages.removeOld(timestamp); - - Thread.sleep(3000); - System.exit(0); - } -} diff --git a/src/main/java/org/whispersystems/textsecuregcm/workers/VacuumCommand.java b/src/main/java/org/whispersystems/textsecuregcm/workers/VacuumCommand.java index 01aaa1df5..a31b8860a 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/workers/VacuumCommand.java +++ b/src/main/java/org/whispersystems/textsecuregcm/workers/VacuumCommand.java @@ -1,7 +1,7 @@ package org.whispersystems.textsecuregcm.workers; import net.sourceforge.argparse4j.inf.Namespace; -import org.skife.jdbi.v2.DBI; +import org.jdbi.v3.core.Jdbi; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.WhisperServerConfiguration; @@ -12,10 +12,6 @@ import org.whispersystems.textsecuregcm.storage.PendingAccounts; import io.dropwizard.cli.ConfiguredCommand; import io.dropwizard.db.DataSourceFactory; -import io.dropwizard.jdbi.ImmutableListContainerFactory; -import io.dropwizard.jdbi.ImmutableSetContainerFactory; -import io.dropwizard.jdbi.OptionalContainerFactory; -import io.dropwizard.jdbi.args.OptionalArgumentFactory; import io.dropwizard.setup.Bootstrap; @@ -35,23 +31,15 @@ public class VacuumCommand extends ConfiguredCommand { DataSourceFactory dbConfig = config.getDataSourceFactory(); DataSourceFactory messageDbConfig = config.getMessageStoreConfiguration(); - DBI dbi = new DBI(dbConfig.getUrl(), dbConfig.getUser(), dbConfig.getPassword() ); - DBI messageDbi = new DBI(messageDbConfig.getUrl(), messageDbConfig.getUser(), messageDbConfig.getPassword()); - dbi.registerArgumentFactory(new OptionalArgumentFactory(dbConfig.getDriverClass())); - dbi.registerContainerFactory(new ImmutableListContainerFactory()); - dbi.registerContainerFactory(new ImmutableSetContainerFactory()); - dbi.registerContainerFactory(new OptionalContainerFactory()); + Jdbi accountDatabase = Jdbi.create(dbConfig.getUrl(), dbConfig.getUser(), dbConfig.getPassword()); + Jdbi messageDatabase = Jdbi.create(messageDbConfig.getUrl(), messageDbConfig.getUser(), messageDbConfig.getPassword()); - messageDbi.registerArgumentFactory(new OptionalArgumentFactory(dbConfig.getDriverClass())); - messageDbi.registerContainerFactory(new ImmutableListContainerFactory()); - messageDbi.registerContainerFactory(new ImmutableSetContainerFactory()); - messageDbi.registerContainerFactory(new OptionalContainerFactory()); - Accounts accounts = dbi.onDemand(Accounts.class ); - Keys keys = dbi.onDemand(Keys.class ); - PendingAccounts pendingAccounts = dbi.onDemand(PendingAccounts.class); - Messages messages = messageDbi.onDemand(Messages.class); + Accounts accounts = new Accounts(accountDatabase); + Keys keys = new Keys(accountDatabase); + PendingAccounts pendingAccounts = new PendingAccounts(accountDatabase); + Messages messages = new Messages(messageDatabase); logger.info("Vacuuming accounts..."); accounts.vacuum(); diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AbusiveHostRulesTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AbusiveHostRulesTest.java new file mode 100644 index 000000000..a5164123c --- /dev/null +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AbusiveHostRulesTest.java @@ -0,0 +1,85 @@ +package org.whispersystems.textsecuregcm.tests.storage; + +import com.opentable.db.postgres.embedded.LiquibasePreparer; +import com.opentable.db.postgres.junit.EmbeddedPostgresRules; +import com.opentable.db.postgres.junit.PreparedDbRule; +import org.jdbi.v3.core.Jdbi; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.whispersystems.textsecuregcm.storage.AbusiveHostRule; +import org.whispersystems.textsecuregcm.storage.AbusiveHostRules; + +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +public class AbusiveHostRulesTest { + + @Rule + public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("abusedb.xml")); + + private AbusiveHostRules abusiveHostRules; + + @Before + public void setup() { + this.abusiveHostRules = new AbusiveHostRules(Jdbi.create(db.getTestDatabase())); + } + + @Test + public void testBlockedHost() throws SQLException { + PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); + statement.setString(1, "192.168.1.1"); + statement.setInt(2, 1); + statement.execute(); + + List rules = abusiveHostRules.getAbusiveHostRulesFor("192.168.1.1"); + assertThat(rules.size()).isEqualTo(1); + assertThat(rules.get(0).getRegions().isEmpty()).isTrue(); + assertThat(rules.get(0).getHost()).isEqualTo("192.168.1.1"); + assertThat(rules.get(0).isBlocked()).isTrue(); + } + + @Test + public void testBlockedCidr() throws SQLException { + PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); + statement.setString(1, "192.168.1.0/24"); + statement.setInt(2, 1); + statement.execute(); + + List rules = abusiveHostRules.getAbusiveHostRulesFor("192.168.1.1"); + assertThat(rules.size()).isEqualTo(1); + assertThat(rules.get(0).getRegions().isEmpty()).isTrue(); + assertThat(rules.get(0).getHost()).isEqualTo("192.168.1.0/24"); + assertThat(rules.get(0).isBlocked()).isTrue(); + } + + @Test + public void testUnblocked() throws SQLException { + PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)"); + statement.setString(1, "192.168.1.0/24"); + statement.setInt(2, 1); + statement.execute(); + + List rules = abusiveHostRules.getAbusiveHostRulesFor("172.17.1.1"); + assertThat(rules.isEmpty()).isTrue(); + } + + @Test + public void testRestricted() throws SQLException { + PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked, regions) VALUES (?::INET, ?, ?)"); + statement.setString(1, "192.168.1.0/24"); + statement.setInt(2, 0); + statement.setString(3, "+1,+49"); + statement.execute(); + + List rules = abusiveHostRules.getAbusiveHostRulesFor("192.168.1.100"); + assertThat(rules.size()).isEqualTo(1); + assertThat(rules.get(0).isBlocked()).isFalse(); + assertThat(rules.get(0).getRegions()).isEqualTo(Arrays.asList("+1", "+49")); + } + +} diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java index 49942bbf0..1d68203f6 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java @@ -54,7 +54,7 @@ public class AccountsManagerTest { when(cacheClient.getReadResource()).thenReturn(jedis); when(cacheClient.getWriteResource()).thenReturn(jedis); when(jedis.get(eq("Account5+14152222222"))).thenReturn(null); - when(accounts.get(eq("+14152222222"))).thenReturn(account); + when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient); Optional retrieved = accountsManager.get("+14152222222"); @@ -82,7 +82,7 @@ public class AccountsManagerTest { when(cacheClient.getReadResource()).thenReturn(jedis); when(cacheClient.getWriteResource()).thenReturn(jedis); when(jedis.get(eq("Account5+14152222222"))).thenThrow(new JedisException("Connection lost!")); - when(accounts.get(eq("+14152222222"))).thenReturn(account); + when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient); Optional retrieved = accountsManager.get("+14152222222"); diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsTest.java new file mode 100644 index 000000000..15c6e8014 --- /dev/null +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsTest.java @@ -0,0 +1,244 @@ +package org.whispersystems.textsecuregcm.tests.storage; + +import com.opentable.db.postgres.embedded.LiquibasePreparer; +import com.opentable.db.postgres.junit.EmbeddedPostgresRules; +import com.opentable.db.postgres.junit.PreparedDbRule; +import org.jdbi.v3.core.Jdbi; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Accounts; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.mappers.AccountRowMapper; + +import java.io.IOException; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.Set; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +public class AccountsTest { + + @Rule + public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml")); + + private Accounts accounts; + + @Before + public void setupAccountsDao() { + this.accounts = new Accounts(Jdbi.create(db.getTestDatabase())); + } + + @Test + public void testStore() throws SQLException, IOException { + Device device = generateDevice (1 ); + Account account = generateAccount("+14151112222", Collections.singleton(device)); + + accounts.create(account); + + PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?"); + verifyStoredState(statement, "+14151112222", account); + } + + @Test + public void testStoreMulti() throws SQLException, IOException { + Set devices = new HashSet<>(); + devices.add(generateDevice(1)); + devices.add(generateDevice(2)); + + Account account = generateAccount("+14151112222", devices); + + accounts.create(account); + + PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?"); + verifyStoredState(statement, "+14151112222", account); + } + + @Test + public void testRetrieve() { + Set devicesFirst = new HashSet<>(); + devicesFirst.add(generateDevice(1)); + devicesFirst.add(generateDevice(2)); + + Account accountFirst = generateAccount("+14151112222", devicesFirst); + + Set devicesSecond = new HashSet<>(); + devicesSecond.add(generateDevice(1)); + devicesSecond.add(generateDevice(2)); + + Account accountSecond = generateAccount("+14152221111", devicesSecond); + + accounts.create(accountFirst); + accounts.create(accountSecond); + + Optional retrievedFirst = accounts.get("+14151112222"); + Optional retrievedSecond = accounts.get("+14152221111"); + + assertThat(retrievedFirst.isPresent()).isTrue(); + assertThat(retrievedSecond.isPresent()).isTrue(); + + verifyStoredState("+14151112222", retrievedFirst.get(), accountFirst); + verifyStoredState("+14152221111", retrievedSecond.get(), accountSecond); + } + + @Test + public void testOverwrite() throws Exception { + Device device = generateDevice (1 ); + Account account = generateAccount("+14151112222", Collections.singleton(device)); + + accounts.create(account); + + PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?"); + verifyStoredState(statement, "+14151112222", account); + + device = generateDevice(1); + account = generateAccount("+14151112222", Collections.singleton(device)); + + accounts.create(account); + verifyStoredState(statement, "+14151112222", account); + } + + @Test + public void testUpdate() { + Device device = generateDevice (1 ); + Account account = generateAccount("+14151112222", Collections.singleton(device)); + + accounts.create(account); + + device.setName("foobar"); + + accounts.update(account); + + Optional retrieved = accounts.get("+14151112222"); + + assertThat(retrieved.isPresent()).isTrue(); + verifyStoredState("+14151112222", retrieved.get(), account); + } + + @Test + public void testRetrieveFrom() { + List users = new ArrayList<>(); + + for (int i=1;i<=100;i++) { + Account account = generateAccount("+1" + String.format("%03d", i)); + users.add(account); + accounts.create(account); + } + + List retrieved = accounts.getAllFrom(10); + assertThat(retrieved.size()).isEqualTo(10); + + for (int i=0;i retrieved = accounts.get("+14151112222"); + assertThat(retrieved.isPresent()).isTrue(); + + verifyStoredState("+14151112222", retrieved.get(), account); + } + + @Test + public void testMissing() { + Device device = generateDevice (1 ); + Account account = generateAccount("+14151112222", Collections.singleton(device)); + + accounts.create(account); + + Optional retrieved = accounts.get("+11111111"); + assertThat(retrieved.isPresent()).isFalse(); + } + + + private Device generateDevice(long id) { + Random random = new Random(System.currentTimeMillis()); + SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(), "testSignature-" + random.nextInt()); + return new Device(1, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(), "testSalt-" + random.nextInt(), null, "testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(), random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(), "testUserAgent-" + random.nextInt(), random.nextBoolean()); + } + + private Account generateAccount(String number) { + Device device = generateDevice(1); + return generateAccount(number, Collections.singleton(device)); + } + + private Account generateAccount(String number, Set devices) { + byte[] unidentifiedAccessKey = new byte[16]; + Random random = new Random(System.currentTimeMillis()); + Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255)); + + return new Account(number, devices, unidentifiedAccessKey); + } + + private void verifyStoredState(PreparedStatement statement, String number, Account expecting) + throws SQLException, IOException + { + statement.setString(1, number); + + ResultSet resultSet = statement.executeQuery(); + + if (resultSet.next()) { + String data = resultSet.getString("data"); + assertThat(data).isNotEmpty(); + + Account result = new AccountRowMapper().map(resultSet, null); + verifyStoredState(number, result, expecting); + } else { + throw new AssertionError("No data"); + } + + assertThat(resultSet.next()).isFalse(); + } + + private void verifyStoredState(String number, Account result, Account expecting) { + assertThat(result.getNumber()).isEqualTo(number); + assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen()); + assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue(); + + for (Device expectingDevice : expecting.getDevices()) { + Device resultDevice = result.getDevice(expectingDevice.getId()).get(); + assertThat(resultDevice.getApnId()).isEqualTo(expectingDevice.getApnId()); + assertThat(resultDevice.getGcmId()).isEqualTo(expectingDevice.getGcmId()); + assertThat(resultDevice.getLastSeen()).isEqualTo(expectingDevice.getLastSeen()); + assertThat(resultDevice.getSignedPreKey().getPublicKey()).isEqualTo(expectingDevice.getSignedPreKey().getPublicKey()); + assertThat(resultDevice.getSignedPreKey().getKeyId()).isEqualTo(expectingDevice.getSignedPreKey().getKeyId()); + assertThat(resultDevice.getSignedPreKey().getSignature()).isEqualTo(expectingDevice.getSignedPreKey().getSignature()); + assertThat(resultDevice.getFetchesMessages()).isEqualTo(expectingDevice.getFetchesMessages()); + assertThat(resultDevice.getUserAgent()).isEqualTo(expectingDevice.getUserAgent()); + assertThat(resultDevice.getName()).isEqualTo(expectingDevice.getName()); + assertThat(resultDevice.getCreated()).isEqualTo(expectingDevice.getCreated()); + } + } + + + + +} diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java index d29bcaa6c..c407c626f 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java @@ -3,9 +3,9 @@ package org.whispersystems.textsecuregcm.tests.storage; import com.opentable.db.postgres.embedded.LiquibasePreparer; import com.opentable.db.postgres.junit.EmbeddedPostgresRules; import com.opentable.db.postgres.junit.PreparedDbRule; +import org.jdbi.v3.core.Jdbi; import org.junit.Rule; import org.junit.Test; -import org.skife.jdbi.v2.DBI; import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.storage.KeyRecord; import org.whispersystems.textsecuregcm.storage.Keys; @@ -27,8 +27,8 @@ public class KeysTest { @Test public void testPopulateKeys() throws SQLException { DataSource dataSource = db.getTestDatabase(); - DBI dbi = new DBI(dataSource); - Keys keys = dbi.onDemand(Keys.class); + Jdbi jdbi = Jdbi.create(dataSource); + Keys keys = new Keys(jdbi); List deviceOnePreKeys = new LinkedList<>(); List deviceTwoPreKeys = new LinkedList<>(); @@ -63,10 +63,10 @@ public class KeysTest { } @Test - public void testKeyCount() throws SQLException { + public void testKeyCount() { DataSource dataSource = db.getTestDatabase(); - DBI dbi = new DBI(dataSource); - Keys keys = dbi.onDemand(Keys.class); + Jdbi jdbi = Jdbi.create(dataSource); + Keys keys = new Keys(jdbi); List deviceOnePreKeys = new LinkedList<>(); @@ -83,8 +83,8 @@ public class KeysTest { @Test public void testGetForDevice() { DataSource dataSource = db.getTestDatabase(); - DBI dbi = new DBI(dataSource); - Keys keys = dbi.onDemand(Keys.class); + Jdbi jdbi = Jdbi.create(dataSource); + Keys keys = new Keys(jdbi); List deviceOnePreKeys = new LinkedList<>(); List deviceTwoPreKeys = new LinkedList<>(); @@ -144,8 +144,8 @@ public class KeysTest { @Test public void testGetForAllDevices() { DataSource dataSource = db.getTestDatabase(); - DBI dbi = new DBI(dataSource); - Keys keys = dbi.onDemand(Keys.class); + Jdbi jdbi = Jdbi.create(dataSource); + Keys keys = new Keys(jdbi); List deviceOnePreKeys = new LinkedList<>(); List deviceTwoPreKeys = new LinkedList<>(); @@ -220,8 +220,8 @@ public class KeysTest { @Test public void testGetForAllDevicesParallel() throws InterruptedException { DataSource dataSource = db.getTestDatabase(); - DBI dbi = new DBI(dataSource); - Keys keys = dbi.onDemand(Keys.class); + Jdbi jdbi = Jdbi.create(dataSource); + Keys keys = new Keys(jdbi); List deviceOnePreKeys = new LinkedList<>(); List deviceTwoPreKeys = new LinkedList<>(); @@ -239,7 +239,7 @@ public class KeysTest { List threads = new LinkedList<>(); - for (int i=0;i<50;i++) { + for (int i=0;i<20;i++) { Thread thread = new Thread(() -> { List results = keys.get("+14152222222"); assertThat(results.size()).isEqualTo(2); @@ -252,21 +252,31 @@ public class KeysTest { thread.join(); } - assertThat(keys.getCount("+14152222222", 1)).isEqualTo(50); - assertThat(keys.getCount("+14152222222",2)).isEqualTo(50); + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(80); + assertThat(keys.getCount("+14152222222",2)).isEqualTo(80); } @Test public void testEmptyKeyGet() { - DBI dbi = new DBI(db.getTestDatabase()); - Keys keys = dbi.onDemand(Keys.class); + DataSource dataSource = db.getTestDatabase(); + Jdbi jdbi = Jdbi.create(dataSource); + Keys keys = new Keys(jdbi); List records = keys.get("+14152222222"); assertThat(records.isEmpty()).isTrue(); } + @Test + public void testVacuum() { + DataSource dataSource = db.getTestDatabase(); + Jdbi jdbi = Jdbi.create(dataSource); + Keys keys = new Keys(jdbi); + + keys.vacuum(); + } + private void verifyStoredState(PreparedStatement statement, String number, int deviceId) throws SQLException { statement.setString(1, number); diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java new file mode 100644 index 000000000..e75566500 --- /dev/null +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java @@ -0,0 +1,252 @@ +package org.whispersystems.textsecuregcm.tests.storage; + +import com.google.protobuf.ByteString; +import com.opentable.db.postgres.embedded.LiquibasePreparer; +import com.opentable.db.postgres.junit.EmbeddedPostgresRules; +import com.opentable.db.postgres.junit.PreparedDbRule; +import org.jdbi.v3.core.Jdbi; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.storage.Messages; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.UUID; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +public class MessagesTest { + + @Rule + public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("messagedb.xml")); + + private Messages messages; + + @Before + public void setupAccountsDao() { + this.messages = new Messages(Jdbi.create(db.getTestDatabase())); + } + + @Test + public void testStore() throws SQLException { + Envelope envelope = generateEnvelope(); + UUID guid = UUID.randomUUID(); + + messages.store(guid, envelope, "+14151112222", 1); + + PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM messages WHERE destination = ?"); + statement.setString(1, "+14151112222"); + + ResultSet resultSet = statement.executeQuery(); + assertThat(resultSet.next()).isTrue(); + + assertThat(resultSet.getString("guid")).isEqualTo(guid.toString()); + assertThat(resultSet.getInt("type")).isEqualTo(envelope.getType().getNumber()); + assertThat(resultSet.getString("relay")).isNullOrEmpty(); + assertThat(resultSet.getLong("timestamp")).isEqualTo(envelope.getTimestamp()); + assertThat(resultSet.getLong("server_timestamp")).isEqualTo(envelope.getServerTimestamp()); + assertThat(resultSet.getString("source")).isEqualTo(envelope.getSource()); + assertThat(resultSet.getLong("source_device")).isEqualTo(envelope.getSourceDevice()); + assertThat(resultSet.getBytes("message")).isEqualTo(envelope.getLegacyMessage().toByteArray()); + assertThat(resultSet.getBytes("content")).isEqualTo(envelope.getContent().toByteArray()); + assertThat(resultSet.getString("destination")).isEqualTo("+14151112222"); + assertThat(resultSet.getLong("destination_device")).isEqualTo(1); + + assertThat(resultSet.next()).isFalse(); + } + + @Test + public void testLoad() { + List inserted = new ArrayList<>(50); + + for (int i=0;i<50;i++) { + MessageToStore message = generateMessageToStore(); + inserted.add(message); + + messages.store(message.guid, message.envelope, "+14151112222", 1); + } + + inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp())); + + List retrieved = messages.load("+14151112222", 1); + + assertThat(retrieved.size()).isEqualTo(inserted.size()); + + for (int i=0;i inserted = insertRandom("+14151112222", 1); + List unrelated = insertRandom("+14151114444", 3); + MessageToStore toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); + Optional removed = messages.remove("+14151112222", 1, toRemove.envelope.getSource(), toRemove.envelope.getTimestamp()); + + assertThat(removed.isPresent()).isTrue(); + verifyExpected(removed.get(), toRemove.envelope, toRemove.guid); + + verifyInTact(inserted, "+14151112222", 1); + verifyInTact(unrelated, "+14151114444", 3); + } + + @Test + public void removeByDestinationGuid() { + List unrelated = insertRandom("+14151113333", 2); + List inserted = insertRandom("+14151112222", 1); + MessageToStore toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); + Optional removed = messages.remove("+14151112222", toRemove.guid); + + assertThat(removed.isPresent()).isTrue(); + verifyExpected(removed.get(), toRemove.envelope, toRemove.guid); + + verifyInTact(inserted, "+14151112222", 1); + verifyInTact(unrelated, "+14151113333", 2); + } + + @Test + public void removeByDestinationRowId() { + List unrelatedInserted = insertRandom("+14151111111", 1); + List inserted = insertRandom("+14151112222", 1); + + inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp())); + + List retrieved = messages.load("+14151112222", 1); + + int toRemoveIndex = new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1); + + inserted.remove(toRemoveIndex); + + messages.remove("+14151112222", retrieved.get(toRemoveIndex).getId()); + + verifyInTact(inserted, "+14151112222", 1); + verifyInTact(unrelatedInserted, "+14151111111", 1); + } + + @Test + public void testLoadEmpty() { + List inserted = insertRandom("+14151112222", 1); + List loaded = messages.load("+14159999999", 1); + assertThat(loaded.isEmpty()).isTrue(); + } + + @Test + public void testClearDestination() { + insertRandom("+14151112222", 1); + insertRandom("+14151112222", 2); + + List unrelated = insertRandom("+14151111111", 1); + + messages.clear("+14151112222"); + + assertThat(messages.load("+14151112222", 1).isEmpty()).isTrue(); + + verifyInTact(unrelated, "+14151111111", 1); + } + + @Test + public void testClearDestinationDevice() { + insertRandom("+14151112222", 1); + List inserted = insertRandom("+14151112222", 2); + + List unrelated = insertRandom("+14151111111", 1); + + messages.clear("+14151112222", 1); + + assertThat(messages.load("+14151112222", 1).isEmpty()).isTrue(); + + verifyInTact(inserted, "+14151112222", 2); + verifyInTact(unrelated, "+14151111111", 1); + } + + @Test + public void testVacuum() { + List inserted = insertRandom("+14151112222", 2); + messages.vacuum(); + verifyInTact(inserted, "+14151112222", 2); + } + + private List insertRandom(String destination, int destinationDevice) { + List inserted = new ArrayList<>(50); + + for (int i=0;i<50;i++) { + MessageToStore message = generateMessageToStore(); + inserted.add(message); + + messages.store(message.guid, message.envelope, destination, destinationDevice); + } + + return inserted; + } + + private void verifyInTact(List inserted, String destination, int destinationDevice) { + inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp())); + + List retrieved = messages.load(destination, destinationDevice); + + assertThat(retrieved.size()).isEqualTo(inserted.size()); + + for (int i=0;i verificationCode = pendingAccounts.getCodeForNumber("+14151112222"); + + assertThat(verificationCode.isPresent()).isTrue(); + assertThat(verificationCode.get().getCode()).isEqualTo("4321"); + assertThat(verificationCode.get().getTimestamp()).isEqualTo(2222); + + Optional missingCode = pendingAccounts.getCodeForNumber("+11111111111"); + assertThat(missingCode.isPresent()).isFalse(); + } + + @Test + public void testOverwrite() throws Exception { + pendingAccounts.insert("+14151112222", "4321", 2222); + pendingAccounts.insert("+14151112222", "4444", 3333); + + Optional verificationCode = pendingAccounts.getCodeForNumber("+14151112222"); + + assertThat(verificationCode.isPresent()).isTrue(); + assertThat(verificationCode.get().getCode()).isEqualTo("4444"); + assertThat(verificationCode.get().getTimestamp()).isEqualTo(3333); + } + + @Test + public void testVacuum() { + pendingAccounts.insert("+14151112222", "4321", 2222); + pendingAccounts.insert("+14151112222", "4444", 3333); + pendingAccounts.vacuum(); + + Optional verificationCode = pendingAccounts.getCodeForNumber("+14151112222"); + + assertThat(verificationCode.isPresent()).isTrue(); + assertThat(verificationCode.get().getCode()).isEqualTo("4444"); + assertThat(verificationCode.get().getTimestamp()).isEqualTo(3333); + } + + @Test + public void testRemove() { + pendingAccounts.insert("+14151112222", "4321", 2222); + pendingAccounts.insert("+14151113333", "1212", 5555); + + Optional verificationCode = pendingAccounts.getCodeForNumber("+14151112222"); + + assertThat(verificationCode.isPresent()).isTrue(); + assertThat(verificationCode.get().getCode()).isEqualTo("4321"); + assertThat(verificationCode.get().getTimestamp()).isEqualTo(2222); + + pendingAccounts.remove("+14151112222"); + + verificationCode = pendingAccounts.getCodeForNumber("+14151112222"); + assertThat(verificationCode.isPresent()).isFalse(); + + verificationCode = pendingAccounts.getCodeForNumber("+14151113333"); + assertThat(verificationCode.isPresent()).isTrue(); + assertThat(verificationCode.get().getCode()).isEqualTo("1212"); + assertThat(verificationCode.get().getTimestamp()).isEqualTo(5555); + } + + +} diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PendingDevicesTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PendingDevicesTest.java new file mode 100644 index 000000000..6cc2ba5d4 --- /dev/null +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PendingDevicesTest.java @@ -0,0 +1,100 @@ +package org.whispersystems.textsecuregcm.tests.storage; + +import com.opentable.db.postgres.embedded.LiquibasePreparer; +import com.opentable.db.postgres.junit.EmbeddedPostgresRules; +import com.opentable.db.postgres.junit.PreparedDbRule; +import org.jdbi.v3.core.Jdbi; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; +import org.whispersystems.textsecuregcm.storage.PendingDevices; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Optional; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +public class PendingDevicesTest { + + @Rule + public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml")); + + private PendingDevices pendingDevices; + + @Before + public void setupAccountsDao() { + this.pendingDevices = new PendingDevices(Jdbi.create(db.getTestDatabase())); + } + + @Test + public void testStore() throws SQLException { + pendingDevices.insert("+14151112222", "1234", 1111); + + PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM pending_devices WHERE number = ?"); + statement.setString(1, "+14151112222"); + + ResultSet resultSet = statement.executeQuery(); + + if (resultSet.next()) { + assertThat(resultSet.getString("verification_code")).isEqualTo("1234"); + assertThat(resultSet.getLong("timestamp")).isEqualTo(1111); + } else { + throw new AssertionError("no results"); + } + + assertThat(resultSet.next()).isFalse(); + } + + @Test + public void testRetrieve() throws Exception { + pendingDevices.insert("+14151112222", "4321", 2222); + pendingDevices.insert("+14151113333", "1212", 5555); + + Optional verificationCode = pendingDevices.getCodeForNumber("+14151112222"); + + assertThat(verificationCode.isPresent()).isTrue(); + assertThat(verificationCode.get().getCode()).isEqualTo("4321"); + assertThat(verificationCode.get().getTimestamp()).isEqualTo(2222); + + Optional missingCode = pendingDevices.getCodeForNumber("+11111111111"); + assertThat(missingCode.isPresent()).isFalse(); + } + + @Test + public void testOverwrite() throws Exception { + pendingDevices.insert("+14151112222", "4321", 2222); + pendingDevices.insert("+14151112222", "4444", 3333); + + Optional verificationCode = pendingDevices.getCodeForNumber("+14151112222"); + + assertThat(verificationCode.isPresent()).isTrue(); + assertThat(verificationCode.get().getCode()).isEqualTo("4444"); + assertThat(verificationCode.get().getTimestamp()).isEqualTo(3333); + } + + @Test + public void testRemove() { + pendingDevices.insert("+14151112222", "4321", 2222); + pendingDevices.insert("+14151113333", "1212", 5555); + + Optional verificationCode = pendingDevices.getCodeForNumber("+14151112222"); + + assertThat(verificationCode.isPresent()).isTrue(); + assertThat(verificationCode.get().getCode()).isEqualTo("4321"); + assertThat(verificationCode.get().getTimestamp()).isEqualTo(2222); + + pendingDevices.remove("+14151112222"); + + verificationCode = pendingDevices.getCodeForNumber("+14151112222"); + assertThat(verificationCode.isPresent()).isFalse(); + + verificationCode = pendingDevices.getCodeForNumber("+14151113333"); + assertThat(verificationCode.isPresent()).isTrue(); + assertThat(verificationCode.get().getCode()).isEqualTo("1212"); + assertThat(verificationCode.get().getTimestamp()).isEqualTo(5555); + } + +}