Eliminate last vestiges of "last resort" key stuff

This commit is contained in:
Moxie Marlinspike 2019-03-26 18:00:53 -07:00
parent 77142eb2df
commit 890b0ac301
6 changed files with 371 additions and 97 deletions

View File

@ -166,6 +166,12 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.opentable.components</groupId>
<artifactId>otj-pg-embedded</artifactId>
<version>0.13.1</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>

View File

@ -1,4 +1,4 @@
/** /*
* Copyright (C) 2014 Open Whisper Systems * Copyright (C) 2014 Open Whisper Systems
* *
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
@ -135,19 +135,17 @@ public class KeysController {
rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "__" + number + "." + deviceId); rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "__" + number + "." + deviceId);
} }
Optional<List<KeyRecord>> targetKeys = getLocalKeys(target.get(), deviceId); List<KeyRecord> targetKeys = getLocalKeys(target.get(), deviceId);
List<PreKeyResponseItem> devices = new LinkedList<>(); List<PreKeyResponseItem> devices = new LinkedList<>();
for (Device device : target.get().getDevices()) { for (Device device : target.get().getDevices()) {
if (device.isActive() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) { if (device.isActive() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) {
SignedPreKey signedPreKey = device.getSignedPreKey(); SignedPreKey signedPreKey = device.getSignedPreKey();
PreKey preKey = null; PreKey preKey = null;
if (targetKeys.isPresent()) { for (KeyRecord keyRecord : targetKeys) {
for (KeyRecord keyRecord : targetKeys.get()) { if (keyRecord.getDeviceId() == device.getId()) {
if (!keyRecord.isLastResort() && keyRecord.getDeviceId() == device.getId()) {
preKey = new PreKey(keyRecord.getKeyId(), keyRecord.getPublicKey()); preKey = new PreKey(keyRecord.getKeyId(), keyRecord.getPublicKey());
}
} }
} }
@ -189,7 +187,7 @@ public class KeysController {
else return Optional.empty(); else return Optional.empty();
} }
private Optional<List<KeyRecord>> getLocalKeys(Account destination, String deviceIdSelector) { private List<KeyRecord> getLocalKeys(Account destination, String deviceIdSelector) {
try { try {
if (deviceIdSelector.equals("*")) { if (deviceIdSelector.equals("*")) {
return keys.get(destination.getNumber()); return keys.get(destination.getNumber());

View File

@ -7,17 +7,13 @@ public class KeyRecord {
private long deviceId; private long deviceId;
private long keyId; private long keyId;
private String publicKey; private String publicKey;
private boolean lastResort;
public KeyRecord(long id, String number, long deviceId, long keyId, public KeyRecord(long id, String number, long deviceId, long keyId, String publicKey) {
String publicKey, boolean lastResort)
{
this.id = id; this.id = id;
this.number = number; this.number = number;
this.deviceId = deviceId; this.deviceId = deviceId;
this.keyId = keyId; this.keyId = keyId;
this.publicKey = publicKey; this.publicKey = publicKey;
this.lastResort = lastResort;
} }
public long getId() { public long getId() {
@ -40,7 +36,4 @@ public class KeyRecord {
return publicKey; return publicKey;
} }
public boolean isLastResort() {
return lastResort;
}
} }

View File

@ -1,4 +1,4 @@
/** /*
* Copyright (C) 2013 Open WhisperSystems * Copyright (C) 2013 Open WhisperSystems
* *
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
@ -38,97 +38,69 @@ import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
import java.sql.ResultSet; import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.stream.Collectors;
public abstract class Keys { public abstract class Keys {
@SqlUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id") @SqlUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id")
abstract void removeKeys(@Bind("number") String number, @Bind("device_id") long deviceId); abstract void remove(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlUpdate("DELETE FROM keys WHERE id = :id") @SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key) VALUES (:number, :device_id, :key_id, :public_key)")
abstract void removeKey(@Bind("id") long id); abstract void append(@KeyRecordBinder List<KeyRecord> preKeys);
@SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key, last_resort) VALUES " + @SqlQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *")
"(:number, :device_id, :key_id, :public_key, :last_resort)") @Mapper(KeyRecordMapper.class)
abstract void append(@PreKeyBinder List<KeyRecord> preKeys); abstract List<KeyRecord> getInternal(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlQuery("SELECT * FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC FOR UPDATE") @SqlQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *")
@Mapper(PreKeyMapper.class) @Mapper(KeyRecordMapper.class)
abstract KeyRecord retrieveFirst(@Bind("number") String number, @Bind("device_id") long deviceId); abstract List<KeyRecord> getInternal(@Bind("number") String number);
@SqlQuery("SELECT DISTINCT ON (number, device_id) * FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC")
@Mapper(PreKeyMapper.class)
abstract List<KeyRecord> retrieveFirst(@Bind("number") String number);
@SqlQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id") @SqlQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id")
public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId); public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId);
// Apparently transaction annotations don't work on the annotated query methods
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public List<KeyRecord> get(String number) {
return getInternal(number);
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public List<KeyRecord> get(String number, long deviceId) {
return getInternal(number, deviceId);
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE) @Transaction(TransactionIsolationLevel.SERIALIZABLE)
public void store(String number, long deviceId, List<PreKey> keys) { public void store(String number, long deviceId, List<PreKey> keys) {
List<KeyRecord> records = new LinkedList<>(); List<KeyRecord> records = keys.stream()
.map(key -> new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey()))
.collect(Collectors.toList());
for (PreKey key : keys) { remove(number, deviceId);
records.add(new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey(), false));
}
removeKeys(number, deviceId);
append(records); append(records);
} }
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public Optional<List<KeyRecord>> get(String number, long deviceId) {
final KeyRecord record = retrieveFirst(number, deviceId);
if (record != null && !record.isLastResort()) {
removeKey(record.getId());
} else if (record == null) {
return Optional.empty();
}
List<KeyRecord> results = new LinkedList<>();
results.add(record);
return Optional.of(results);
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public Optional<List<KeyRecord>> get(String number) {
List<KeyRecord> preKeys = retrieveFirst(number);
if (preKeys != null) {
for (KeyRecord preKey : preKeys) {
if (!preKey.isLastResort()) {
removeKey(preKey.getId());
}
}
}
if (preKeys != null) return Optional.of(preKeys);
else return Optional.empty();
}
@SqlUpdate("VACUUM keys") @SqlUpdate("VACUUM keys")
public abstract void vacuum(); public abstract void vacuum();
@BindingAnnotation(PreKeyBinder.PreKeyBinderFactory.class) @BindingAnnotation(KeyRecordBinder.PreKeyBinderFactory.class)
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.PARAMETER}) @Target({ElementType.PARAMETER})
public @interface PreKeyBinder { public @interface KeyRecordBinder {
public static class PreKeyBinderFactory implements BinderFactory { public static class PreKeyBinderFactory implements BinderFactory {
@Override @Override
public Binder build(Annotation annotation) { public Binder build(Annotation annotation) {
return new Binder<PreKeyBinder, KeyRecord>() { return new Binder<KeyRecordBinder, KeyRecord>() {
@Override @Override
public void bind(SQLStatement<?> sql, PreKeyBinder accountBinder, KeyRecord record) public void bind(SQLStatement<?> sql, KeyRecordBinder keyRecordBinder, KeyRecord record)
{ {
sql.bind("id", record.getId()); sql.bind("id", record.getId());
sql.bind("number", record.getNumber()); sql.bind("number", record.getNumber());
sql.bind("device_id", record.getDeviceId()); sql.bind("device_id", record.getDeviceId());
sql.bind("key_id", record.getKeyId()); sql.bind("key_id", record.getKeyId());
sql.bind("public_key", record.getPublicKey()); sql.bind("public_key", record.getPublicKey());
sql.bind("last_resort", record.isLastResort() ? 1 : 0);
} }
}; };
} }
@ -136,14 +108,14 @@ public abstract class Keys {
} }
public static class PreKeyMapper implements ResultSetMapper<KeyRecord> { public static class KeyRecordMapper implements ResultSetMapper<KeyRecord> {
@Override @Override
public KeyRecord map(int i, ResultSet resultSet, StatementContext statementContext) public KeyRecord map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException throws SQLException
{ {
return new KeyRecord(resultSet.getLong("id"), resultSet.getString("number"), return new KeyRecord(resultSet.getLong("id"), resultSet.getString("number"),
resultSet.getLong("device_id"), resultSet.getLong("key_id"), resultSet.getLong("device_id"), resultSet.getLong("key_id"),
resultSet.getString("public_key"), resultSet.getInt("last_resort") == 1); resultSet.getString("public_key"));
} }
} }

View File

@ -45,15 +45,16 @@ public class KeyControllerTest {
private static int SAMPLE_REGISTRATION_ID2 = 1002; private static int SAMPLE_REGISTRATION_ID2 = 1002;
private static int SAMPLE_REGISTRATION_ID4 = 1555; private static int SAMPLE_REGISTRATION_ID4 = 1555;
private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", false); private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1");
private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3", false ); private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3");
private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5", false ); private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5");
private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6", false ); private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6");
private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey(1111, "foofoo", "sig11"); private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey( 1111, "foofoo", "sig11" );
private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey(2222, "foobar", "sig22"); private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey( 2222, "foobar", "sig22" );
private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey(3333, "barfoo", "sig33"); private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey( 3333, "barfoo", "sig33" );
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = new SignedPreKey(89898, "zoofarb", "sigvalid");
private final Keys keys = mock(Keys.class ); private final Keys keys = mock(Keys.class );
private final AccountsManager accounts = mock(AccountsManager.class); private final AccountsManager accounts = mock(AccountsManager.class);
@ -120,20 +121,20 @@ public class KeyControllerTest {
List<KeyRecord> singleDevice = new LinkedList<>(); List<KeyRecord> singleDevice = new LinkedList<>();
singleDevice.add(SAMPLE_KEY); singleDevice.add(SAMPLE_KEY);
when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(Optional.of(singleDevice)); when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(singleDevice);
when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.<List<KeyRecord>>empty()); when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(new LinkedList<>());
List<KeyRecord> multiDevice = new LinkedList<>(); List<KeyRecord> multiDevice = new LinkedList<>();
multiDevice.add(SAMPLE_KEY); multiDevice.add(SAMPLE_KEY);
multiDevice.add(SAMPLE_KEY2); multiDevice.add(SAMPLE_KEY2);
multiDevice.add(SAMPLE_KEY3); multiDevice.add(SAMPLE_KEY3);
multiDevice.add(SAMPLE_KEY4); multiDevice.add(SAMPLE_KEY4);
when(keys.get(EXISTS_NUMBER)).thenReturn(Optional.of(multiDevice)); when(keys.get(EXISTS_NUMBER)).thenReturn(multiDevice);
when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5); when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5);
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(new SignedPreKey(89898, "zoofarb", "sigvalid")); when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null); when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null);
} }
@ -146,7 +147,7 @@ public class KeyControllerTest {
AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyCount.class); .get(PreKeyCount.class);
assertThat(result.getCount() == 4); assertThat(result.getCount()).isEqualTo(4);
verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L)); verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L));
} }
@ -159,7 +160,9 @@ public class KeyControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class); .get(SignedPreKey.class);
assertThat(result.equals(SAMPLE_SIGNED_KEY)); assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getSignature());
assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getPublicKey());
} }
@Test @Test
@ -171,7 +174,7 @@ public class KeyControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(test, MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(test, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus() == 204); assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test)); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT)); verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
@ -332,10 +335,9 @@ public class KeyControllerTest {
@Test @Test
public void putKeysTestV2() throws Exception { public void putKeysTestV2() throws Exception {
final PreKey preKey = new PreKey(31337, "foobar"); final PreKey preKey = new PreKey(31337, "foobar");
final PreKey lastResortKey = new PreKey(31339, "barbar"); final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig"); final String identityKey = "barbar";
final String identityKey = "barbar";
List<PreKey> preKeys = new LinkedList<PreKey>() {{ List<PreKey> preKeys = new LinkedList<PreKey>() {{
add(preKey); add(preKey);
@ -356,9 +358,9 @@ public class KeyControllerTest {
verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture()); verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue(); List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size() == 1); assertThat(capturedList.size()).isEqualTo(1);
assertThat(capturedList.get(0).getKeyId() == 31337); assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337);
assertThat(capturedList.get(0).getPublicKey().equals("foobar")); assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar");
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar")); verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar"));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));

View File

@ -0,0 +1,303 @@
package org.whispersystems.textsecuregcm.tests.storage;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
import com.opentable.db.postgres.junit.PreparedDbRule;
import org.junit.Rule;
import org.junit.Test;
import org.skife.jdbi.v2.DBI;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import org.whispersystems.textsecuregcm.storage.Keys;
import javax.sql.DataSource;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.LinkedList;
import java.util.List;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
public class KeysTest {
@Rule
public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
@Test
public void testPopulateKeys() throws SQLException {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
List<PreKey> oldAnotherDeviceOnePrKeys = new LinkedList<>();
List<PreKey> anotherDeviceOnePreKeys = new LinkedList<>();
List<PreKey> anotherDeviceTwoPreKeys = new LinkedList<>();
for (int i=1;i<=100;i++) {
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
}
for (int i=1;i<=100;i++) {
oldAnotherDeviceOnePrKeys.add(new PreKey(i, "OldPublicKey" + i));
anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
}
keys.store("+14152222222", 1, deviceOnePreKeys);
keys.store("+14152222222", 2, deviceTwoPreKeys);
keys.store("+14151111111", 1, oldAnotherDeviceOnePrKeys);
keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
PreparedStatement statement = dataSource.getConnection().prepareStatement("SELECT * FROM keys WHERE number = ? AND device_id = ? ORDER BY key_id");
verifyStoredState(statement, "+14152222222", 1);
verifyStoredState(statement, "+14152222222", 2);
verifyStoredState(statement, "+14151111111", 1);
verifyStoredState(statement, "+14151111111", 2);
}
@Test
public void testKeyCount() throws SQLException {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
for (int i=1;i<=100;i++) {
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
}
keys.store("+14152222222", 1, deviceOnePreKeys);
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
}
@Test
public void testGetForDevice() {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
List<PreKey> anotherDeviceOnePreKeys = new LinkedList<>();
List<PreKey> anotherDeviceTwoPreKeys = new LinkedList<>();
for (int i=1;i<=100;i++) {
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
}
for (int i=1;i<=100;i++) {
anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
}
keys.store("+14152222222", 1, deviceOnePreKeys);
keys.store("+14152222222", 2, deviceTwoPreKeys);
keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
List<KeyRecord> records = keys.get("+14152222222", 1);
assertThat(records.size()).isEqualTo(1);
assertThat(records.get(0).getKeyId()).isEqualTo(1);
assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey1");
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
records = keys.get("+14152222222", 1);
assertThat(records.size()).isEqualTo(1);
assertThat(records.get(0).getKeyId()).isEqualTo(2);
assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey2");
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
records = keys.get("+14152222222", 2);
assertThat(records.size()).isEqualTo(1);
assertThat(records.get(0).getKeyId()).isEqualTo(1);
assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device2PublicKey1");
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99);
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
}
@Test
public void testGetForAllDevices() {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
List<PreKey> anotherDeviceOnePreKeys = new LinkedList<>();
List<PreKey> anotherDeviceTwoPreKeys = new LinkedList<>();
List<PreKey> anotherDeviceThreePreKeys = new LinkedList<>();
for (int i=1;i<=100;i++) {
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
}
for (int i=1;i<=100;i++) {
anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i));
}
keys.store("+14152222222", 1, deviceOnePreKeys);
keys.store("+14152222222", 2, deviceTwoPreKeys);
keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
keys.store("+14151111111", 3, anotherDeviceThreePreKeys);
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
List<KeyRecord> records = keys.get("+14152222222");
assertThat(records.size()).isEqualTo(2);
assertThat(records.get(0).getKeyId()).isEqualTo(1);
assertThat(records.get(1).getKeyId()).isEqualTo(1);
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey1"))).isTrue();
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey1"))).isTrue();
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99);
records = keys.get("+14152222222");
assertThat(records.size()).isEqualTo(2);
assertThat(records.get(0).getKeyId()).isEqualTo(2);
assertThat(records.get(1).getKeyId()).isEqualTo(2);
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey2"))).isTrue();
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey2"))).isTrue();
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(98);
records = keys.get("+14151111111");
assertThat(records.size()).isEqualTo(3);
assertThat(records.get(0).getKeyId()).isEqualTo(1);
assertThat(records.get(1).getKeyId()).isEqualTo(1);
assertThat(records.get(2).getKeyId()).isEqualTo(1);
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device1PublicKey1"))).isTrue();
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device2PublicKey1"))).isTrue();
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device3PublicKey1"))).isTrue();
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(99);
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(99);
assertThat(keys.getCount("+14151111111", 3)).isEqualTo(99);
}
@Test
public void testGetForAllDevicesParallel() throws InterruptedException {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
for (int i=1;i<=100;i++) {
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
}
keys.store("+14152222222", 1, deviceOnePreKeys);
keys.store("+14152222222", 2, deviceTwoPreKeys);
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
List<Thread> threads = new LinkedList<>();
for (int i=0;i<50;i++) {
Thread thread = new Thread(() -> {
for (int j=0;j<10;j++) {
try {
List<KeyRecord> results = keys.get("+14152222222");
assertThat(results.size()).isEqualTo(2);
return;
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
throw new AssertionError();
});
thread.start();
threads.add(thread);
}
for (Thread thread : threads) {
thread.join();
}
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(50);
assertThat(keys.getCount("+14152222222",2)).isEqualTo(50);
}
@Test
public void testEmptyKeyGet() {
DBI dbi = new DBI(db.getTestDatabase());
Keys keys = dbi.onDemand(Keys.class);
List<KeyRecord> records = keys.get("+14152222222");
assertThat(records.isEmpty()).isTrue();
}
private void verifyStoredState(PreparedStatement statement, String number, int deviceId) throws SQLException {
statement.setString(1, number);
statement.setInt(2, deviceId);
ResultSet resultSet = statement.executeQuery();
int rowCount = 1;
while (resultSet.next()) {
long keyId = resultSet.getLong("key_id");
String publicKey = resultSet.getString("public_key");
assertThat(keyId).isEqualTo(rowCount);
assertThat(publicKey).isEqualTo(number + "Device" + deviceId + "PublicKey" + rowCount);
rowCount++;
}
resultSet.close();
assertThat(rowCount).isEqualTo(101);
}
}