Update to JDBIv3

This commit is contained in:
Moxie Marlinspike 2019-04-01 20:09:02 -07:00
parent 944e1d9698
commit 20f09e6c6e
27 changed files with 1275 additions and 571 deletions

View File

@ -26,7 +26,7 @@
</dependency>
<dependency>
<groupId>io.dropwizard</groupId>
<artifactId>dropwizard-jdbi</artifactId>
<artifactId>dropwizard-jdbi3</artifactId>
<version>${dropwizard.version}</version>
</dependency>
<dependency>

View File

@ -22,7 +22,7 @@ import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.DeserializationFeature;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.skife.jdbi.v2.DBI;
import org.jdbi.v3.core.Jdbi;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
@ -38,8 +38,8 @@ import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.ProfileController;
import org.whispersystems.textsecuregcm.controllers.ProvisioningController;
import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController;
import org.whispersystems.textsecuregcm.controllers.TransparentDataController;
import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.liquibase.NameableMigrationsBundle;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
@ -73,8 +73,6 @@ import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.textsecuregcm.workers.CertificateCommand;
import org.whispersystems.textsecuregcm.workers.DeleteUserCommand;
import org.whispersystems.textsecuregcm.workers.DirectoryCommand;
import org.whispersystems.textsecuregcm.workers.TrimMessagesCommand;
import org.whispersystems.textsecuregcm.workers.VacuumCommand;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ -95,7 +93,7 @@ import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.db.DataSourceFactory;
import io.dropwizard.db.PooledDataSourceFactory;
import io.dropwizard.jdbi.DBIFactory;
import io.dropwizard.jdbi3.JdbiFactory;
import io.dropwizard.setup.Bootstrap;
import io.dropwizard.setup.Environment;
@ -108,7 +106,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
@Override
public void initialize(Bootstrap<WhisperServerConfiguration> bootstrap) {
bootstrap.addCommand(new VacuumCommand());
bootstrap.addCommand(new TrimMessagesCommand());
bootstrap.addCommand(new DeleteUserCommand());
bootstrap.addCommand(new CertificateCommand());
bootstrap.addBundle(new NameableMigrationsBundle<WhisperServerConfiguration>("accountdb", "accountsdb.xml") {
@ -147,17 +144,17 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.getObjectMapper().setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
environment.getObjectMapper().setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
DBIFactory dbiFactory = new DBIFactory();
DBI database = dbiFactory.build(environment, config.getDataSourceFactory(), "accountdb");
DBI messagedb = dbiFactory.build(environment, config.getMessageStoreConfiguration(), "messagedb");
DBI abusedb = dbiFactory.build(environment, config.getAbuseDatabaseConfiguration(), "abusedb");
JdbiFactory jdbiFactory = new JdbiFactory();
Jdbi accountDatabase = jdbiFactory.build(environment, config.getDataSourceFactory(), "accountdb");
Jdbi messageDatabase = jdbiFactory.build(environment, config.getMessageStoreConfiguration(), "messagedb");
Jdbi abuseDatabase = jdbiFactory.build(environment, config.getAbuseDatabaseConfiguration(), "abusedb");
Accounts accounts = database.onDemand(Accounts.class );
PendingAccounts pendingAccounts = database.onDemand(PendingAccounts.class);
PendingDevices pendingDevices = database.onDemand(PendingDevices.class );
Keys keys = database.onDemand(Keys.class );
Messages messages = messagedb.onDemand(Messages.class);
AbusiveHostRules abusiveHostRules = abusedb.onDemand(AbusiveHostRules.class);
Accounts accounts = new Accounts(accountDatabase);
PendingAccounts pendingAccounts = new PendingAccounts(accountDatabase);
PendingDevices pendingDevices = new PendingDevices(accountDatabase);
Keys keys = new Keys(accountDatabase);
Messages messages = new Messages(messageDatabase);
AbusiveHostRules abusiveHostRules = new AbusiveHostRules(abuseDatabase);
RedisClientFactory cacheClientFactory = new RedisClientFactory("main_cache", config.getCacheConfiguration().getUrl(), config.getCacheConfiguration().getReplicaUrls(), config.getCacheConfiguration().getCircuitBreakerConfiguration());
RedisClientFactory directoryClientFactory = new RedisClientFactory("directory_cache", config.getDirectoryConfiguration().getRedisConfiguration().getUrl(), config.getDirectoryConfiguration().getRedisConfiguration().getReplicaUrls(), config.getDirectoryConfiguration().getRedisConfiguration().getCircuitBreakerConfiguration());

View File

@ -1,43 +1,29 @@
package org.whispersystems.textsecuregcm.storage;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.jdbi.v3.core.Jdbi;
import org.whispersystems.textsecuregcm.storage.mappers.AbusiveHostRuleRowMapper;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
public abstract class AbusiveHostRules {
public class AbusiveHostRules {
private static final String ID = "id";
private static final String HOST = "host";
private static final String BLOCKED = "blocked";
private static final String REGIONS = "regions";
public static final String ID = "id";
public static final String HOST = "host";
public static final String BLOCKED = "blocked";
public static final String REGIONS = "regions";
@Mapper(AbusiveHostRuleMapper.class)
@SqlQuery("SELECT * FROM abusive_host_rules WHERE :host::inet <<= " + HOST)
public abstract List<AbusiveHostRule> getAbusiveHostRulesFor(@Bind("host") String host);
private final Jdbi database;
public static class AbusiveHostRuleMapper implements ResultSetMapper<AbusiveHostRule> {
@Override
public AbusiveHostRule map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
String regionsData = resultSet.getString(REGIONS);
public AbusiveHostRules(Jdbi database) {
this.database = database;
this.database.registerRowMapper(new AbusiveHostRuleRowMapper());
}
List<String> regions;
if (regionsData == null) regions = new LinkedList<>();
else regions = Arrays.asList(regionsData.split(","));
return new AbusiveHostRule(resultSet.getString(HOST), resultSet.getInt(BLOCKED) == 1, regions);
}
public List<AbusiveHostRule> getAbusiveHostRulesFor(String host) {
return database.withHandle(handle -> handle.createQuery("SELECT * FROM abusive_host_rules WHERE :host::inet <<= " + HOST)
.bind("host", host)
.mapTo(AbusiveHostRule.class)
.list());
}
}

View File

@ -18,124 +18,87 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.skife.jdbi.v2.SQLStatement;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.TransactionIsolationLevel;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.Binder;
import org.skife.jdbi.v2.sqlobject.BinderFactory;
import org.skife.jdbi.v2.sqlobject.BindingAnnotation;
import org.skife.jdbi.v2.sqlobject.GetGeneratedKeys;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.Transaction;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.jdbi.v3.core.Jdbi;
import org.jdbi.v3.core.transaction.TransactionIsolationLevel;
import org.whispersystems.textsecuregcm.storage.mappers.AccountRowMapper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
public abstract class Accounts {
public class Accounts {
private static final String ID = "id";
private static final String NUMBER = "number";
private static final String DATA = "data";
public static final String ID = "id";
public static final String NUMBER = "number";
public static final String DATA = "data";
private static final ObjectMapper mapper = SystemMapper.getMapper();
@SqlUpdate("INSERT INTO accounts (" + NUMBER + ", " + DATA + ") VALUES (:number, CAST(:data AS json))")
abstract void insertStep(@AccountBinder Account account);
private final Jdbi database;
@SqlUpdate("DELETE FROM accounts WHERE " + NUMBER + " = :number")
abstract int removeAccount(@Bind("number") String number);
public Accounts(Jdbi database) {
this.database = database;
this.database.registerRowMapper(new AccountRowMapper());
}
@SqlUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + NUMBER + " = :number")
abstract void update(@AccountBinder Account account);
@Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number")
public abstract Account get(@Bind("number") String number);
@SqlQuery("SELECT COUNT(DISTINCT " + NUMBER + ") from accounts")
public abstract long getCount();
@Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts OFFSET :offset LIMIT :limit")
abstract List<Account> getAll(@Bind("offset") int offset, @Bind("limit") int length);
@Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts")
public abstract Iterator<Account> getAll();
@Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts ORDER BY " + NUMBER + " LIMIT :limit")
public abstract List<Account> getAllFrom(@Bind("limit") int length);
@Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " > :from ORDER BY " + NUMBER + " LIMIT :limit")
public abstract List<Account> getAllFrom(@Bind("from") String from, @Bind("limit") int length);
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public boolean create(Account account) {
int rows = removeAccount(account.getNumber());
insertStep(account);
return rows == 0;
}
@SqlUpdate("VACUUM accounts")
public abstract void vacuum();
public static class AccountMapper implements ResultSetMapper<Account> {
@Override
public Account map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
return database.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> {
try {
Account account = mapper.readValue(resultSet.getString(DATA), Account.class);
account.setNumber(resultSet.getString(NUMBER));
int rows = handle.createUpdate("DELETE FROM accounts WHERE " + NUMBER + " = :number")
.bind("number", account.getNumber())
.execute();
return account;
} catch (IOException e) {
throw new SQLException(e);
handle.createUpdate("INSERT INTO accounts (" + NUMBER + ", " + DATA + ") VALUES (:number, CAST(:data AS json))")
.bind("number", account.getNumber())
.bind("data", mapper.writeValueAsString(account))
.execute();
return rows == 0;
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
});
}
@BindingAnnotation(AccountBinder.AccountBinderFactory.class)
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.PARAMETER})
public @interface AccountBinder {
public static class AccountBinderFactory implements BinderFactory {
@Override
public Binder build(Annotation annotation) {
return new Binder<AccountBinder, Account>() {
@Override
public void bind(SQLStatement<?> sql,
AccountBinder accountBinder,
Account account)
{
try {
String serialized = mapper.writeValueAsString(account);
sql.bind(NUMBER, account.getNumber());
sql.bind(DATA, serialized);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
};
public void update(Account account) {
database.useHandle(handle -> {
try {
handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + NUMBER + " = :number")
.bind("number", account.getNumber())
.bind("data", mapper.writeValueAsString(account))
.execute();
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
});
}
public Optional<Account> get(String number) {
return database.withHandle(handle -> handle.createQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number")
.bind("number", number)
.mapTo(Account.class)
.findFirst());
}
public List<Account> getAllFrom(String from, int length) {
return database.withHandle(handle -> handle.createQuery("SELECT * FROM accounts WHERE " + NUMBER + " > :from ORDER BY " + NUMBER + " LIMIT :limit")
.bind("from", from)
.bind("limit", length)
.mapTo(Account.class)
.list());
}
public List<Account> getAllFrom(int length) {
return database.withHandle(handle -> handle.createQuery("SELECT * FROM accounts ORDER BY " + NUMBER + " LIMIT :limit")
.bind("limit", length)
.mapTo(Account.class)
.list());
}
public void vacuum() {
database.useHandle(handle -> handle.execute("VACUUM accounts"));
}
}

View File

@ -31,8 +31,6 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import static com.codahale.metrics.MetricRegistry.name;
@ -63,18 +61,6 @@ public class AccountsManager {
this.mapper = SystemMapper.getMapper();
}
public long getCount() {
return accounts.getCount();
}
public List<Account> getAll(int offset, int length) {
return accounts.getAll(offset, length);
}
public Iterator<Account> getAll() {
return accounts.getAll();
}
public boolean create(Account account) {
try (Timer.Context context = createTimer.time()) {
boolean freshUser = databaseCreate(account);
@ -154,7 +140,7 @@ public class AccountsManager {
}
private Optional<Account> databaseGet(String number) {
return Optional.ofNullable(accounts.get(number));
return accounts.get(number);
}
private boolean databaseCreate(Account account) {

View File

@ -16,138 +16,75 @@
*/
package org.whispersystems.textsecuregcm.storage;
import org.skife.jdbi.v2.SQLStatement;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.TransactionIsolationLevel;
import org.skife.jdbi.v2.exceptions.UnableToExecuteStatementException;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.Binder;
import org.skife.jdbi.v2.sqlobject.BinderFactory;
import org.skife.jdbi.v2.sqlobject.BindingAnnotation;
import org.skife.jdbi.v2.sqlobject.SqlBatch;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.Transaction;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.jdbi.v3.core.Jdbi;
import org.jdbi.v3.core.statement.PreparedBatch;
import org.jdbi.v3.core.transaction.SerializableTransactionRunner;
import org.jdbi.v3.core.transaction.TransactionIsolationLevel;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.storage.mappers.KeyRecordRowMapper;
import java.lang.annotation.Annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
public abstract class Keys {
public class Keys {
private static final Logger logger = LoggerFactory.getLogger(Keys.class);
private final Jdbi database;
@SqlUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id")
abstract void remove(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key) VALUES (:number, :device_id, :key_id, :public_key)")
abstract void append(@KeyRecordBinder List<KeyRecord> preKeys);
@SqlQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *")
@Mapper(KeyRecordMapper.class)
abstract List<KeyRecord> getInternal(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *")
@Mapper(KeyRecordMapper.class)
abstract List<KeyRecord> getInternal(@Bind("number") String number);
@SqlQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id")
public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId);
// Apparently transaction annotations don't work on the annotated query methods
@SuppressWarnings("WeakerAccess")
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
List<KeyRecord> getInternalWithTransaction(String number) {
return getInternal(number);
public Keys(Jdbi database) {
this.database = database;
this.database.registerRowMapper(new KeyRecordRowMapper());
this.database.setTransactionHandler(new SerializableTransactionRunner());
this.database.getConfig(SerializableTransactionRunner.Configuration.class).setMaxRetries(10);
}
// Apparently transaction annotations don't work on the annotated query methods
@SuppressWarnings("WeakerAccess")
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
List<KeyRecord> getInternalWithTransaction(String number, long deviceId) {
return getInternal(number, deviceId);
}
public void store(String number, long deviceId, List<PreKey> keys) {
database.useTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> {
PreparedBatch preparedBatch = handle.prepareBatch("INSERT INTO keys (number, device_id, key_id, public_key) VALUES (:number, :device_id, :key_id, :public_key)");
public List<KeyRecord> get(String number) {
return executeAndRetrySerializableAction(() -> getInternalWithTransaction(number));
for (PreKey key : keys) {
preparedBatch.bind("number", number)
.bind("device_id", deviceId)
.bind("key_id", key.getKeyId())
.bind("public_key", key.getPublicKey())
.add();
}
handle.createUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id")
.bind("number", number)
.bind("device_id", deviceId)
.execute();
preparedBatch.execute();
});
}
public List<KeyRecord> get(String number, long deviceId) {
return executeAndRetrySerializableAction(() -> getInternalWithTransaction(number, deviceId));
return database.inTransaction(TransactionIsolationLevel.SERIALIZABLE,
handle -> handle.createQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *")
.bind("number", number)
.bind("device_id", deviceId)
.mapTo(KeyRecord.class)
.list());
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public void store(String number, long deviceId, List<PreKey> keys) {
List<KeyRecord> records = keys.stream()
.map(key -> new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey()))
.collect(Collectors.toList());
public List<KeyRecord> get(String number) {
return database.inTransaction(TransactionIsolationLevel.SERIALIZABLE,
handle -> handle.createQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *")
.bind("number", number)
.mapTo(KeyRecord.class)
.list());
remove(number, deviceId);
append(records);
}
private List<KeyRecord> executeAndRetrySerializableAction(Callable<List<KeyRecord>> action) {
for (int i=0;i<20;i++) {
try {
return action.call();
} catch (UnableToExecuteStatementException e) {
logger.info("Serializable conflict, retrying: " + e.getMessage());
} catch (Exception e) {
throw new AssertionError(e);
}
}
throw new UnableToExecuteStatementException("Retried statement too many times!");
public int getCount(String number, long deviceId) {
return database.withHandle(handle -> handle.createQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id")
.bind("number", number)
.bind("device_id", deviceId)
.mapTo(Integer.class)
.findOnly());
}
@SqlUpdate("VACUUM keys")
public abstract void vacuum();
@BindingAnnotation(KeyRecordBinder.PreKeyBinderFactory.class)
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.PARAMETER})
public @interface KeyRecordBinder {
public static class PreKeyBinderFactory implements BinderFactory {
@Override
public Binder build(Annotation annotation) {
return new Binder<KeyRecordBinder, KeyRecord>() {
@Override
public void bind(SQLStatement<?> sql, KeyRecordBinder keyRecordBinder, KeyRecord record)
{
sql.bind("id", record.getId());
sql.bind("number", record.getNumber());
sql.bind("device_id", record.getDeviceId());
sql.bind("key_id", record.getKeyId());
sql.bind("public_key", record.getPublicKey());
}
};
}
}
}
public static class KeyRecordMapper implements ResultSetMapper<KeyRecord> {
@Override
public KeyRecord map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
return new KeyRecord(resultSet.getLong("id"), resultSet.getString("number"),
resultSet.getLong("device_id"), resultSet.getLong("key_id"),
resultSet.getString("public_key"));
}
public void vacuum() {
database.useHandle(handle -> handle.execute("VACUUM keys"));
}
}

View File

@ -1,137 +1,106 @@
package org.whispersystems.textsecuregcm.storage;
import org.skife.jdbi.v2.SQLStatement;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.Binder;
import org.skife.jdbi.v2.sqlobject.BinderFactory;
import org.skife.jdbi.v2.sqlobject.BindingAnnotation;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.jdbi.v3.core.Jdbi;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.storage.mappers.OutgoingMessageEntityRowMapper;
import java.lang.annotation.Annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
public abstract class Messages {
public class Messages {
static final int RESULT_SET_CHUNK_SIZE = 100;
private static final String ID = "id";
private static final String GUID = "guid";
private static final String TYPE = "type";
private static final String RELAY = "relay";
private static final String TIMESTAMP = "timestamp";
private static final String SERVER_TIMESTAMP = "server_timestamp";
private static final String SOURCE = "source";
private static final String SOURCE_DEVICE = "source_device";
private static final String DESTINATION = "destination";
private static final String DESTINATION_DEVICE = "destination_device";
private static final String MESSAGE = "message";
private static final String CONTENT = "content";
public static final String ID = "id";
public static final String GUID = "guid";
public static final String TYPE = "type";
public static final String RELAY = "relay";
public static final String TIMESTAMP = "timestamp";
public static final String SERVER_TIMESTAMP = "server_timestamp";
public static final String SOURCE = "source";
public static final String SOURCE_DEVICE = "source_device";
public static final String DESTINATION = "destination";
public static final String DESTINATION_DEVICE = "destination_device";
public static final String MESSAGE = "message";
public static final String CONTENT = "content";
@SqlUpdate("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " +
"VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_device, :destination, :destination_device, :message, :content)")
abstract void store(@Bind("guid") UUID guid,
@MessageBinder Envelope message,
@Bind("destination") String destination,
@Bind("destination_device") long destinationDevice);
private final Jdbi database;
@Mapper(MessageMapper.class)
@SqlQuery("SELECT * FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device ORDER BY " + TIMESTAMP + " ASC LIMIT " + RESULT_SET_CHUNK_SIZE)
abstract List<OutgoingMessageEntity> load(@Bind("destination") String destination,
@Bind("destination_device") long destinationDevice);
@Mapper(MessageMapper.class)
@SqlQuery("DELETE FROM messages WHERE " + ID + " IN (SELECT " + ID + " FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device AND " + SOURCE + " = :source AND " + TIMESTAMP + " = :timestamp ORDER BY " + ID + " LIMIT 1) RETURNING *")
abstract OutgoingMessageEntity remove(@Bind("destination") String destination,
@Bind("destination_device") long destinationDevice,
@Bind("source") String source,
@Bind("timestamp") long timestamp);
@Mapper(MessageMapper.class)
@SqlQuery("DELETE FROM messages WHERE "+ ID + " IN (SELECT " + ID + " FROM MESSAGES WHERE " + GUID + " = :guid AND " + DESTINATION + " = :destination ORDER BY " + ID + " LIMIT 1) RETURNING *")
abstract OutgoingMessageEntity remove(@Bind("destination") String destination, @Bind("guid") UUID guid);
@Mapper(MessageMapper.class)
@SqlUpdate("DELETE FROM messages WHERE " + ID + " = :id AND " + DESTINATION + " = :destination")
abstract void remove(@Bind("destination") String destination, @Bind("id") long id);
@SqlUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination")
abstract void clear(@Bind("destination") String destination);
@SqlUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device")
abstract void clear(@Bind("destination") String destination, @Bind("destination_device") long destinationDevice);
@SqlUpdate("DELETE FROM messages WHERE " + TIMESTAMP + " < :timestamp")
public abstract void removeOld(@Bind("timestamp") long timestamp);
@SqlUpdate("VACUUM messages")
public abstract void vacuum();
public static class MessageMapper implements ResultSetMapper<OutgoingMessageEntity> {
@Override
public OutgoingMessageEntity map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
int type = resultSet.getInt(TYPE);
byte[] legacyMessage = resultSet.getBytes(MESSAGE);
String guid = resultSet.getString(GUID);
if (type == Envelope.Type.RECEIPT_VALUE && legacyMessage == null) {
/// XXX - REMOVE AFTER 10/01/15
legacyMessage = new byte[0];
}
return new OutgoingMessageEntity(resultSet.getLong(ID),
false,
guid == null ? null : UUID.fromString(guid),
type,
resultSet.getString(RELAY),
resultSet.getLong(TIMESTAMP),
resultSet.getString(SOURCE),
resultSet.getInt(SOURCE_DEVICE),
legacyMessage,
resultSet.getBytes(CONTENT),
resultSet.getLong(SERVER_TIMESTAMP));
}
public Messages(Jdbi database) {
this.database = database;
this.database.registerRowMapper(new OutgoingMessageEntityRowMapper());
}
@BindingAnnotation(MessageBinder.MessageBinderFactory.class)
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.PARAMETER})
public @interface MessageBinder {
public static class MessageBinderFactory implements BinderFactory {
@Override
public Binder build(Annotation annotation) {
return new Binder<MessageBinder, Envelope>() {
@Override
public void bind(SQLStatement<?> sql,
MessageBinder accountBinder,
Envelope message)
{
sql.bind(TYPE, message.getType().getNumber());
sql.bind(RELAY, message.getRelay());
sql.bind(TIMESTAMP, message.getTimestamp());
sql.bind(SERVER_TIMESTAMP, message.getServerTimestamp());
sql.bind(SOURCE, message.hasSource() ? message.getSource() : null);
sql.bind(SOURCE_DEVICE, message.hasSourceDevice() ? message.getSourceDevice() : null);
sql.bind(MESSAGE, message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null);
sql.bind(CONTENT, message.hasContent() ? message.getContent().toByteArray() : null);
}
};
}
}
public void store(UUID guid, Envelope message, String destination, long destinationDevice) {
database.useHandle(handle -> {
handle.createUpdate("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " +
"VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_device, :destination, :destination_device, :message, :content)")
.bind("guid", guid)
.bind("destination", destination)
.bind("destination_device", destinationDevice)
.bind("type", message.getType().getNumber())
.bind("relay", message.getRelay())
.bind("timestamp", message.getTimestamp())
.bind("server_timestamp", message.getServerTimestamp())
.bind("source", message.hasSource() ? message.getSource() : null)
.bind("source_device", message.hasSourceDevice() ? message.getSourceDevice() : null)
.bind("message", message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null)
.bind("content", message.hasContent() ? message.getContent().toByteArray() : null)
.execute();
});
}
public List<OutgoingMessageEntity> load(String destination, long destinationDevice) {
return database.withHandle(handle -> handle.createQuery("SELECT * FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device ORDER BY " + TIMESTAMP + " ASC LIMIT " + RESULT_SET_CHUNK_SIZE)
.bind("destination", destination)
.bind("destination_device", destinationDevice)
.mapTo(OutgoingMessageEntity.class)
.list());
}
public Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, String source, long timestamp) {
return database.withHandle(handle -> handle.createQuery("DELETE FROM messages WHERE " + ID + " IN (SELECT " + ID + " FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device AND " + SOURCE + " = :source AND " + TIMESTAMP + " = :timestamp ORDER BY " + ID + " LIMIT 1) RETURNING *")
.bind("destination", destination)
.bind("destination_device", destinationDevice)
.bind("source", source)
.bind("timestamp", timestamp)
.mapTo(OutgoingMessageEntity.class)
.findFirst());
}
public Optional<OutgoingMessageEntity> remove(String destination, UUID guid) {
return database.withHandle(handle -> handle.createQuery("DELETE FROM messages WHERE "+ ID + " IN (SELECT " + ID + " FROM MESSAGES WHERE " + GUID + " = :guid AND " + DESTINATION + " = :destination ORDER BY " + ID + " LIMIT 1) RETURNING *")
.bind("destination", destination)
.bind("guid", guid)
.mapTo(OutgoingMessageEntity.class)
.findFirst());
}
public void remove(String destination, long id) {
database.useHandle(handle -> handle.createUpdate("DELETE FROM messages WHERE " + ID + " = :id AND " + DESTINATION + " = :destination")
.bind("destination", destination)
.bind("id", id)
.execute());
}
public void clear(String destination) {
database.useHandle(handle -> handle.createUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination")
.bind("destination", destination)
.execute());
}
public void clear(String destination, long destinationDevice) {
database.useHandle(handle -> handle.createUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device")
.bind("destination", destination)
.bind("destination_device", destinationDevice)
.execute());
}
public void vacuum() {
database.useHandle(handle -> handle.execute("VACUUM messages"));
}
}

View File

@ -64,7 +64,7 @@ public class MessagesManager {
Optional<OutgoingMessageEntity> removed = this.messagesCache.remove(destination, destinationDevice, source, timestamp);
if (!removed.isPresent()) {
removed = Optional.ofNullable(this.messages.remove(destination, destinationDevice, source, timestamp));
removed = this.messages.remove(destination, destinationDevice, source, timestamp);
cacheMissByNameMeter.mark();
} else {
cacheHitByNameMeter.mark();
@ -77,7 +77,7 @@ public class MessagesManager {
Optional<OutgoingMessageEntity> removed = this.messagesCache.remove(destination, deviceId, guid);
if (!removed.isPresent()) {
removed = Optional.ofNullable(this.messages.remove(destination, guid));
removed = this.messages.remove(destination, guid);
cacheMissByGuidMeter.mark();
} else {
cacheHitByGuidMeter.mark();

View File

@ -16,40 +16,45 @@
*/
package org.whispersystems.textsecuregcm.storage;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.jdbi.v3.core.Jdbi;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.storage.mappers.StoredVerificationCodeRowMapper;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Optional;
public interface PendingAccounts {
public class PendingAccounts {
@SqlUpdate("WITH upsert AS (UPDATE pending_accounts SET verification_code = :verification_code, timestamp = :timestamp WHERE number = :number RETURNING *) " +
"INSERT INTO pending_accounts (number, verification_code, timestamp) SELECT :number, :verification_code, :timestamp WHERE NOT EXISTS (SELECT * FROM upsert)")
void insert(@Bind("number") String number, @Bind("verification_code") String verificationCode, @Bind("timestamp") long timestamp);
private final Jdbi database;
@Mapper(StoredVerificationCodeMapper.class)
@SqlQuery("SELECT verification_code, timestamp FROM pending_accounts WHERE number = :number")
StoredVerificationCode getCodeForNumber(@Bind("number") String number);
@SqlUpdate("DELETE FROM pending_accounts WHERE number = :number")
void remove(@Bind("number") String number);
@SqlUpdate("VACUUM pending_accounts")
public void vacuum();
public static class StoredVerificationCodeMapper implements ResultSetMapper<StoredVerificationCode> {
@Override
public StoredVerificationCode map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
return new StoredVerificationCode(resultSet.getString("verification_code"),
resultSet.getLong("timestamp"));
}
public PendingAccounts(Jdbi database) {
this.database = database;
this.database.registerRowMapper(new StoredVerificationCodeRowMapper());
}
public void insert(String number, String verificationCode, long timestamp) {
database.useHandle(handle -> handle.createUpdate("WITH upsert AS (UPDATE pending_accounts SET verification_code = :verification_code, timestamp = :timestamp WHERE number = :number RETURNING *) " +
"INSERT INTO pending_accounts (number, verification_code, timestamp) SELECT :number, :verification_code, :timestamp WHERE NOT EXISTS (SELECT * FROM upsert)")
.bind("verification_code", verificationCode)
.bind("timestamp", timestamp)
.bind("number", number)
.execute());
}
public Optional<StoredVerificationCode> getCodeForNumber(String number) {
return database.withHandle(handle -> handle.createQuery("SELECT verification_code, timestamp FROM pending_accounts WHERE number = :number")
.bind("number", number)
.mapTo(StoredVerificationCode.class)
.findFirst());
}
public void remove(String number) {
database.useHandle(handle -> handle.createUpdate("DELETE FROM pending_accounts WHERE number = :number")
.bind("number", number)
.execute());
}
public void vacuum() {
database.useHandle(handle -> handle.execute("VACUUM pending_accounts"));
}
}

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@ -60,11 +60,8 @@ public class PendingAccountsManager {
Optional<StoredVerificationCode> code = memcacheGet(number);
if (!code.isPresent()) {
code = Optional.ofNullable(pendingAccounts.getCodeForNumber(number));
if (code.isPresent()) {
memcacheSet(number, code.get());
}
code = pendingAccounts.getCodeForNumber(number);
code.ifPresent(storedVerificationCode -> memcacheSet(number, storedVerificationCode));
}
return code;

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@ -16,38 +16,41 @@
*/
package org.whispersystems.textsecuregcm.storage;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.jdbi.v3.core.Jdbi;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.storage.mappers.StoredVerificationCodeRowMapper;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Optional;
public interface PendingDevices {
public class PendingDevices {
@SqlUpdate("WITH upsert AS (UPDATE pending_devices SET verification_code = :verification_code, timestamp = :timestamp WHERE number = :number RETURNING *) " +
"INSERT INTO pending_devices (number, verification_code, timestamp) SELECT :number, :verification_code, :timestamp WHERE NOT EXISTS (SELECT * FROM upsert)")
void insert(@Bind("number") String number, @Bind("verification_code") String verificationCode, @Bind("timestamp") long timestamp);
private final Jdbi database;
@Mapper(StoredVerificationCodeMapper.class)
@SqlQuery("SELECT verification_code, timestamp FROM pending_devices WHERE number = :number")
StoredVerificationCode getCodeForNumber(@Bind("number") String number);
public PendingDevices(Jdbi database) {
this.database = database;
this.database.registerRowMapper(new StoredVerificationCodeRowMapper());
}
@SqlUpdate("DELETE FROM pending_devices WHERE number = :number")
void remove(@Bind("number") String number);
public void insert(String number, String verificationCode, long timestamp) {
database.useHandle(handle -> handle.createUpdate("WITH upsert AS (UPDATE pending_devices SET verification_code = :verification_code, timestamp = :timestamp WHERE number = :number RETURNING *) " +
"INSERT INTO pending_devices (number, verification_code, timestamp) SELECT :number, :verification_code, :timestamp WHERE NOT EXISTS (SELECT * FROM upsert)")
.bind("number", number)
.bind("verification_code", verificationCode)
.bind("timestamp", timestamp)
.execute());
}
public static class StoredVerificationCodeMapper implements ResultSetMapper<StoredVerificationCode> {
@Override
public StoredVerificationCode map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
return new StoredVerificationCode(resultSet.getString("verification_code"),
resultSet.getLong("timestamp"));
}
public Optional<StoredVerificationCode> getCodeForNumber(String number) {
return database.withHandle(handle -> handle.createQuery("SELECT verification_code, timestamp FROM pending_devices WHERE number = :number")
.bind("number", number)
.mapTo(StoredVerificationCode.class)
.findFirst());
}
public void remove(String number) {
database.useHandle(handle -> handle.createUpdate("DELETE FROM pending_devices WHERE number = :number")
.bind("number", number)
.execute());
}
}

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@ -59,11 +59,8 @@ public class PendingDevicesManager {
Optional<StoredVerificationCode> code = memcacheGet(number);
if (!code.isPresent()) {
code = Optional.ofNullable(pendingDevices.getCodeForNumber(number));
if (code.isPresent()) {
memcacheSet(number, code.get());
}
code = pendingDevices.getCodeForNumber(number);
code.ifPresent(storedVerificationCode -> memcacheSet(number, storedVerificationCode));
}
return code;

View File

@ -0,0 +1,28 @@
package org.whispersystems.textsecuregcm.storage.mappers;
import org.jdbi.v3.core.mapper.RowMapper;
import org.jdbi.v3.core.statement.StatementContext;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRule;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
public class AbusiveHostRuleRowMapper implements RowMapper<AbusiveHostRule> {
@Override
public AbusiveHostRule map(ResultSet resultSet, StatementContext ctx) throws SQLException {
String regionsData = resultSet.getString(AbusiveHostRules.REGIONS);
List<String> regions;
if (regionsData == null) regions = new LinkedList<>();
else regions = Arrays.asList(regionsData.split(","));
return new AbusiveHostRule(resultSet.getString(AbusiveHostRules.HOST), resultSet.getInt(AbusiveHostRules.BLOCKED) == 1, regions);
}
}

View File

@ -0,0 +1,29 @@
package org.whispersystems.textsecuregcm.storage.mappers;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.jdbi.v3.core.mapper.RowMapper;
import org.jdbi.v3.core.statement.StatementContext;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException;
import java.sql.ResultSet;
import java.sql.SQLException;
public class AccountRowMapper implements RowMapper<Account> {
private static ObjectMapper mapper = SystemMapper.getMapper();
@Override
public Account map(ResultSet resultSet, StatementContext ctx) throws SQLException {
try {
Account account = mapper.readValue(resultSet.getString(Accounts.DATA), Account.class);
account.setNumber(resultSet.getString(Accounts.NUMBER));
return account;
} catch (IOException e) {
throw new SQLException(e);
}
}
}

View File

@ -0,0 +1,20 @@
package org.whispersystems.textsecuregcm.storage.mappers;
import org.jdbi.v3.core.mapper.RowMapper;
import org.jdbi.v3.core.statement.StatementContext;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import java.sql.ResultSet;
import java.sql.SQLException;
public class KeyRecordRowMapper implements RowMapper<KeyRecord> {
@Override
public KeyRecord map(ResultSet resultSet, StatementContext ctx) throws SQLException {
return new KeyRecord(resultSet.getLong("id"),
resultSet.getString("number"),
resultSet.getLong("device_id"),
resultSet.getLong("key_id"),
resultSet.getString("public_key"));
}
}

View File

@ -0,0 +1,37 @@
package org.whispersystems.textsecuregcm.storage.mappers;
import org.jdbi.v3.core.mapper.RowMapper;
import org.jdbi.v3.core.statement.StatementContext;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.storage.Messages;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.UUID;
public class OutgoingMessageEntityRowMapper implements RowMapper<OutgoingMessageEntity> {
@Override
public OutgoingMessageEntity map(ResultSet resultSet, StatementContext ctx) throws SQLException {
int type = resultSet.getInt(Messages.TYPE);
byte[] legacyMessage = resultSet.getBytes(Messages.MESSAGE);
String guid = resultSet.getString(Messages.GUID);
if (type == Envelope.Type.RECEIPT_VALUE && legacyMessage == null) {
/// XXX - REMOVE AFTER 10/01/15
legacyMessage = new byte[0];
}
return new OutgoingMessageEntity(resultSet.getLong(Messages.ID),
false,
guid == null ? null : UUID.fromString(guid),
type,
resultSet.getString(Messages.RELAY),
resultSet.getLong(Messages.TIMESTAMP),
resultSet.getString(Messages.SOURCE),
resultSet.getInt(Messages.SOURCE_DEVICE),
legacyMessage,
resultSet.getBytes(Messages.CONTENT),
resultSet.getLong(Messages.SERVER_TIMESTAMP));
}
}

View File

@ -0,0 +1,17 @@
package org.whispersystems.textsecuregcm.storage.mappers;
import org.jdbi.v3.core.mapper.RowMapper;
import org.jdbi.v3.core.statement.StatementContext;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import java.sql.ResultSet;
import java.sql.SQLException;
public class StoredVerificationCodeRowMapper implements RowMapper<StoredVerificationCode> {
@Override
public StoredVerificationCode map(ResultSet resultSet, StatementContext ctx) throws SQLException {
return new StoredVerificationCode(resultSet.getString("verification_code"),
resultSet.getLong("timestamp"));
}
}

View File

@ -3,7 +3,7 @@ package org.whispersystems.textsecuregcm.workers;
import com.fasterxml.jackson.databind.DeserializationFeature;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.skife.jdbi.v2.DBI;
import org.jdbi.v3.core.Jdbi;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
@ -23,17 +23,12 @@ import java.util.Optional;
import io.dropwizard.Application;
import io.dropwizard.cli.EnvironmentCommand;
import io.dropwizard.db.DataSourceFactory;
import io.dropwizard.jdbi.ImmutableListContainerFactory;
import io.dropwizard.jdbi.ImmutableSetContainerFactory;
import io.dropwizard.jdbi.OptionalContainerFactory;
import io.dropwizard.jdbi.args.OptionalArgumentFactory;
import io.dropwizard.jdbi3.JdbiFactory;
import io.dropwizard.setup.Environment;
import redis.clients.jedis.JedisPool;
public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfiguration> {
private final Logger logger = LoggerFactory.getLogger(DirectoryCommand.class);
private final Logger logger = LoggerFactory.getLogger(DeleteUserCommand.class);
public DeleteUserCommand() {
super(new Application<WhisperServerConfiguration>() {
@ -66,15 +61,10 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
environment.getObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
DataSourceFactory dbConfig = configuration.getDataSourceFactory();
DBI dbi = new DBI(dbConfig.getUrl(), dbConfig.getUser(), dbConfig.getPassword());
JdbiFactory jdbiFactory = new JdbiFactory();
Jdbi accountDatabase = jdbiFactory.build(environment, configuration.getDataSourceFactory(), "accountdb");
dbi.registerArgumentFactory(new OptionalArgumentFactory(dbConfig.getDriverClass()));
dbi.registerContainerFactory(new ImmutableListContainerFactory());
dbi.registerContainerFactory(new ImmutableSetContainerFactory());
dbi.registerContainerFactory(new OptionalContainerFactory());
Accounts accounts = dbi.onDemand(Accounts.class);
Accounts accounts = new Accounts(accountDatabase);
ReplicatedJedisPool cacheClient = new RedisClientFactory("main_cache_delete_command", configuration.getCacheConfiguration().getUrl(), configuration.getCacheConfiguration().getReplicaUrls(), configuration.getCacheConfiguration().getCircuitBreakerConfiguration()).getRedisClientPool();
ReplicatedJedisPool redisClient = new RedisClientFactory("directory_cache_delete_command", configuration.getDirectoryConfiguration().getRedisConfiguration().getUrl(), configuration.getDirectoryConfiguration().getRedisConfiguration().getReplicaUrls(), configuration.getDirectoryConfiguration().getRedisConfiguration().getCircuitBreakerConfiguration()).getRedisClientPool();
DirectoryQueue directoryQueue = new DirectoryQueue(configuration.getDirectoryConfiguration().getSqsConfiguration());

View File

@ -1,50 +0,0 @@
package org.whispersystems.textsecuregcm.workers;
import net.sourceforge.argparse4j.inf.Namespace;
import org.skife.jdbi.v2.DBI;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.storage.Messages;
import java.util.concurrent.TimeUnit;
import io.dropwizard.cli.ConfiguredCommand;
import io.dropwizard.db.DataSourceFactory;
import io.dropwizard.jdbi.ImmutableListContainerFactory;
import io.dropwizard.jdbi.ImmutableSetContainerFactory;
import io.dropwizard.jdbi.OptionalContainerFactory;
import io.dropwizard.jdbi.args.OptionalArgumentFactory;
import io.dropwizard.setup.Bootstrap;
public class TrimMessagesCommand extends ConfiguredCommand<WhisperServerConfiguration> {
private final Logger logger = LoggerFactory.getLogger(TrimMessagesCommand.class);
public TrimMessagesCommand() {
super("trim", "Trim Messages Database");
}
@Override
protected void run(Bootstrap<WhisperServerConfiguration> bootstrap,
Namespace namespace,
WhisperServerConfiguration config)
throws Exception
{
DataSourceFactory messageDbConfig = config.getMessageStoreConfiguration();
DBI messageDbi = new DBI(messageDbConfig.getUrl(), messageDbConfig.getUser(), messageDbConfig.getPassword());
messageDbi.registerArgumentFactory(new OptionalArgumentFactory(messageDbConfig.getDriverClass()));
messageDbi.registerContainerFactory(new ImmutableListContainerFactory());
messageDbi.registerContainerFactory(new ImmutableSetContainerFactory());
messageDbi.registerContainerFactory(new OptionalContainerFactory());
Messages messages = messageDbi.onDemand(Messages.class);
long timestamp = System.currentTimeMillis() - TimeUnit.DAYS.toMillis(90);
logger.info("Trimming old messages: " + timestamp + "...");
messages.removeOld(timestamp);
Thread.sleep(3000);
System.exit(0);
}
}

View File

@ -1,7 +1,7 @@
package org.whispersystems.textsecuregcm.workers;
import net.sourceforge.argparse4j.inf.Namespace;
import org.skife.jdbi.v2.DBI;
import org.jdbi.v3.core.Jdbi;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
@ -12,10 +12,6 @@ import org.whispersystems.textsecuregcm.storage.PendingAccounts;
import io.dropwizard.cli.ConfiguredCommand;
import io.dropwizard.db.DataSourceFactory;
import io.dropwizard.jdbi.ImmutableListContainerFactory;
import io.dropwizard.jdbi.ImmutableSetContainerFactory;
import io.dropwizard.jdbi.OptionalContainerFactory;
import io.dropwizard.jdbi.args.OptionalArgumentFactory;
import io.dropwizard.setup.Bootstrap;
@ -35,23 +31,15 @@ public class VacuumCommand extends ConfiguredCommand<WhisperServerConfiguration>
{
DataSourceFactory dbConfig = config.getDataSourceFactory();
DataSourceFactory messageDbConfig = config.getMessageStoreConfiguration();
DBI dbi = new DBI(dbConfig.getUrl(), dbConfig.getUser(), dbConfig.getPassword() );
DBI messageDbi = new DBI(messageDbConfig.getUrl(), messageDbConfig.getUser(), messageDbConfig.getPassword());
dbi.registerArgumentFactory(new OptionalArgumentFactory(dbConfig.getDriverClass()));
dbi.registerContainerFactory(new ImmutableListContainerFactory());
dbi.registerContainerFactory(new ImmutableSetContainerFactory());
dbi.registerContainerFactory(new OptionalContainerFactory());
Jdbi accountDatabase = Jdbi.create(dbConfig.getUrl(), dbConfig.getUser(), dbConfig.getPassword());
Jdbi messageDatabase = Jdbi.create(messageDbConfig.getUrl(), messageDbConfig.getUser(), messageDbConfig.getPassword());
messageDbi.registerArgumentFactory(new OptionalArgumentFactory(dbConfig.getDriverClass()));
messageDbi.registerContainerFactory(new ImmutableListContainerFactory());
messageDbi.registerContainerFactory(new ImmutableSetContainerFactory());
messageDbi.registerContainerFactory(new OptionalContainerFactory());
Accounts accounts = dbi.onDemand(Accounts.class );
Keys keys = dbi.onDemand(Keys.class );
PendingAccounts pendingAccounts = dbi.onDemand(PendingAccounts.class);
Messages messages = messageDbi.onDemand(Messages.class);
Accounts accounts = new Accounts(accountDatabase);
Keys keys = new Keys(accountDatabase);
PendingAccounts pendingAccounts = new PendingAccounts(accountDatabase);
Messages messages = new Messages(messageDatabase);
logger.info("Vacuuming accounts...");
accounts.vacuum();

View File

@ -0,0 +1,85 @@
package org.whispersystems.textsecuregcm.tests.storage;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
import com.opentable.db.postgres.junit.PreparedDbRule;
import org.jdbi.v3.core.Jdbi;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRule;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
public class AbusiveHostRulesTest {
@Rule
public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("abusedb.xml"));
private AbusiveHostRules abusiveHostRules;
@Before
public void setup() {
this.abusiveHostRules = new AbusiveHostRules(Jdbi.create(db.getTestDatabase()));
}
@Test
public void testBlockedHost() throws SQLException {
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)");
statement.setString(1, "192.168.1.1");
statement.setInt(2, 1);
statement.execute();
List<AbusiveHostRule> rules = abusiveHostRules.getAbusiveHostRulesFor("192.168.1.1");
assertThat(rules.size()).isEqualTo(1);
assertThat(rules.get(0).getRegions().isEmpty()).isTrue();
assertThat(rules.get(0).getHost()).isEqualTo("192.168.1.1");
assertThat(rules.get(0).isBlocked()).isTrue();
}
@Test
public void testBlockedCidr() throws SQLException {
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)");
statement.setString(1, "192.168.1.0/24");
statement.setInt(2, 1);
statement.execute();
List<AbusiveHostRule> rules = abusiveHostRules.getAbusiveHostRulesFor("192.168.1.1");
assertThat(rules.size()).isEqualTo(1);
assertThat(rules.get(0).getRegions().isEmpty()).isTrue();
assertThat(rules.get(0).getHost()).isEqualTo("192.168.1.0/24");
assertThat(rules.get(0).isBlocked()).isTrue();
}
@Test
public void testUnblocked() throws SQLException {
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked) VALUES (?::INET, ?)");
statement.setString(1, "192.168.1.0/24");
statement.setInt(2, 1);
statement.execute();
List<AbusiveHostRule> rules = abusiveHostRules.getAbusiveHostRulesFor("172.17.1.1");
assertThat(rules.isEmpty()).isTrue();
}
@Test
public void testRestricted() throws SQLException {
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("INSERT INTO abusive_host_rules (host, blocked, regions) VALUES (?::INET, ?, ?)");
statement.setString(1, "192.168.1.0/24");
statement.setInt(2, 0);
statement.setString(3, "+1,+49");
statement.execute();
List<AbusiveHostRule> rules = abusiveHostRules.getAbusiveHostRulesFor("192.168.1.100");
assertThat(rules.size()).isEqualTo(1);
assertThat(rules.get(0).isBlocked()).isFalse();
assertThat(rules.get(0).getRegions()).isEqualTo(Arrays.asList("+1", "+49"));
}
}

View File

@ -54,7 +54,7 @@ public class AccountsManagerTest {
when(cacheClient.getReadResource()).thenReturn(jedis);
when(cacheClient.getWriteResource()).thenReturn(jedis);
when(jedis.get(eq("Account5+14152222222"))).thenReturn(null);
when(accounts.get(eq("+14152222222"))).thenReturn(account);
when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account));
AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient);
Optional<Account> retrieved = accountsManager.get("+14152222222");
@ -82,7 +82,7 @@ public class AccountsManagerTest {
when(cacheClient.getReadResource()).thenReturn(jedis);
when(cacheClient.getWriteResource()).thenReturn(jedis);
when(jedis.get(eq("Account5+14152222222"))).thenThrow(new JedisException("Connection lost!"));
when(accounts.get(eq("+14152222222"))).thenReturn(account);
when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account));
AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient);
Optional<Account> retrieved = accountsManager.get("+14152222222");

View File

@ -0,0 +1,244 @@
package org.whispersystems.textsecuregcm.tests.storage;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
import com.opentable.db.postgres.junit.PreparedDbRule;
import org.jdbi.v3.core.Jdbi;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.mappers.AccountRowMapper;
import java.io.IOException;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
public class AccountsTest {
@Rule
public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
private Accounts accounts;
@Before
public void setupAccountsDao() {
this.accounts = new Accounts(Jdbi.create(db.getTestDatabase()));
}
@Test
public void testStore() throws SQLException, IOException {
Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", Collections.singleton(device));
accounts.create(account);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?");
verifyStoredState(statement, "+14151112222", account);
}
@Test
public void testStoreMulti() throws SQLException, IOException {
Set<Device> devices = new HashSet<>();
devices.add(generateDevice(1));
devices.add(generateDevice(2));
Account account = generateAccount("+14151112222", devices);
accounts.create(account);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?");
verifyStoredState(statement, "+14151112222", account);
}
@Test
public void testRetrieve() {
Set<Device> devicesFirst = new HashSet<>();
devicesFirst.add(generateDevice(1));
devicesFirst.add(generateDevice(2));
Account accountFirst = generateAccount("+14151112222", devicesFirst);
Set<Device> devicesSecond = new HashSet<>();
devicesSecond.add(generateDevice(1));
devicesSecond.add(generateDevice(2));
Account accountSecond = generateAccount("+14152221111", devicesSecond);
accounts.create(accountFirst);
accounts.create(accountSecond);
Optional<Account> retrievedFirst = accounts.get("+14151112222");
Optional<Account> retrievedSecond = accounts.get("+14152221111");
assertThat(retrievedFirst.isPresent()).isTrue();
assertThat(retrievedSecond.isPresent()).isTrue();
verifyStoredState("+14151112222", retrievedFirst.get(), accountFirst);
verifyStoredState("+14152221111", retrievedSecond.get(), accountSecond);
}
@Test
public void testOverwrite() throws Exception {
Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", Collections.singleton(device));
accounts.create(account);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?");
verifyStoredState(statement, "+14151112222", account);
device = generateDevice(1);
account = generateAccount("+14151112222", Collections.singleton(device));
accounts.create(account);
verifyStoredState(statement, "+14151112222", account);
}
@Test
public void testUpdate() {
Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", Collections.singleton(device));
accounts.create(account);
device.setName("foobar");
accounts.update(account);
Optional<Account> retrieved = accounts.get("+14151112222");
assertThat(retrieved.isPresent()).isTrue();
verifyStoredState("+14151112222", retrieved.get(), account);
}
@Test
public void testRetrieveFrom() {
List<Account> users = new ArrayList<>();
for (int i=1;i<=100;i++) {
Account account = generateAccount("+1" + String.format("%03d", i));
users.add(account);
accounts.create(account);
}
List<Account> retrieved = accounts.getAllFrom(10);
assertThat(retrieved.size()).isEqualTo(10);
for (int i=0;i<retrieved.size();i++) {
verifyStoredState("+1" + String.format("%03d", (i + 1)), retrieved.get(i), users.get(i));
}
for (int j=0;j<9;j++) {
retrieved = accounts.getAllFrom(retrieved.get(9).getNumber(), 10);
assertThat(retrieved.size()).isEqualTo(10);
for (int i=0;i<retrieved.size();i++) {
verifyStoredState("+1" + String.format("%03d", (10 + (j * 10) + i + 1)), retrieved.get(i), users.get(10 + (j * 10) + i));
}
}
}
@Test
public void testVacuum() {
Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", Collections.singleton(device));
accounts.create(account);
accounts.vacuum();
Optional<Account> retrieved = accounts.get("+14151112222");
assertThat(retrieved.isPresent()).isTrue();
verifyStoredState("+14151112222", retrieved.get(), account);
}
@Test
public void testMissing() {
Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", Collections.singleton(device));
accounts.create(account);
Optional<Account> retrieved = accounts.get("+11111111");
assertThat(retrieved.isPresent()).isFalse();
}
private Device generateDevice(long id) {
Random random = new Random(System.currentTimeMillis());
SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(), "testSignature-" + random.nextInt());
return new Device(1, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(), "testSalt-" + random.nextInt(), null, "testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(), random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(), "testUserAgent-" + random.nextInt(), random.nextBoolean());
}
private Account generateAccount(String number) {
Device device = generateDevice(1);
return generateAccount(number, Collections.singleton(device));
}
private Account generateAccount(String number, Set<Device> devices) {
byte[] unidentifiedAccessKey = new byte[16];
Random random = new Random(System.currentTimeMillis());
Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255));
return new Account(number, devices, unidentifiedAccessKey);
}
private void verifyStoredState(PreparedStatement statement, String number, Account expecting)
throws SQLException, IOException
{
statement.setString(1, number);
ResultSet resultSet = statement.executeQuery();
if (resultSet.next()) {
String data = resultSet.getString("data");
assertThat(data).isNotEmpty();
Account result = new AccountRowMapper().map(resultSet, null);
verifyStoredState(number, result, expecting);
} else {
throw new AssertionError("No data");
}
assertThat(resultSet.next()).isFalse();
}
private void verifyStoredState(String number, Account result, Account expecting) {
assertThat(result.getNumber()).isEqualTo(number);
assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen());
assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue();
for (Device expectingDevice : expecting.getDevices()) {
Device resultDevice = result.getDevice(expectingDevice.getId()).get();
assertThat(resultDevice.getApnId()).isEqualTo(expectingDevice.getApnId());
assertThat(resultDevice.getGcmId()).isEqualTo(expectingDevice.getGcmId());
assertThat(resultDevice.getLastSeen()).isEqualTo(expectingDevice.getLastSeen());
assertThat(resultDevice.getSignedPreKey().getPublicKey()).isEqualTo(expectingDevice.getSignedPreKey().getPublicKey());
assertThat(resultDevice.getSignedPreKey().getKeyId()).isEqualTo(expectingDevice.getSignedPreKey().getKeyId());
assertThat(resultDevice.getSignedPreKey().getSignature()).isEqualTo(expectingDevice.getSignedPreKey().getSignature());
assertThat(resultDevice.getFetchesMessages()).isEqualTo(expectingDevice.getFetchesMessages());
assertThat(resultDevice.getUserAgent()).isEqualTo(expectingDevice.getUserAgent());
assertThat(resultDevice.getName()).isEqualTo(expectingDevice.getName());
assertThat(resultDevice.getCreated()).isEqualTo(expectingDevice.getCreated());
}
}
}

View File

@ -3,9 +3,9 @@ package org.whispersystems.textsecuregcm.tests.storage;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
import com.opentable.db.postgres.junit.PreparedDbRule;
import org.jdbi.v3.core.Jdbi;
import org.junit.Rule;
import org.junit.Test;
import org.skife.jdbi.v2.DBI;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import org.whispersystems.textsecuregcm.storage.Keys;
@ -27,8 +27,8 @@ public class KeysTest {
@Test
public void testPopulateKeys() throws SQLException {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
Jdbi jdbi = Jdbi.create(dataSource);
Keys keys = new Keys(jdbi);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
@ -63,10 +63,10 @@ public class KeysTest {
}
@Test
public void testKeyCount() throws SQLException {
public void testKeyCount() {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
Jdbi jdbi = Jdbi.create(dataSource);
Keys keys = new Keys(jdbi);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
@ -83,8 +83,8 @@ public class KeysTest {
@Test
public void testGetForDevice() {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
Jdbi jdbi = Jdbi.create(dataSource);
Keys keys = new Keys(jdbi);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
@ -144,8 +144,8 @@ public class KeysTest {
@Test
public void testGetForAllDevices() {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
Jdbi jdbi = Jdbi.create(dataSource);
Keys keys = new Keys(jdbi);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
@ -220,8 +220,8 @@ public class KeysTest {
@Test
public void testGetForAllDevicesParallel() throws InterruptedException {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
Jdbi jdbi = Jdbi.create(dataSource);
Keys keys = new Keys(jdbi);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
@ -239,7 +239,7 @@ public class KeysTest {
List<Thread> threads = new LinkedList<>();
for (int i=0;i<50;i++) {
for (int i=0;i<20;i++) {
Thread thread = new Thread(() -> {
List<KeyRecord> results = keys.get("+14152222222");
assertThat(results.size()).isEqualTo(2);
@ -252,21 +252,31 @@ public class KeysTest {
thread.join();
}
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(50);
assertThat(keys.getCount("+14152222222",2)).isEqualTo(50);
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(80);
assertThat(keys.getCount("+14152222222",2)).isEqualTo(80);
}
@Test
public void testEmptyKeyGet() {
DBI dbi = new DBI(db.getTestDatabase());
Keys keys = dbi.onDemand(Keys.class);
DataSource dataSource = db.getTestDatabase();
Jdbi jdbi = Jdbi.create(dataSource);
Keys keys = new Keys(jdbi);
List<KeyRecord> records = keys.get("+14152222222");
assertThat(records.isEmpty()).isTrue();
}
@Test
public void testVacuum() {
DataSource dataSource = db.getTestDatabase();
Jdbi jdbi = Jdbi.create(dataSource);
Keys keys = new Keys(jdbi);
keys.vacuum();
}
private void verifyStoredState(PreparedStatement statement, String number, int deviceId) throws SQLException {
statement.setString(1, number);

View File

@ -0,0 +1,252 @@
package org.whispersystems.textsecuregcm.tests.storage;
import com.google.protobuf.ByteString;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
import com.opentable.db.postgres.junit.PreparedDbRule;
import org.jdbi.v3.core.Jdbi;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.storage.Messages;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
public class MessagesTest {
@Rule
public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("messagedb.xml"));
private Messages messages;
@Before
public void setupAccountsDao() {
this.messages = new Messages(Jdbi.create(db.getTestDatabase()));
}
@Test
public void testStore() throws SQLException {
Envelope envelope = generateEnvelope();
UUID guid = UUID.randomUUID();
messages.store(guid, envelope, "+14151112222", 1);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM messages WHERE destination = ?");
statement.setString(1, "+14151112222");
ResultSet resultSet = statement.executeQuery();
assertThat(resultSet.next()).isTrue();
assertThat(resultSet.getString("guid")).isEqualTo(guid.toString());
assertThat(resultSet.getInt("type")).isEqualTo(envelope.getType().getNumber());
assertThat(resultSet.getString("relay")).isNullOrEmpty();
assertThat(resultSet.getLong("timestamp")).isEqualTo(envelope.getTimestamp());
assertThat(resultSet.getLong("server_timestamp")).isEqualTo(envelope.getServerTimestamp());
assertThat(resultSet.getString("source")).isEqualTo(envelope.getSource());
assertThat(resultSet.getLong("source_device")).isEqualTo(envelope.getSourceDevice());
assertThat(resultSet.getBytes("message")).isEqualTo(envelope.getLegacyMessage().toByteArray());
assertThat(resultSet.getBytes("content")).isEqualTo(envelope.getContent().toByteArray());
assertThat(resultSet.getString("destination")).isEqualTo("+14151112222");
assertThat(resultSet.getLong("destination_device")).isEqualTo(1);
assertThat(resultSet.next()).isFalse();
}
@Test
public void testLoad() {
List<MessageToStore> inserted = new ArrayList<>(50);
for (int i=0;i<50;i++) {
MessageToStore message = generateMessageToStore();
inserted.add(message);
messages.store(message.guid, message.envelope, "+14151112222", 1);
}
inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp()));
List<OutgoingMessageEntity> retrieved = messages.load("+14151112222", 1);
assertThat(retrieved.size()).isEqualTo(inserted.size());
for (int i=0;i<retrieved.size();i++) {
verifyExpected(retrieved.get(i), inserted.get(i).envelope, inserted.get(i).guid);
}
}
@Test
public void removeBySourceDestinationTimestamp() {
List<MessageToStore> inserted = insertRandom("+14151112222", 1);
List<MessageToStore> unrelated = insertRandom("+14151114444", 3);
MessageToStore toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1));
Optional<OutgoingMessageEntity> removed = messages.remove("+14151112222", 1, toRemove.envelope.getSource(), toRemove.envelope.getTimestamp());
assertThat(removed.isPresent()).isTrue();
verifyExpected(removed.get(), toRemove.envelope, toRemove.guid);
verifyInTact(inserted, "+14151112222", 1);
verifyInTact(unrelated, "+14151114444", 3);
}
@Test
public void removeByDestinationGuid() {
List<MessageToStore> unrelated = insertRandom("+14151113333", 2);
List<MessageToStore> inserted = insertRandom("+14151112222", 1);
MessageToStore toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1));
Optional<OutgoingMessageEntity> removed = messages.remove("+14151112222", toRemove.guid);
assertThat(removed.isPresent()).isTrue();
verifyExpected(removed.get(), toRemove.envelope, toRemove.guid);
verifyInTact(inserted, "+14151112222", 1);
verifyInTact(unrelated, "+14151113333", 2);
}
@Test
public void removeByDestinationRowId() {
List<MessageToStore> unrelatedInserted = insertRandom("+14151111111", 1);
List<MessageToStore> inserted = insertRandom("+14151112222", 1);
inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp()));
List<OutgoingMessageEntity> retrieved = messages.load("+14151112222", 1);
int toRemoveIndex = new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1);
inserted.remove(toRemoveIndex);
messages.remove("+14151112222", retrieved.get(toRemoveIndex).getId());
verifyInTact(inserted, "+14151112222", 1);
verifyInTact(unrelatedInserted, "+14151111111", 1);
}
@Test
public void testLoadEmpty() {
List<MessageToStore> inserted = insertRandom("+14151112222", 1);
List<OutgoingMessageEntity> loaded = messages.load("+14159999999", 1);
assertThat(loaded.isEmpty()).isTrue();
}
@Test
public void testClearDestination() {
insertRandom("+14151112222", 1);
insertRandom("+14151112222", 2);
List<MessageToStore> unrelated = insertRandom("+14151111111", 1);
messages.clear("+14151112222");
assertThat(messages.load("+14151112222", 1).isEmpty()).isTrue();
verifyInTact(unrelated, "+14151111111", 1);
}
@Test
public void testClearDestinationDevice() {
insertRandom("+14151112222", 1);
List<MessageToStore> inserted = insertRandom("+14151112222", 2);
List<MessageToStore> unrelated = insertRandom("+14151111111", 1);
messages.clear("+14151112222", 1);
assertThat(messages.load("+14151112222", 1).isEmpty()).isTrue();
verifyInTact(inserted, "+14151112222", 2);
verifyInTact(unrelated, "+14151111111", 1);
}
@Test
public void testVacuum() {
List<MessageToStore> inserted = insertRandom("+14151112222", 2);
messages.vacuum();
verifyInTact(inserted, "+14151112222", 2);
}
private List<MessageToStore> insertRandom(String destination, int destinationDevice) {
List<MessageToStore> inserted = new ArrayList<>(50);
for (int i=0;i<50;i++) {
MessageToStore message = generateMessageToStore();
inserted.add(message);
messages.store(message.guid, message.envelope, destination, destinationDevice);
}
return inserted;
}
private void verifyInTact(List<MessageToStore> inserted, String destination, int destinationDevice) {
inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp()));
List<OutgoingMessageEntity> retrieved = messages.load(destination, destinationDevice);
assertThat(retrieved.size()).isEqualTo(inserted.size());
for (int i=0;i<retrieved.size();i++) {
verifyExpected(retrieved.get(i), inserted.get(i).envelope, inserted.get(i).guid);
}
}
private void verifyExpected(OutgoingMessageEntity retrieved, Envelope inserted, UUID guid) {
assertThat(retrieved.getSource()).isEqualTo(inserted.getSource());
assertThat(retrieved.getTimestamp()).isEqualTo(inserted.getTimestamp());
assertThat(retrieved.getRelay()).isEqualTo(inserted.getRelay());
assertThat(retrieved.getType()).isEqualTo(inserted.getType().getNumber());
assertThat(retrieved.getContent()).isEqualTo(inserted.getContent().toByteArray());
assertThat(retrieved.getMessage()).isEqualTo(inserted.getLegacyMessage().toByteArray());
assertThat(retrieved.getServerTimestamp()).isEqualTo(inserted.getServerTimestamp());
assertThat(retrieved.getGuid()).isEqualTo(guid);
assertThat(retrieved.getSourceDevice()).isEqualTo(inserted.getSourceDevice());
}
private MessageToStore generateMessageToStore() {
return new MessageToStore(UUID.randomUUID(), generateEnvelope());
}
private Envelope generateEnvelope() {
Random random = new Random();
byte[] content = new byte[256];
byte[] legacy = new byte[200];
Arrays.fill(content, (byte)random.nextInt(255));
Arrays.fill(legacy, (byte)random.nextInt(255));
return Envelope.newBuilder()
.setSourceDevice(random.nextInt(10000))
.setSource("testSource" + random.nextInt())
.setTimestamp(random.nextInt(100000))
.setServerTimestamp(random.nextInt(100000))
.setLegacyMessage(ByteString.copyFrom(legacy))
.setContent(ByteString.copyFrom(content))
.setType(Envelope.Type.CIPHERTEXT)
.setServerGuid(UUID.randomUUID().toString())
.build();
}
private static class MessageToStore {
private final UUID guid;
private final Envelope envelope;
private MessageToStore(UUID guid, Envelope envelope) {
this.guid = guid;
this.envelope = envelope;
}
}
}

View File

@ -0,0 +1,114 @@
package org.whispersystems.textsecuregcm.tests.storage;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
import com.opentable.db.postgres.junit.PreparedDbRule;
import org.jdbi.v3.core.Jdbi;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.storage.PendingAccounts;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Optional;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
public class PendingAccountsTest {
@Rule
public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
private PendingAccounts pendingAccounts;
@Before
public void setupAccountsDao() {
this.pendingAccounts = new PendingAccounts(Jdbi.create(db.getTestDatabase()));
}
@Test
public void testStore() throws SQLException {
pendingAccounts.insert("+14151112222", "1234", 1111);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM pending_accounts WHERE number = ?");
statement.setString(1, "+14151112222");
ResultSet resultSet = statement.executeQuery();
if (resultSet.next()) {
assertThat(resultSet.getString("verification_code")).isEqualTo("1234");
assertThat(resultSet.getLong("timestamp")).isEqualTo(1111);
} else {
throw new AssertionError("no results");
}
assertThat(resultSet.next()).isFalse();
}
@Test
public void testRetrieve() throws Exception {
pendingAccounts.insert("+14151112222", "4321", 2222);
pendingAccounts.insert("+14151113333", "1212", 5555);
Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("4321");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(2222);
Optional<StoredVerificationCode> missingCode = pendingAccounts.getCodeForNumber("+11111111111");
assertThat(missingCode.isPresent()).isFalse();
}
@Test
public void testOverwrite() throws Exception {
pendingAccounts.insert("+14151112222", "4321", 2222);
pendingAccounts.insert("+14151112222", "4444", 3333);
Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("4444");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(3333);
}
@Test
public void testVacuum() {
pendingAccounts.insert("+14151112222", "4321", 2222);
pendingAccounts.insert("+14151112222", "4444", 3333);
pendingAccounts.vacuum();
Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("4444");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(3333);
}
@Test
public void testRemove() {
pendingAccounts.insert("+14151112222", "4321", 2222);
pendingAccounts.insert("+14151113333", "1212", 5555);
Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("4321");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(2222);
pendingAccounts.remove("+14151112222");
verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isFalse();
verificationCode = pendingAccounts.getCodeForNumber("+14151113333");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("1212");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(5555);
}
}

View File

@ -0,0 +1,100 @@
package org.whispersystems.textsecuregcm.tests.storage;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
import com.opentable.db.postgres.junit.PreparedDbRule;
import org.jdbi.v3.core.Jdbi;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.storage.PendingDevices;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Optional;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
public class PendingDevicesTest {
@Rule
public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
private PendingDevices pendingDevices;
@Before
public void setupAccountsDao() {
this.pendingDevices = new PendingDevices(Jdbi.create(db.getTestDatabase()));
}
@Test
public void testStore() throws SQLException {
pendingDevices.insert("+14151112222", "1234", 1111);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM pending_devices WHERE number = ?");
statement.setString(1, "+14151112222");
ResultSet resultSet = statement.executeQuery();
if (resultSet.next()) {
assertThat(resultSet.getString("verification_code")).isEqualTo("1234");
assertThat(resultSet.getLong("timestamp")).isEqualTo(1111);
} else {
throw new AssertionError("no results");
}
assertThat(resultSet.next()).isFalse();
}
@Test
public void testRetrieve() throws Exception {
pendingDevices.insert("+14151112222", "4321", 2222);
pendingDevices.insert("+14151113333", "1212", 5555);
Optional<StoredVerificationCode> verificationCode = pendingDevices.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("4321");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(2222);
Optional<StoredVerificationCode> missingCode = pendingDevices.getCodeForNumber("+11111111111");
assertThat(missingCode.isPresent()).isFalse();
}
@Test
public void testOverwrite() throws Exception {
pendingDevices.insert("+14151112222", "4321", 2222);
pendingDevices.insert("+14151112222", "4444", 3333);
Optional<StoredVerificationCode> verificationCode = pendingDevices.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("4444");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(3333);
}
@Test
public void testRemove() {
pendingDevices.insert("+14151112222", "4321", 2222);
pendingDevices.insert("+14151113333", "1212", 5555);
Optional<StoredVerificationCode> verificationCode = pendingDevices.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("4321");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(2222);
pendingDevices.remove("+14151112222");
verificationCode = pendingDevices.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isFalse();
verificationCode = pendingDevices.getCodeForNumber("+14151113333");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("1212");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(5555);
}
}