Eliminate last vestiges of "last resort" key stuff
This commit is contained in:
parent
77142eb2df
commit
890b0ac301
6
pom.xml
6
pom.xml
|
@ -166,6 +166,12 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.opentable.components</groupId>
|
||||
<artifactId>otj-pg-embedded</artifactId>
|
||||
<version>0.13.1</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/**
|
||||
/*
|
||||
* Copyright (C) 2014 Open Whisper Systems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
|
@ -135,19 +135,17 @@ public class KeysController {
|
|||
rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "__" + number + "." + deviceId);
|
||||
}
|
||||
|
||||
Optional<List<KeyRecord>> targetKeys = getLocalKeys(target.get(), deviceId);
|
||||
List<PreKeyResponseItem> devices = new LinkedList<>();
|
||||
List<KeyRecord> targetKeys = getLocalKeys(target.get(), deviceId);
|
||||
List<PreKeyResponseItem> devices = new LinkedList<>();
|
||||
|
||||
for (Device device : target.get().getDevices()) {
|
||||
if (device.isActive() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) {
|
||||
SignedPreKey signedPreKey = device.getSignedPreKey();
|
||||
PreKey preKey = null;
|
||||
|
||||
if (targetKeys.isPresent()) {
|
||||
for (KeyRecord keyRecord : targetKeys.get()) {
|
||||
if (!keyRecord.isLastResort() && keyRecord.getDeviceId() == device.getId()) {
|
||||
for (KeyRecord keyRecord : targetKeys) {
|
||||
if (keyRecord.getDeviceId() == device.getId()) {
|
||||
preKey = new PreKey(keyRecord.getKeyId(), keyRecord.getPublicKey());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -189,7 +187,7 @@ public class KeysController {
|
|||
else return Optional.empty();
|
||||
}
|
||||
|
||||
private Optional<List<KeyRecord>> getLocalKeys(Account destination, String deviceIdSelector) {
|
||||
private List<KeyRecord> getLocalKeys(Account destination, String deviceIdSelector) {
|
||||
try {
|
||||
if (deviceIdSelector.equals("*")) {
|
||||
return keys.get(destination.getNumber());
|
||||
|
|
|
@ -7,17 +7,13 @@ public class KeyRecord {
|
|||
private long deviceId;
|
||||
private long keyId;
|
||||
private String publicKey;
|
||||
private boolean lastResort;
|
||||
|
||||
public KeyRecord(long id, String number, long deviceId, long keyId,
|
||||
String publicKey, boolean lastResort)
|
||||
{
|
||||
public KeyRecord(long id, String number, long deviceId, long keyId, String publicKey) {
|
||||
this.id = id;
|
||||
this.number = number;
|
||||
this.deviceId = deviceId;
|
||||
this.keyId = keyId;
|
||||
this.publicKey = publicKey;
|
||||
this.lastResort = lastResort;
|
||||
}
|
||||
|
||||
public long getId() {
|
||||
|
@ -40,7 +36,4 @@ public class KeyRecord {
|
|||
return publicKey;
|
||||
}
|
||||
|
||||
public boolean isLastResort() {
|
||||
return lastResort;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/**
|
||||
/*
|
||||
* Copyright (C) 2013 Open WhisperSystems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
|
@ -38,97 +38,69 @@ import java.lang.annotation.RetentionPolicy;
|
|||
import java.lang.annotation.Target;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public abstract class Keys {
|
||||
|
||||
@SqlUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id")
|
||||
abstract void removeKeys(@Bind("number") String number, @Bind("device_id") long deviceId);
|
||||
abstract void remove(@Bind("number") String number, @Bind("device_id") long deviceId);
|
||||
|
||||
@SqlUpdate("DELETE FROM keys WHERE id = :id")
|
||||
abstract void removeKey(@Bind("id") long id);
|
||||
@SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key) VALUES (:number, :device_id, :key_id, :public_key)")
|
||||
abstract void append(@KeyRecordBinder List<KeyRecord> preKeys);
|
||||
|
||||
@SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key, last_resort) VALUES " +
|
||||
"(:number, :device_id, :key_id, :public_key, :last_resort)")
|
||||
abstract void append(@PreKeyBinder List<KeyRecord> preKeys);
|
||||
@SqlQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *")
|
||||
@Mapper(KeyRecordMapper.class)
|
||||
abstract List<KeyRecord> getInternal(@Bind("number") String number, @Bind("device_id") long deviceId);
|
||||
|
||||
@SqlQuery("SELECT * FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC FOR UPDATE")
|
||||
@Mapper(PreKeyMapper.class)
|
||||
abstract KeyRecord retrieveFirst(@Bind("number") String number, @Bind("device_id") long deviceId);
|
||||
|
||||
@SqlQuery("SELECT DISTINCT ON (number, device_id) * FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC")
|
||||
@Mapper(PreKeyMapper.class)
|
||||
abstract List<KeyRecord> retrieveFirst(@Bind("number") String number);
|
||||
@SqlQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *")
|
||||
@Mapper(KeyRecordMapper.class)
|
||||
abstract List<KeyRecord> getInternal(@Bind("number") String number);
|
||||
|
||||
@SqlQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id")
|
||||
public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId);
|
||||
|
||||
// Apparently transaction annotations don't work on the annotated query methods
|
||||
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
|
||||
public List<KeyRecord> get(String number) {
|
||||
return getInternal(number);
|
||||
}
|
||||
|
||||
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
|
||||
public List<KeyRecord> get(String number, long deviceId) {
|
||||
return getInternal(number, deviceId);
|
||||
}
|
||||
|
||||
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
|
||||
public void store(String number, long deviceId, List<PreKey> keys) {
|
||||
List<KeyRecord> records = new LinkedList<>();
|
||||
List<KeyRecord> records = keys.stream()
|
||||
.map(key -> new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
for (PreKey key : keys) {
|
||||
records.add(new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey(), false));
|
||||
}
|
||||
|
||||
removeKeys(number, deviceId);
|
||||
remove(number, deviceId);
|
||||
append(records);
|
||||
}
|
||||
|
||||
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
|
||||
public Optional<List<KeyRecord>> get(String number, long deviceId) {
|
||||
final KeyRecord record = retrieveFirst(number, deviceId);
|
||||
|
||||
if (record != null && !record.isLastResort()) {
|
||||
removeKey(record.getId());
|
||||
} else if (record == null) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
List<KeyRecord> results = new LinkedList<>();
|
||||
results.add(record);
|
||||
|
||||
return Optional.of(results);
|
||||
}
|
||||
|
||||
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
|
||||
public Optional<List<KeyRecord>> get(String number) {
|
||||
List<KeyRecord> preKeys = retrieveFirst(number);
|
||||
|
||||
if (preKeys != null) {
|
||||
for (KeyRecord preKey : preKeys) {
|
||||
if (!preKey.isLastResort()) {
|
||||
removeKey(preKey.getId());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (preKeys != null) return Optional.of(preKeys);
|
||||
else return Optional.empty();
|
||||
}
|
||||
|
||||
@SqlUpdate("VACUUM keys")
|
||||
public abstract void vacuum();
|
||||
|
||||
@BindingAnnotation(PreKeyBinder.PreKeyBinderFactory.class)
|
||||
@BindingAnnotation(KeyRecordBinder.PreKeyBinderFactory.class)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Target({ElementType.PARAMETER})
|
||||
public @interface PreKeyBinder {
|
||||
public @interface KeyRecordBinder {
|
||||
public static class PreKeyBinderFactory implements BinderFactory {
|
||||
@Override
|
||||
public Binder build(Annotation annotation) {
|
||||
return new Binder<PreKeyBinder, KeyRecord>() {
|
||||
return new Binder<KeyRecordBinder, KeyRecord>() {
|
||||
@Override
|
||||
public void bind(SQLStatement<?> sql, PreKeyBinder accountBinder, KeyRecord record)
|
||||
public void bind(SQLStatement<?> sql, KeyRecordBinder keyRecordBinder, KeyRecord record)
|
||||
{
|
||||
sql.bind("id", record.getId());
|
||||
sql.bind("number", record.getNumber());
|
||||
sql.bind("device_id", record.getDeviceId());
|
||||
sql.bind("key_id", record.getKeyId());
|
||||
sql.bind("public_key", record.getPublicKey());
|
||||
sql.bind("last_resort", record.isLastResort() ? 1 : 0);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -136,14 +108,14 @@ public abstract class Keys {
|
|||
}
|
||||
|
||||
|
||||
public static class PreKeyMapper implements ResultSetMapper<KeyRecord> {
|
||||
public static class KeyRecordMapper implements ResultSetMapper<KeyRecord> {
|
||||
@Override
|
||||
public KeyRecord map(int i, ResultSet resultSet, StatementContext statementContext)
|
||||
throws SQLException
|
||||
{
|
||||
return new KeyRecord(resultSet.getLong("id"), resultSet.getString("number"),
|
||||
resultSet.getLong("device_id"), resultSet.getLong("key_id"),
|
||||
resultSet.getString("public_key"), resultSet.getInt("last_resort") == 1);
|
||||
resultSet.getString("public_key"));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -45,15 +45,16 @@ public class KeyControllerTest {
|
|||
private static int SAMPLE_REGISTRATION_ID2 = 1002;
|
||||
private static int SAMPLE_REGISTRATION_ID4 = 1555;
|
||||
|
||||
private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", false);
|
||||
private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3", false );
|
||||
private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5", false );
|
||||
private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6", false );
|
||||
private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1");
|
||||
private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3");
|
||||
private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5");
|
||||
private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6");
|
||||
|
||||
|
||||
private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey(1111, "foofoo", "sig11");
|
||||
private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey(2222, "foobar", "sig22");
|
||||
private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey(3333, "barfoo", "sig33");
|
||||
private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey( 1111, "foofoo", "sig11" );
|
||||
private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey( 2222, "foobar", "sig22" );
|
||||
private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey( 3333, "barfoo", "sig33" );
|
||||
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = new SignedPreKey(89898, "zoofarb", "sigvalid");
|
||||
|
||||
private final Keys keys = mock(Keys.class );
|
||||
private final AccountsManager accounts = mock(AccountsManager.class);
|
||||
|
@ -120,20 +121,20 @@ public class KeyControllerTest {
|
|||
|
||||
List<KeyRecord> singleDevice = new LinkedList<>();
|
||||
singleDevice.add(SAMPLE_KEY);
|
||||
when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(Optional.of(singleDevice));
|
||||
when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(singleDevice);
|
||||
|
||||
when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.<List<KeyRecord>>empty());
|
||||
when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(new LinkedList<>());
|
||||
|
||||
List<KeyRecord> multiDevice = new LinkedList<>();
|
||||
multiDevice.add(SAMPLE_KEY);
|
||||
multiDevice.add(SAMPLE_KEY2);
|
||||
multiDevice.add(SAMPLE_KEY3);
|
||||
multiDevice.add(SAMPLE_KEY4);
|
||||
when(keys.get(EXISTS_NUMBER)).thenReturn(Optional.of(multiDevice));
|
||||
when(keys.get(EXISTS_NUMBER)).thenReturn(multiDevice);
|
||||
|
||||
when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5);
|
||||
|
||||
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(new SignedPreKey(89898, "zoofarb", "sigvalid"));
|
||||
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY);
|
||||
when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null);
|
||||
}
|
||||
|
||||
|
@ -146,7 +147,7 @@ public class KeyControllerTest {
|
|||
AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
|
||||
.get(PreKeyCount.class);
|
||||
|
||||
assertThat(result.getCount() == 4);
|
||||
assertThat(result.getCount()).isEqualTo(4);
|
||||
|
||||
verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L));
|
||||
}
|
||||
|
@ -159,7 +160,9 @@ public class KeyControllerTest {
|
|||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
|
||||
.get(SignedPreKey.class);
|
||||
|
||||
assertThat(result.equals(SAMPLE_SIGNED_KEY));
|
||||
assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getSignature());
|
||||
assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getKeyId());
|
||||
assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getPublicKey());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -171,7 +174,7 @@ public class KeyControllerTest {
|
|||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
|
||||
.put(Entity.entity(test, MediaType.APPLICATION_JSON_TYPE));
|
||||
|
||||
assertThat(response.getStatus() == 204);
|
||||
assertThat(response.getStatus()).isEqualTo(204);
|
||||
|
||||
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
|
||||
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
|
||||
|
@ -332,10 +335,9 @@ public class KeyControllerTest {
|
|||
|
||||
@Test
|
||||
public void putKeysTestV2() throws Exception {
|
||||
final PreKey preKey = new PreKey(31337, "foobar");
|
||||
final PreKey lastResortKey = new PreKey(31339, "barbar");
|
||||
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
|
||||
final String identityKey = "barbar";
|
||||
final PreKey preKey = new PreKey(31337, "foobar");
|
||||
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
|
||||
final String identityKey = "barbar";
|
||||
|
||||
List<PreKey> preKeys = new LinkedList<PreKey>() {{
|
||||
add(preKey);
|
||||
|
@ -356,9 +358,9 @@ public class KeyControllerTest {
|
|||
verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture());
|
||||
|
||||
List<PreKey> capturedList = listCaptor.getValue();
|
||||
assertThat(capturedList.size() == 1);
|
||||
assertThat(capturedList.get(0).getKeyId() == 31337);
|
||||
assertThat(capturedList.get(0).getPublicKey().equals("foobar"));
|
||||
assertThat(capturedList.size()).isEqualTo(1);
|
||||
assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337);
|
||||
assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar");
|
||||
|
||||
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar"));
|
||||
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
|
||||
|
|
|
@ -0,0 +1,303 @@
|
|||
package org.whispersystems.textsecuregcm.tests.storage;
|
||||
|
||||
import com.opentable.db.postgres.embedded.LiquibasePreparer;
|
||||
import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
|
||||
import com.opentable.db.postgres.junit.PreparedDbRule;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.skife.jdbi.v2.DBI;
|
||||
import org.whispersystems.textsecuregcm.entities.PreKey;
|
||||
import org.whispersystems.textsecuregcm.storage.KeyRecord;
|
||||
import org.whispersystems.textsecuregcm.storage.Keys;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
|
||||
|
||||
public class KeysTest {
|
||||
|
||||
@Rule
|
||||
public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
|
||||
|
||||
@Test
|
||||
public void testPopulateKeys() throws SQLException {
|
||||
DataSource dataSource = db.getTestDatabase();
|
||||
DBI dbi = new DBI(dataSource);
|
||||
Keys keys = dbi.onDemand(Keys.class);
|
||||
|
||||
List<PreKey> deviceOnePreKeys = new LinkedList<>();
|
||||
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
|
||||
|
||||
List<PreKey> oldAnotherDeviceOnePrKeys = new LinkedList<>();
|
||||
List<PreKey> anotherDeviceOnePreKeys = new LinkedList<>();
|
||||
List<PreKey> anotherDeviceTwoPreKeys = new LinkedList<>();
|
||||
|
||||
for (int i=1;i<=100;i++) {
|
||||
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
|
||||
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
|
||||
}
|
||||
|
||||
for (int i=1;i<=100;i++) {
|
||||
oldAnotherDeviceOnePrKeys.add(new PreKey(i, "OldPublicKey" + i));
|
||||
anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
|
||||
anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
|
||||
}
|
||||
|
||||
keys.store("+14152222222", 1, deviceOnePreKeys);
|
||||
keys.store("+14152222222", 2, deviceTwoPreKeys);
|
||||
|
||||
keys.store("+14151111111", 1, oldAnotherDeviceOnePrKeys);
|
||||
keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
|
||||
keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
|
||||
|
||||
PreparedStatement statement = dataSource.getConnection().prepareStatement("SELECT * FROM keys WHERE number = ? AND device_id = ? ORDER BY key_id");
|
||||
verifyStoredState(statement, "+14152222222", 1);
|
||||
verifyStoredState(statement, "+14152222222", 2);
|
||||
verifyStoredState(statement, "+14151111111", 1);
|
||||
verifyStoredState(statement, "+14151111111", 2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKeyCount() throws SQLException {
|
||||
DataSource dataSource = db.getTestDatabase();
|
||||
DBI dbi = new DBI(dataSource);
|
||||
Keys keys = dbi.onDemand(Keys.class);
|
||||
|
||||
List<PreKey> deviceOnePreKeys = new LinkedList<>();
|
||||
|
||||
for (int i=1;i<=100;i++) {
|
||||
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
|
||||
}
|
||||
|
||||
|
||||
keys.store("+14152222222", 1, deviceOnePreKeys);
|
||||
|
||||
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetForDevice() {
|
||||
DataSource dataSource = db.getTestDatabase();
|
||||
DBI dbi = new DBI(dataSource);
|
||||
Keys keys = dbi.onDemand(Keys.class);
|
||||
|
||||
List<PreKey> deviceOnePreKeys = new LinkedList<>();
|
||||
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
|
||||
|
||||
List<PreKey> anotherDeviceOnePreKeys = new LinkedList<>();
|
||||
List<PreKey> anotherDeviceTwoPreKeys = new LinkedList<>();
|
||||
|
||||
for (int i=1;i<=100;i++) {
|
||||
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
|
||||
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
|
||||
}
|
||||
|
||||
for (int i=1;i<=100;i++) {
|
||||
anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
|
||||
anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
|
||||
}
|
||||
|
||||
keys.store("+14152222222", 1, deviceOnePreKeys);
|
||||
keys.store("+14152222222", 2, deviceTwoPreKeys);
|
||||
|
||||
keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
|
||||
keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
|
||||
|
||||
|
||||
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
|
||||
List<KeyRecord> records = keys.get("+14152222222", 1);
|
||||
|
||||
assertThat(records.size()).isEqualTo(1);
|
||||
assertThat(records.get(0).getKeyId()).isEqualTo(1);
|
||||
assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey1");
|
||||
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99);
|
||||
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
|
||||
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
|
||||
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
|
||||
|
||||
records = keys.get("+14152222222", 1);
|
||||
|
||||
assertThat(records.size()).isEqualTo(1);
|
||||
assertThat(records.get(0).getKeyId()).isEqualTo(2);
|
||||
assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey2");
|
||||
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
|
||||
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
|
||||
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
|
||||
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
|
||||
|
||||
records = keys.get("+14152222222", 2);
|
||||
|
||||
assertThat(records.size()).isEqualTo(1);
|
||||
assertThat(records.get(0).getKeyId()).isEqualTo(1);
|
||||
assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device2PublicKey1");
|
||||
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
|
||||
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99);
|
||||
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
|
||||
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetForAllDevices() {
|
||||
DataSource dataSource = db.getTestDatabase();
|
||||
DBI dbi = new DBI(dataSource);
|
||||
Keys keys = dbi.onDemand(Keys.class);
|
||||
|
||||
List<PreKey> deviceOnePreKeys = new LinkedList<>();
|
||||
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
|
||||
|
||||
List<PreKey> anotherDeviceOnePreKeys = new LinkedList<>();
|
||||
List<PreKey> anotherDeviceTwoPreKeys = new LinkedList<>();
|
||||
List<PreKey> anotherDeviceThreePreKeys = new LinkedList<>();
|
||||
|
||||
for (int i=1;i<=100;i++) {
|
||||
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
|
||||
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
|
||||
}
|
||||
|
||||
for (int i=1;i<=100;i++) {
|
||||
anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
|
||||
anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
|
||||
anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i));
|
||||
}
|
||||
|
||||
keys.store("+14152222222", 1, deviceOnePreKeys);
|
||||
keys.store("+14152222222", 2, deviceTwoPreKeys);
|
||||
|
||||
keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
|
||||
keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
|
||||
keys.store("+14151111111", 3, anotherDeviceThreePreKeys);
|
||||
|
||||
|
||||
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
|
||||
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
|
||||
|
||||
List<KeyRecord> records = keys.get("+14152222222");
|
||||
|
||||
assertThat(records.size()).isEqualTo(2);
|
||||
assertThat(records.get(0).getKeyId()).isEqualTo(1);
|
||||
assertThat(records.get(1).getKeyId()).isEqualTo(1);
|
||||
|
||||
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey1"))).isTrue();
|
||||
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey1"))).isTrue();
|
||||
|
||||
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99);
|
||||
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99);
|
||||
|
||||
records = keys.get("+14152222222");
|
||||
|
||||
assertThat(records.size()).isEqualTo(2);
|
||||
assertThat(records.get(0).getKeyId()).isEqualTo(2);
|
||||
assertThat(records.get(1).getKeyId()).isEqualTo(2);
|
||||
|
||||
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey2"))).isTrue();
|
||||
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey2"))).isTrue();
|
||||
|
||||
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
|
||||
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(98);
|
||||
|
||||
|
||||
records = keys.get("+14151111111");
|
||||
|
||||
assertThat(records.size()).isEqualTo(3);
|
||||
assertThat(records.get(0).getKeyId()).isEqualTo(1);
|
||||
assertThat(records.get(1).getKeyId()).isEqualTo(1);
|
||||
assertThat(records.get(2).getKeyId()).isEqualTo(1);
|
||||
|
||||
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device1PublicKey1"))).isTrue();
|
||||
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device2PublicKey1"))).isTrue();
|
||||
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device3PublicKey1"))).isTrue();
|
||||
|
||||
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(99);
|
||||
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(99);
|
||||
assertThat(keys.getCount("+14151111111", 3)).isEqualTo(99);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetForAllDevicesParallel() throws InterruptedException {
|
||||
DataSource dataSource = db.getTestDatabase();
|
||||
DBI dbi = new DBI(dataSource);
|
||||
Keys keys = dbi.onDemand(Keys.class);
|
||||
|
||||
List<PreKey> deviceOnePreKeys = new LinkedList<>();
|
||||
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
|
||||
|
||||
for (int i=1;i<=100;i++) {
|
||||
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
|
||||
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
|
||||
}
|
||||
|
||||
keys.store("+14152222222", 1, deviceOnePreKeys);
|
||||
keys.store("+14152222222", 2, deviceTwoPreKeys);
|
||||
|
||||
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
|
||||
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
|
||||
|
||||
List<Thread> threads = new LinkedList<>();
|
||||
|
||||
for (int i=0;i<50;i++) {
|
||||
Thread thread = new Thread(() -> {
|
||||
for (int j=0;j<10;j++) {
|
||||
try {
|
||||
List<KeyRecord> results = keys.get("+14152222222");
|
||||
assertThat(results.size()).isEqualTo(2);
|
||||
return;
|
||||
} catch (Exception e) {
|
||||
System.err.println(e.getMessage());
|
||||
}
|
||||
}
|
||||
throw new AssertionError();
|
||||
});
|
||||
thread.start();
|
||||
threads.add(thread);
|
||||
}
|
||||
|
||||
for (Thread thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(50);
|
||||
assertThat(keys.getCount("+14152222222",2)).isEqualTo(50);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEmptyKeyGet() {
|
||||
DBI dbi = new DBI(db.getTestDatabase());
|
||||
Keys keys = dbi.onDemand(Keys.class);
|
||||
|
||||
List<KeyRecord> records = keys.get("+14152222222");
|
||||
|
||||
assertThat(records.isEmpty()).isTrue();
|
||||
}
|
||||
|
||||
|
||||
private void verifyStoredState(PreparedStatement statement, String number, int deviceId) throws SQLException {
|
||||
statement.setString(1, number);
|
||||
statement.setInt(2, deviceId);
|
||||
|
||||
ResultSet resultSet = statement.executeQuery();
|
||||
int rowCount = 1;
|
||||
|
||||
while (resultSet.next()) {
|
||||
long keyId = resultSet.getLong("key_id");
|
||||
String publicKey = resultSet.getString("public_key");
|
||||
|
||||
|
||||
assertThat(keyId).isEqualTo(rowCount);
|
||||
assertThat(publicKey).isEqualTo(number + "Device" + deviceId + "PublicKey" + rowCount);
|
||||
|
||||
rowCount++;
|
||||
}
|
||||
|
||||
resultSet.close();
|
||||
|
||||
assertThat(rowCount).isEqualTo(101);
|
||||
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue