Eliminate last vestiges of "last resort" key stuff

This commit is contained in:
Moxie Marlinspike 2019-03-26 18:00:53 -07:00
parent 77142eb2df
commit 890b0ac301
6 changed files with 371 additions and 97 deletions

View File

@ -166,6 +166,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.opentable.components</groupId>
<artifactId>otj-pg-embedded</artifactId>
<version>0.13.1</version>
<scope>test</scope>
</dependency>
</dependencies>

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
@ -135,19 +135,17 @@ public class KeysController {
rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "__" + number + "." + deviceId);
}
Optional<List<KeyRecord>> targetKeys = getLocalKeys(target.get(), deviceId);
List<PreKeyResponseItem> devices = new LinkedList<>();
List<KeyRecord> targetKeys = getLocalKeys(target.get(), deviceId);
List<PreKeyResponseItem> devices = new LinkedList<>();
for (Device device : target.get().getDevices()) {
if (device.isActive() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) {
SignedPreKey signedPreKey = device.getSignedPreKey();
PreKey preKey = null;
if (targetKeys.isPresent()) {
for (KeyRecord keyRecord : targetKeys.get()) {
if (!keyRecord.isLastResort() && keyRecord.getDeviceId() == device.getId()) {
for (KeyRecord keyRecord : targetKeys) {
if (keyRecord.getDeviceId() == device.getId()) {
preKey = new PreKey(keyRecord.getKeyId(), keyRecord.getPublicKey());
}
}
}
@ -189,7 +187,7 @@ public class KeysController {
else return Optional.empty();
}
private Optional<List<KeyRecord>> getLocalKeys(Account destination, String deviceIdSelector) {
private List<KeyRecord> getLocalKeys(Account destination, String deviceIdSelector) {
try {
if (deviceIdSelector.equals("*")) {
return keys.get(destination.getNumber());

View File

@ -7,17 +7,13 @@ public class KeyRecord {
private long deviceId;
private long keyId;
private String publicKey;
private boolean lastResort;
public KeyRecord(long id, String number, long deviceId, long keyId,
String publicKey, boolean lastResort)
{
public KeyRecord(long id, String number, long deviceId, long keyId, String publicKey) {
this.id = id;
this.number = number;
this.deviceId = deviceId;
this.keyId = keyId;
this.publicKey = publicKey;
this.lastResort = lastResort;
}
public long getId() {
@ -40,7 +36,4 @@ public class KeyRecord {
return publicKey;
}
public boolean isLastResort() {
return lastResort;
}
}

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@ -38,97 +38,69 @@ import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
public abstract class Keys {
@SqlUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id")
abstract void removeKeys(@Bind("number") String number, @Bind("device_id") long deviceId);
abstract void remove(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlUpdate("DELETE FROM keys WHERE id = :id")
abstract void removeKey(@Bind("id") long id);
@SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key) VALUES (:number, :device_id, :key_id, :public_key)")
abstract void append(@KeyRecordBinder List<KeyRecord> preKeys);
@SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key, last_resort) VALUES " +
"(:number, :device_id, :key_id, :public_key, :last_resort)")
abstract void append(@PreKeyBinder List<KeyRecord> preKeys);
@SqlQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *")
@Mapper(KeyRecordMapper.class)
abstract List<KeyRecord> getInternal(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlQuery("SELECT * FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC FOR UPDATE")
@Mapper(PreKeyMapper.class)
abstract KeyRecord retrieveFirst(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlQuery("SELECT DISTINCT ON (number, device_id) * FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC")
@Mapper(PreKeyMapper.class)
abstract List<KeyRecord> retrieveFirst(@Bind("number") String number);
@SqlQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *")
@Mapper(KeyRecordMapper.class)
abstract List<KeyRecord> getInternal(@Bind("number") String number);
@SqlQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id")
public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId);
// Apparently transaction annotations don't work on the annotated query methods
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public List<KeyRecord> get(String number) {
return getInternal(number);
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public List<KeyRecord> get(String number, long deviceId) {
return getInternal(number, deviceId);
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public void store(String number, long deviceId, List<PreKey> keys) {
List<KeyRecord> records = new LinkedList<>();
List<KeyRecord> records = keys.stream()
.map(key -> new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey()))
.collect(Collectors.toList());
for (PreKey key : keys) {
records.add(new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey(), false));
}
removeKeys(number, deviceId);
remove(number, deviceId);
append(records);
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public Optional<List<KeyRecord>> get(String number, long deviceId) {
final KeyRecord record = retrieveFirst(number, deviceId);
if (record != null && !record.isLastResort()) {
removeKey(record.getId());
} else if (record == null) {
return Optional.empty();
}
List<KeyRecord> results = new LinkedList<>();
results.add(record);
return Optional.of(results);
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public Optional<List<KeyRecord>> get(String number) {
List<KeyRecord> preKeys = retrieveFirst(number);
if (preKeys != null) {
for (KeyRecord preKey : preKeys) {
if (!preKey.isLastResort()) {
removeKey(preKey.getId());
}
}
}
if (preKeys != null) return Optional.of(preKeys);
else return Optional.empty();
}
@SqlUpdate("VACUUM keys")
public abstract void vacuum();
@BindingAnnotation(PreKeyBinder.PreKeyBinderFactory.class)
@BindingAnnotation(KeyRecordBinder.PreKeyBinderFactory.class)
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.PARAMETER})
public @interface PreKeyBinder {
public @interface KeyRecordBinder {
public static class PreKeyBinderFactory implements BinderFactory {
@Override
public Binder build(Annotation annotation) {
return new Binder<PreKeyBinder, KeyRecord>() {
return new Binder<KeyRecordBinder, KeyRecord>() {
@Override
public void bind(SQLStatement<?> sql, PreKeyBinder accountBinder, KeyRecord record)
public void bind(SQLStatement<?> sql, KeyRecordBinder keyRecordBinder, KeyRecord record)
{
sql.bind("id", record.getId());
sql.bind("number", record.getNumber());
sql.bind("device_id", record.getDeviceId());
sql.bind("key_id", record.getKeyId());
sql.bind("public_key", record.getPublicKey());
sql.bind("last_resort", record.isLastResort() ? 1 : 0);
}
};
}
@ -136,14 +108,14 @@ public abstract class Keys {
}
public static class PreKeyMapper implements ResultSetMapper<KeyRecord> {
public static class KeyRecordMapper implements ResultSetMapper<KeyRecord> {
@Override
public KeyRecord map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
return new KeyRecord(resultSet.getLong("id"), resultSet.getString("number"),
resultSet.getLong("device_id"), resultSet.getLong("key_id"),
resultSet.getString("public_key"), resultSet.getInt("last_resort") == 1);
resultSet.getString("public_key"));
}
}

View File

@ -45,15 +45,16 @@ public class KeyControllerTest {
private static int SAMPLE_REGISTRATION_ID2 = 1002;
private static int SAMPLE_REGISTRATION_ID4 = 1555;
private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", false);
private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3", false );
private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5", false );
private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6", false );
private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1");
private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3");
private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5");
private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6");
private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey(1111, "foofoo", "sig11");
private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey(2222, "foobar", "sig22");
private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey(3333, "barfoo", "sig33");
private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey( 1111, "foofoo", "sig11" );
private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey( 2222, "foobar", "sig22" );
private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey( 3333, "barfoo", "sig33" );
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = new SignedPreKey(89898, "zoofarb", "sigvalid");
private final Keys keys = mock(Keys.class );
private final AccountsManager accounts = mock(AccountsManager.class);
@ -120,20 +121,20 @@ public class KeyControllerTest {
List<KeyRecord> singleDevice = new LinkedList<>();
singleDevice.add(SAMPLE_KEY);
when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(Optional.of(singleDevice));
when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(singleDevice);
when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.<List<KeyRecord>>empty());
when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(new LinkedList<>());
List<KeyRecord> multiDevice = new LinkedList<>();
multiDevice.add(SAMPLE_KEY);
multiDevice.add(SAMPLE_KEY2);
multiDevice.add(SAMPLE_KEY3);
multiDevice.add(SAMPLE_KEY4);
when(keys.get(EXISTS_NUMBER)).thenReturn(Optional.of(multiDevice));
when(keys.get(EXISTS_NUMBER)).thenReturn(multiDevice);
when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5);
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(new SignedPreKey(89898, "zoofarb", "sigvalid"));
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null);
}
@ -146,7 +147,7 @@ public class KeyControllerTest {
AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyCount.class);
assertThat(result.getCount() == 4);
assertThat(result.getCount()).isEqualTo(4);
verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L));
}
@ -159,7 +160,9 @@ public class KeyControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class);
assertThat(result.equals(SAMPLE_SIGNED_KEY));
assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getSignature());
assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getPublicKey());
}
@Test
@ -171,7 +174,7 @@ public class KeyControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(test, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus() == 204);
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
@ -332,10 +335,9 @@ public class KeyControllerTest {
@Test
public void putKeysTestV2() throws Exception {
final PreKey preKey = new PreKey(31337, "foobar");
final PreKey lastResortKey = new PreKey(31339, "barbar");
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
final String identityKey = "barbar";
final PreKey preKey = new PreKey(31337, "foobar");
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
final String identityKey = "barbar";
List<PreKey> preKeys = new LinkedList<PreKey>() {{
add(preKey);
@ -356,9 +358,9 @@ public class KeyControllerTest {
verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size() == 1);
assertThat(capturedList.get(0).getKeyId() == 31337);
assertThat(capturedList.get(0).getPublicKey().equals("foobar"));
assertThat(capturedList.size()).isEqualTo(1);
assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337);
assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar");
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar"));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));

View File

@ -0,0 +1,303 @@
package org.whispersystems.textsecuregcm.tests.storage;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
import com.opentable.db.postgres.junit.PreparedDbRule;
import org.junit.Rule;
import org.junit.Test;
import org.skife.jdbi.v2.DBI;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import org.whispersystems.textsecuregcm.storage.Keys;
import javax.sql.DataSource;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.LinkedList;
import java.util.List;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
public class KeysTest {
@Rule
public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
@Test
public void testPopulateKeys() throws SQLException {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
List<PreKey> oldAnotherDeviceOnePrKeys = new LinkedList<>();
List<PreKey> anotherDeviceOnePreKeys = new LinkedList<>();
List<PreKey> anotherDeviceTwoPreKeys = new LinkedList<>();
for (int i=1;i<=100;i++) {
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
}
for (int i=1;i<=100;i++) {
oldAnotherDeviceOnePrKeys.add(new PreKey(i, "OldPublicKey" + i));
anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
}
keys.store("+14152222222", 1, deviceOnePreKeys);
keys.store("+14152222222", 2, deviceTwoPreKeys);
keys.store("+14151111111", 1, oldAnotherDeviceOnePrKeys);
keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
PreparedStatement statement = dataSource.getConnection().prepareStatement("SELECT * FROM keys WHERE number = ? AND device_id = ? ORDER BY key_id");
verifyStoredState(statement, "+14152222222", 1);
verifyStoredState(statement, "+14152222222", 2);
verifyStoredState(statement, "+14151111111", 1);
verifyStoredState(statement, "+14151111111", 2);
}
@Test
public void testKeyCount() throws SQLException {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
for (int i=1;i<=100;i++) {
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
}
keys.store("+14152222222", 1, deviceOnePreKeys);
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
}
@Test
public void testGetForDevice() {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
List<PreKey> anotherDeviceOnePreKeys = new LinkedList<>();
List<PreKey> anotherDeviceTwoPreKeys = new LinkedList<>();
for (int i=1;i<=100;i++) {
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
}
for (int i=1;i<=100;i++) {
anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
}
keys.store("+14152222222", 1, deviceOnePreKeys);
keys.store("+14152222222", 2, deviceTwoPreKeys);
keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
List<KeyRecord> records = keys.get("+14152222222", 1);
assertThat(records.size()).isEqualTo(1);
assertThat(records.get(0).getKeyId()).isEqualTo(1);
assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey1");
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
records = keys.get("+14152222222", 1);
assertThat(records.size()).isEqualTo(1);
assertThat(records.get(0).getKeyId()).isEqualTo(2);
assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey2");
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
records = keys.get("+14152222222", 2);
assertThat(records.size()).isEqualTo(1);
assertThat(records.get(0).getKeyId()).isEqualTo(1);
assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device2PublicKey1");
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99);
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
}
@Test
public void testGetForAllDevices() {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
List<PreKey> anotherDeviceOnePreKeys = new LinkedList<>();
List<PreKey> anotherDeviceTwoPreKeys = new LinkedList<>();
List<PreKey> anotherDeviceThreePreKeys = new LinkedList<>();
for (int i=1;i<=100;i++) {
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
}
for (int i=1;i<=100;i++) {
anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i));
}
keys.store("+14152222222", 1, deviceOnePreKeys);
keys.store("+14152222222", 2, deviceTwoPreKeys);
keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
keys.store("+14151111111", 3, anotherDeviceThreePreKeys);
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
List<KeyRecord> records = keys.get("+14152222222");
assertThat(records.size()).isEqualTo(2);
assertThat(records.get(0).getKeyId()).isEqualTo(1);
assertThat(records.get(1).getKeyId()).isEqualTo(1);
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey1"))).isTrue();
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey1"))).isTrue();
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99);
records = keys.get("+14152222222");
assertThat(records.size()).isEqualTo(2);
assertThat(records.get(0).getKeyId()).isEqualTo(2);
assertThat(records.get(1).getKeyId()).isEqualTo(2);
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey2"))).isTrue();
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey2"))).isTrue();
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(98);
records = keys.get("+14151111111");
assertThat(records.size()).isEqualTo(3);
assertThat(records.get(0).getKeyId()).isEqualTo(1);
assertThat(records.get(1).getKeyId()).isEqualTo(1);
assertThat(records.get(2).getKeyId()).isEqualTo(1);
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device1PublicKey1"))).isTrue();
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device2PublicKey1"))).isTrue();
assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device3PublicKey1"))).isTrue();
assertThat(keys.getCount("+14151111111", 1)).isEqualTo(99);
assertThat(keys.getCount("+14151111111", 2)).isEqualTo(99);
assertThat(keys.getCount("+14151111111", 3)).isEqualTo(99);
}
@Test
public void testGetForAllDevicesParallel() throws InterruptedException {
DataSource dataSource = db.getTestDatabase();
DBI dbi = new DBI(dataSource);
Keys keys = dbi.onDemand(Keys.class);
List<PreKey> deviceOnePreKeys = new LinkedList<>();
List<PreKey> deviceTwoPreKeys = new LinkedList<>();
for (int i=1;i<=100;i++) {
deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
}
keys.store("+14152222222", 1, deviceOnePreKeys);
keys.store("+14152222222", 2, deviceTwoPreKeys);
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
List<Thread> threads = new LinkedList<>();
for (int i=0;i<50;i++) {
Thread thread = new Thread(() -> {
for (int j=0;j<10;j++) {
try {
List<KeyRecord> results = keys.get("+14152222222");
assertThat(results.size()).isEqualTo(2);
return;
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
throw new AssertionError();
});
thread.start();
threads.add(thread);
}
for (Thread thread : threads) {
thread.join();
}
assertThat(keys.getCount("+14152222222", 1)).isEqualTo(50);
assertThat(keys.getCount("+14152222222",2)).isEqualTo(50);
}
@Test
public void testEmptyKeyGet() {
DBI dbi = new DBI(db.getTestDatabase());
Keys keys = dbi.onDemand(Keys.class);
List<KeyRecord> records = keys.get("+14152222222");
assertThat(records.isEmpty()).isTrue();
}
private void verifyStoredState(PreparedStatement statement, String number, int deviceId) throws SQLException {
statement.setString(1, number);
statement.setInt(2, deviceId);
ResultSet resultSet = statement.executeQuery();
int rowCount = 1;
while (resultSet.next()) {
long keyId = resultSet.getLong("key_id");
String publicKey = resultSet.getString("public_key");
assertThat(keyId).isEqualTo(rowCount);
assertThat(publicKey).isEqualTo(number + "Device" + deviceId + "PublicKey" + rowCount);
rowCount++;
}
resultSet.close();
assertThat(rowCount).isEqualTo(101);
}
}