diff --git a/pom.xml b/pom.xml
index 2cdd2f9fd..071e3c14d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -166,6 +166,12 @@
test
+
+ com.opentable.components
+ otj-pg-embedded
+ 0.13.1
+ test
+
diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java
index 421cb7643..fb7e238d4 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java
@@ -1,4 +1,4 @@
-/**
+/*
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
@@ -135,19 +135,17 @@ public class KeysController {
rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "__" + number + "." + deviceId);
}
- Optional> targetKeys = getLocalKeys(target.get(), deviceId);
- List devices = new LinkedList<>();
+ List targetKeys = getLocalKeys(target.get(), deviceId);
+ List devices = new LinkedList<>();
for (Device device : target.get().getDevices()) {
if (device.isActive() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) {
SignedPreKey signedPreKey = device.getSignedPreKey();
PreKey preKey = null;
- if (targetKeys.isPresent()) {
- for (KeyRecord keyRecord : targetKeys.get()) {
- if (!keyRecord.isLastResort() && keyRecord.getDeviceId() == device.getId()) {
+ for (KeyRecord keyRecord : targetKeys) {
+ if (keyRecord.getDeviceId() == device.getId()) {
preKey = new PreKey(keyRecord.getKeyId(), keyRecord.getPublicKey());
- }
}
}
@@ -189,7 +187,7 @@ public class KeysController {
else return Optional.empty();
}
- private Optional> getLocalKeys(Account destination, String deviceIdSelector) {
+ private List getLocalKeys(Account destination, String deviceIdSelector) {
try {
if (deviceIdSelector.equals("*")) {
return keys.get(destination.getNumber());
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java b/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java
index 6f5d0bc47..7b1222158 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java
@@ -7,17 +7,13 @@ public class KeyRecord {
private long deviceId;
private long keyId;
private String publicKey;
- private boolean lastResort;
- public KeyRecord(long id, String number, long deviceId, long keyId,
- String publicKey, boolean lastResort)
- {
+ public KeyRecord(long id, String number, long deviceId, long keyId, String publicKey) {
this.id = id;
this.number = number;
this.deviceId = deviceId;
this.keyId = keyId;
this.publicKey = publicKey;
- this.lastResort = lastResort;
}
public long getId() {
@@ -40,7 +36,4 @@ public class KeyRecord {
return publicKey;
}
- public boolean isLastResort() {
- return lastResort;
- }
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java
index 2082cf716..07767c251 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java
@@ -1,4 +1,4 @@
-/**
+/*
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@@ -38,97 +38,69 @@ import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.sql.ResultSet;
import java.sql.SQLException;
-import java.util.LinkedList;
import java.util.List;
-import java.util.Optional;
+import java.util.stream.Collectors;
public abstract class Keys {
@SqlUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id")
- abstract void removeKeys(@Bind("number") String number, @Bind("device_id") long deviceId);
+ abstract void remove(@Bind("number") String number, @Bind("device_id") long deviceId);
- @SqlUpdate("DELETE FROM keys WHERE id = :id")
- abstract void removeKey(@Bind("id") long id);
+ @SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key) VALUES (:number, :device_id, :key_id, :public_key)")
+ abstract void append(@KeyRecordBinder List preKeys);
- @SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key, last_resort) VALUES " +
- "(:number, :device_id, :key_id, :public_key, :last_resort)")
- abstract void append(@PreKeyBinder List preKeys);
+ @SqlQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *")
+ @Mapper(KeyRecordMapper.class)
+ abstract List getInternal(@Bind("number") String number, @Bind("device_id") long deviceId);
- @SqlQuery("SELECT * FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC FOR UPDATE")
- @Mapper(PreKeyMapper.class)
- abstract KeyRecord retrieveFirst(@Bind("number") String number, @Bind("device_id") long deviceId);
-
- @SqlQuery("SELECT DISTINCT ON (number, device_id) * FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC")
- @Mapper(PreKeyMapper.class)
- abstract List retrieveFirst(@Bind("number") String number);
+ @SqlQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *")
+ @Mapper(KeyRecordMapper.class)
+ abstract List getInternal(@Bind("number") String number);
@SqlQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id")
public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId);
+ // Apparently transaction annotations don't work on the annotated query methods
+ @Transaction(TransactionIsolationLevel.SERIALIZABLE)
+ public List get(String number) {
+ return getInternal(number);
+ }
+
+ @Transaction(TransactionIsolationLevel.SERIALIZABLE)
+ public List get(String number, long deviceId) {
+ return getInternal(number, deviceId);
+ }
+
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public void store(String number, long deviceId, List keys) {
- List records = new LinkedList<>();
+ List records = keys.stream()
+ .map(key -> new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey()))
+ .collect(Collectors.toList());
- for (PreKey key : keys) {
- records.add(new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey(), false));
- }
-
- removeKeys(number, deviceId);
+ remove(number, deviceId);
append(records);
}
- @Transaction(TransactionIsolationLevel.SERIALIZABLE)
- public Optional> get(String number, long deviceId) {
- final KeyRecord record = retrieveFirst(number, deviceId);
-
- if (record != null && !record.isLastResort()) {
- removeKey(record.getId());
- } else if (record == null) {
- return Optional.empty();
- }
-
- List results = new LinkedList<>();
- results.add(record);
-
- return Optional.of(results);
- }
-
- @Transaction(TransactionIsolationLevel.SERIALIZABLE)
- public Optional> get(String number) {
- List preKeys = retrieveFirst(number);
-
- if (preKeys != null) {
- for (KeyRecord preKey : preKeys) {
- if (!preKey.isLastResort()) {
- removeKey(preKey.getId());
- }
- }
- }
-
- if (preKeys != null) return Optional.of(preKeys);
- else return Optional.empty();
- }
@SqlUpdate("VACUUM keys")
public abstract void vacuum();
- @BindingAnnotation(PreKeyBinder.PreKeyBinderFactory.class)
+ @BindingAnnotation(KeyRecordBinder.PreKeyBinderFactory.class)
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.PARAMETER})
- public @interface PreKeyBinder {
+ public @interface KeyRecordBinder {
public static class PreKeyBinderFactory implements BinderFactory {
@Override
public Binder build(Annotation annotation) {
- return new Binder() {
+ return new Binder() {
@Override
- public void bind(SQLStatement> sql, PreKeyBinder accountBinder, KeyRecord record)
+ public void bind(SQLStatement> sql, KeyRecordBinder keyRecordBinder, KeyRecord record)
{
sql.bind("id", record.getId());
sql.bind("number", record.getNumber());
sql.bind("device_id", record.getDeviceId());
sql.bind("key_id", record.getKeyId());
sql.bind("public_key", record.getPublicKey());
- sql.bind("last_resort", record.isLastResort() ? 1 : 0);
}
};
}
@@ -136,14 +108,14 @@ public abstract class Keys {
}
- public static class PreKeyMapper implements ResultSetMapper {
+ public static class KeyRecordMapper implements ResultSetMapper {
@Override
public KeyRecord map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
return new KeyRecord(resultSet.getLong("id"), resultSet.getString("number"),
resultSet.getLong("device_id"), resultSet.getLong("key_id"),
- resultSet.getString("public_key"), resultSet.getInt("last_resort") == 1);
+ resultSet.getString("public_key"));
}
}
diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java
index 7a980c008..e4a07c175 100644
--- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java
+++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java
@@ -45,15 +45,16 @@ public class KeyControllerTest {
private static int SAMPLE_REGISTRATION_ID2 = 1002;
private static int SAMPLE_REGISTRATION_ID4 = 1555;
- private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", false);
- private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3", false );
- private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5", false );
- private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6", false );
+ private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1");
+ private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3");
+ private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5");
+ private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6");
- private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey(1111, "foofoo", "sig11");
- private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey(2222, "foobar", "sig22");
- private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey(3333, "barfoo", "sig33");
+ private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey( 1111, "foofoo", "sig11" );
+ private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey( 2222, "foobar", "sig22" );
+ private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey( 3333, "barfoo", "sig33" );
+ private final SignedPreKey VALID_DEVICE_SIGNED_KEY = new SignedPreKey(89898, "zoofarb", "sigvalid");
private final Keys keys = mock(Keys.class );
private final AccountsManager accounts = mock(AccountsManager.class);
@@ -120,20 +121,20 @@ public class KeyControllerTest {
List singleDevice = new LinkedList<>();
singleDevice.add(SAMPLE_KEY);
- when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(Optional.of(singleDevice));
+ when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(singleDevice);
- when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.>empty());
+ when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(new LinkedList<>());
List multiDevice = new LinkedList<>();
multiDevice.add(SAMPLE_KEY);
multiDevice.add(SAMPLE_KEY2);
multiDevice.add(SAMPLE_KEY3);
multiDevice.add(SAMPLE_KEY4);
- when(keys.get(EXISTS_NUMBER)).thenReturn(Optional.of(multiDevice));
+ when(keys.get(EXISTS_NUMBER)).thenReturn(multiDevice);
when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5);
- when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(new SignedPreKey(89898, "zoofarb", "sigvalid"));
+ when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null);
}
@@ -146,7 +147,7 @@ public class KeyControllerTest {
AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyCount.class);
- assertThat(result.getCount() == 4);
+ assertThat(result.getCount()).isEqualTo(4);
verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L));
}
@@ -159,7 +160,9 @@ public class KeyControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class);
- assertThat(result.equals(SAMPLE_SIGNED_KEY));
+ assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getSignature());
+ assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getKeyId());
+ assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getPublicKey());
}
@Test
@@ -171,7 +174,7 @@ public class KeyControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(test, MediaType.APPLICATION_JSON_TYPE));
- assertThat(response.getStatus() == 204);
+ assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
@@ -332,10 +335,9 @@ public class KeyControllerTest {
@Test
public void putKeysTestV2() throws Exception {
- final PreKey preKey = new PreKey(31337, "foobar");
- final PreKey lastResortKey = new PreKey(31339, "barbar");
- final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
- final String identityKey = "barbar";
+ final PreKey preKey = new PreKey(31337, "foobar");
+ final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
+ final String identityKey = "barbar";
List preKeys = new LinkedList() {{
add(preKey);
@@ -356,9 +358,9 @@ public class KeyControllerTest {
verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture());
List capturedList = listCaptor.getValue();
- assertThat(capturedList.size() == 1);
- assertThat(capturedList.get(0).getKeyId() == 31337);
- assertThat(capturedList.get(0).getPublicKey().equals("foobar"));
+ assertThat(capturedList.size()).isEqualTo(1);
+ assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337);
+ assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar");
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar"));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java
new file mode 100644
index 000000000..589c79cec
--- /dev/null
+++ b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java
@@ -0,0 +1,303 @@
+package org.whispersystems.textsecuregcm.tests.storage;
+
+import com.opentable.db.postgres.embedded.LiquibasePreparer;
+import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
+import com.opentable.db.postgres.junit.PreparedDbRule;
+import org.junit.Rule;
+import org.junit.Test;
+import org.skife.jdbi.v2.DBI;
+import org.whispersystems.textsecuregcm.entities.PreKey;
+import org.whispersystems.textsecuregcm.storage.KeyRecord;
+import org.whispersystems.textsecuregcm.storage.Keys;
+
+import javax.sql.DataSource;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.LinkedList;
+import java.util.List;
+
+import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
+
+public class KeysTest {
+
+ @Rule
+ public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
+
+ @Test
+ public void testPopulateKeys() throws SQLException {
+ DataSource dataSource = db.getTestDatabase();
+ DBI dbi = new DBI(dataSource);
+ Keys keys = dbi.onDemand(Keys.class);
+
+ List deviceOnePreKeys = new LinkedList<>();
+ List deviceTwoPreKeys = new LinkedList<>();
+
+ List oldAnotherDeviceOnePrKeys = new LinkedList<>();
+ List anotherDeviceOnePreKeys = new LinkedList<>();
+ List anotherDeviceTwoPreKeys = new LinkedList<>();
+
+ for (int i=1;i<=100;i++) {
+ deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
+ deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
+ }
+
+ for (int i=1;i<=100;i++) {
+ oldAnotherDeviceOnePrKeys.add(new PreKey(i, "OldPublicKey" + i));
+ anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
+ anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
+ }
+
+ keys.store("+14152222222", 1, deviceOnePreKeys);
+ keys.store("+14152222222", 2, deviceTwoPreKeys);
+
+ keys.store("+14151111111", 1, oldAnotherDeviceOnePrKeys);
+ keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
+ keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
+
+ PreparedStatement statement = dataSource.getConnection().prepareStatement("SELECT * FROM keys WHERE number = ? AND device_id = ? ORDER BY key_id");
+ verifyStoredState(statement, "+14152222222", 1);
+ verifyStoredState(statement, "+14152222222", 2);
+ verifyStoredState(statement, "+14151111111", 1);
+ verifyStoredState(statement, "+14151111111", 2);
+ }
+
+ @Test
+ public void testKeyCount() throws SQLException {
+ DataSource dataSource = db.getTestDatabase();
+ DBI dbi = new DBI(dataSource);
+ Keys keys = dbi.onDemand(Keys.class);
+
+ List deviceOnePreKeys = new LinkedList<>();
+
+ for (int i=1;i<=100;i++) {
+ deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
+ }
+
+
+ keys.store("+14152222222", 1, deviceOnePreKeys);
+
+ assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
+ }
+
+ @Test
+ public void testGetForDevice() {
+ DataSource dataSource = db.getTestDatabase();
+ DBI dbi = new DBI(dataSource);
+ Keys keys = dbi.onDemand(Keys.class);
+
+ List deviceOnePreKeys = new LinkedList<>();
+ List deviceTwoPreKeys = new LinkedList<>();
+
+ List anotherDeviceOnePreKeys = new LinkedList<>();
+ List anotherDeviceTwoPreKeys = new LinkedList<>();
+
+ for (int i=1;i<=100;i++) {
+ deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
+ deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
+ }
+
+ for (int i=1;i<=100;i++) {
+ anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
+ anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
+ }
+
+ keys.store("+14152222222", 1, deviceOnePreKeys);
+ keys.store("+14152222222", 2, deviceTwoPreKeys);
+
+ keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
+ keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
+
+
+ assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
+ List records = keys.get("+14152222222", 1);
+
+ assertThat(records.size()).isEqualTo(1);
+ assertThat(records.get(0).getKeyId()).isEqualTo(1);
+ assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey1");
+ assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99);
+ assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
+ assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
+ assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
+
+ records = keys.get("+14152222222", 1);
+
+ assertThat(records.size()).isEqualTo(1);
+ assertThat(records.get(0).getKeyId()).isEqualTo(2);
+ assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey2");
+ assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
+ assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
+ assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
+ assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
+
+ records = keys.get("+14152222222", 2);
+
+ assertThat(records.size()).isEqualTo(1);
+ assertThat(records.get(0).getKeyId()).isEqualTo(1);
+ assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device2PublicKey1");
+ assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
+ assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99);
+ assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100);
+ assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100);
+ }
+
+ @Test
+ public void testGetForAllDevices() {
+ DataSource dataSource = db.getTestDatabase();
+ DBI dbi = new DBI(dataSource);
+ Keys keys = dbi.onDemand(Keys.class);
+
+ List deviceOnePreKeys = new LinkedList<>();
+ List deviceTwoPreKeys = new LinkedList<>();
+
+ List anotherDeviceOnePreKeys = new LinkedList<>();
+ List anotherDeviceTwoPreKeys = new LinkedList<>();
+ List anotherDeviceThreePreKeys = new LinkedList<>();
+
+ for (int i=1;i<=100;i++) {
+ deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
+ deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
+ }
+
+ for (int i=1;i<=100;i++) {
+ anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i));
+ anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i));
+ anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i));
+ }
+
+ keys.store("+14152222222", 1, deviceOnePreKeys);
+ keys.store("+14152222222", 2, deviceTwoPreKeys);
+
+ keys.store("+14151111111", 1, anotherDeviceOnePreKeys);
+ keys.store("+14151111111", 2, anotherDeviceTwoPreKeys);
+ keys.store("+14151111111", 3, anotherDeviceThreePreKeys);
+
+
+ assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
+ assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
+
+ List records = keys.get("+14152222222");
+
+ assertThat(records.size()).isEqualTo(2);
+ assertThat(records.get(0).getKeyId()).isEqualTo(1);
+ assertThat(records.get(1).getKeyId()).isEqualTo(1);
+
+ assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey1"))).isTrue();
+ assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey1"))).isTrue();
+
+ assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99);
+ assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99);
+
+ records = keys.get("+14152222222");
+
+ assertThat(records.size()).isEqualTo(2);
+ assertThat(records.get(0).getKeyId()).isEqualTo(2);
+ assertThat(records.get(1).getKeyId()).isEqualTo(2);
+
+ assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey2"))).isTrue();
+ assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey2"))).isTrue();
+
+ assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98);
+ assertThat(keys.getCount("+14152222222", 2)).isEqualTo(98);
+
+
+ records = keys.get("+14151111111");
+
+ assertThat(records.size()).isEqualTo(3);
+ assertThat(records.get(0).getKeyId()).isEqualTo(1);
+ assertThat(records.get(1).getKeyId()).isEqualTo(1);
+ assertThat(records.get(2).getKeyId()).isEqualTo(1);
+
+ assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device1PublicKey1"))).isTrue();
+ assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device2PublicKey1"))).isTrue();
+ assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device3PublicKey1"))).isTrue();
+
+ assertThat(keys.getCount("+14151111111", 1)).isEqualTo(99);
+ assertThat(keys.getCount("+14151111111", 2)).isEqualTo(99);
+ assertThat(keys.getCount("+14151111111", 3)).isEqualTo(99);
+ }
+
+ @Test
+ public void testGetForAllDevicesParallel() throws InterruptedException {
+ DataSource dataSource = db.getTestDatabase();
+ DBI dbi = new DBI(dataSource);
+ Keys keys = dbi.onDemand(Keys.class);
+
+ List deviceOnePreKeys = new LinkedList<>();
+ List deviceTwoPreKeys = new LinkedList<>();
+
+ for (int i=1;i<=100;i++) {
+ deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i));
+ deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i));
+ }
+
+ keys.store("+14152222222", 1, deviceOnePreKeys);
+ keys.store("+14152222222", 2, deviceTwoPreKeys);
+
+ assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100);
+ assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100);
+
+ List threads = new LinkedList<>();
+
+ for (int i=0;i<50;i++) {
+ Thread thread = new Thread(() -> {
+ for (int j=0;j<10;j++) {
+ try {
+ List results = keys.get("+14152222222");
+ assertThat(results.size()).isEqualTo(2);
+ return;
+ } catch (Exception e) {
+ System.err.println(e.getMessage());
+ }
+ }
+ throw new AssertionError();
+ });
+ thread.start();
+ threads.add(thread);
+ }
+
+ for (Thread thread : threads) {
+ thread.join();
+ }
+
+ assertThat(keys.getCount("+14152222222", 1)).isEqualTo(50);
+ assertThat(keys.getCount("+14152222222",2)).isEqualTo(50);
+ }
+
+
+ @Test
+ public void testEmptyKeyGet() {
+ DBI dbi = new DBI(db.getTestDatabase());
+ Keys keys = dbi.onDemand(Keys.class);
+
+ List records = keys.get("+14152222222");
+
+ assertThat(records.isEmpty()).isTrue();
+ }
+
+
+ private void verifyStoredState(PreparedStatement statement, String number, int deviceId) throws SQLException {
+ statement.setString(1, number);
+ statement.setInt(2, deviceId);
+
+ ResultSet resultSet = statement.executeQuery();
+ int rowCount = 1;
+
+ while (resultSet.next()) {
+ long keyId = resultSet.getLong("key_id");
+ String publicKey = resultSet.getString("public_key");
+
+
+ assertThat(keyId).isEqualTo(rowCount);
+ assertThat(publicKey).isEqualTo(number + "Device" + deviceId + "PublicKey" + rowCount);
+
+ rowCount++;
+ }
+
+ resultSet.close();
+
+ assertThat(rowCount).isEqualTo(101);
+
+ }
+
+}