diff --git a/pom.xml b/pom.xml index 2cdd2f9fd..071e3c14d 100644 --- a/pom.xml +++ b/pom.xml @@ -166,6 +166,12 @@ test + + com.opentable.components + otj-pg-embedded + 0.13.1 + test + diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 421cb7643..fb7e238d4 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2014 Open Whisper Systems * * This program is free software: you can redistribute it and/or modify @@ -135,19 +135,17 @@ public class KeysController { rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "__" + number + "." + deviceId); } - Optional> targetKeys = getLocalKeys(target.get(), deviceId); - List devices = new LinkedList<>(); + List targetKeys = getLocalKeys(target.get(), deviceId); + List devices = new LinkedList<>(); for (Device device : target.get().getDevices()) { if (device.isActive() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) { SignedPreKey signedPreKey = device.getSignedPreKey(); PreKey preKey = null; - if (targetKeys.isPresent()) { - for (KeyRecord keyRecord : targetKeys.get()) { - if (!keyRecord.isLastResort() && keyRecord.getDeviceId() == device.getId()) { + for (KeyRecord keyRecord : targetKeys) { + if (keyRecord.getDeviceId() == device.getId()) { preKey = new PreKey(keyRecord.getKeyId(), keyRecord.getPublicKey()); - } } } @@ -189,7 +187,7 @@ public class KeysController { else return Optional.empty(); } - private Optional> getLocalKeys(Account destination, String deviceIdSelector) { + private List getLocalKeys(Account destination, String deviceIdSelector) { try { if (deviceIdSelector.equals("*")) { return keys.get(destination.getNumber()); diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java b/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java index 6f5d0bc47..7b1222158 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java @@ -7,17 +7,13 @@ public class KeyRecord { private long deviceId; private long keyId; private String publicKey; - private boolean lastResort; - public KeyRecord(long id, String number, long deviceId, long keyId, - String publicKey, boolean lastResort) - { + public KeyRecord(long id, String number, long deviceId, long keyId, String publicKey) { this.id = id; this.number = number; this.deviceId = deviceId; this.keyId = keyId; this.publicKey = publicKey; - this.lastResort = lastResort; } public long getId() { @@ -40,7 +36,4 @@ public class KeyRecord { return publicKey; } - public boolean isLastResort() { - return lastResort; - } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java index 2082cf716..07767c251 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2013 Open WhisperSystems * * This program is free software: you can redistribute it and/or modify @@ -38,97 +38,69 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.sql.ResultSet; import java.sql.SQLException; -import java.util.LinkedList; import java.util.List; -import java.util.Optional; +import java.util.stream.Collectors; public abstract class Keys { @SqlUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id") - abstract void removeKeys(@Bind("number") String number, @Bind("device_id") long deviceId); + abstract void remove(@Bind("number") String number, @Bind("device_id") long deviceId); - @SqlUpdate("DELETE FROM keys WHERE id = :id") - abstract void removeKey(@Bind("id") long id); + @SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key) VALUES (:number, :device_id, :key_id, :public_key)") + abstract void append(@KeyRecordBinder List preKeys); - @SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key, last_resort) VALUES " + - "(:number, :device_id, :key_id, :public_key, :last_resort)") - abstract void append(@PreKeyBinder List preKeys); + @SqlQuery("DELETE FROM keys WHERE id IN (SELECT id FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC LIMIT 1) RETURNING *") + @Mapper(KeyRecordMapper.class) + abstract List getInternal(@Bind("number") String number, @Bind("device_id") long deviceId); - @SqlQuery("SELECT * FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC FOR UPDATE") - @Mapper(PreKeyMapper.class) - abstract KeyRecord retrieveFirst(@Bind("number") String number, @Bind("device_id") long deviceId); - - @SqlQuery("SELECT DISTINCT ON (number, device_id) * FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC") - @Mapper(PreKeyMapper.class) - abstract List retrieveFirst(@Bind("number") String number); + @SqlQuery("DELETE FROM keys WHERE id IN (SELECT DISTINCT ON (number, device_id) id FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC) RETURNING *") + @Mapper(KeyRecordMapper.class) + abstract List getInternal(@Bind("number") String number); @SqlQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id") public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId); + // Apparently transaction annotations don't work on the annotated query methods + @Transaction(TransactionIsolationLevel.SERIALIZABLE) + public List get(String number) { + return getInternal(number); + } + + @Transaction(TransactionIsolationLevel.SERIALIZABLE) + public List get(String number, long deviceId) { + return getInternal(number, deviceId); + } + @Transaction(TransactionIsolationLevel.SERIALIZABLE) public void store(String number, long deviceId, List keys) { - List records = new LinkedList<>(); + List records = keys.stream() + .map(key -> new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey())) + .collect(Collectors.toList()); - for (PreKey key : keys) { - records.add(new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey(), false)); - } - - removeKeys(number, deviceId); + remove(number, deviceId); append(records); } - @Transaction(TransactionIsolationLevel.SERIALIZABLE) - public Optional> get(String number, long deviceId) { - final KeyRecord record = retrieveFirst(number, deviceId); - - if (record != null && !record.isLastResort()) { - removeKey(record.getId()); - } else if (record == null) { - return Optional.empty(); - } - - List results = new LinkedList<>(); - results.add(record); - - return Optional.of(results); - } - - @Transaction(TransactionIsolationLevel.SERIALIZABLE) - public Optional> get(String number) { - List preKeys = retrieveFirst(number); - - if (preKeys != null) { - for (KeyRecord preKey : preKeys) { - if (!preKey.isLastResort()) { - removeKey(preKey.getId()); - } - } - } - - if (preKeys != null) return Optional.of(preKeys); - else return Optional.empty(); - } @SqlUpdate("VACUUM keys") public abstract void vacuum(); - @BindingAnnotation(PreKeyBinder.PreKeyBinderFactory.class) + @BindingAnnotation(KeyRecordBinder.PreKeyBinderFactory.class) @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.PARAMETER}) - public @interface PreKeyBinder { + public @interface KeyRecordBinder { public static class PreKeyBinderFactory implements BinderFactory { @Override public Binder build(Annotation annotation) { - return new Binder() { + return new Binder() { @Override - public void bind(SQLStatement sql, PreKeyBinder accountBinder, KeyRecord record) + public void bind(SQLStatement sql, KeyRecordBinder keyRecordBinder, KeyRecord record) { sql.bind("id", record.getId()); sql.bind("number", record.getNumber()); sql.bind("device_id", record.getDeviceId()); sql.bind("key_id", record.getKeyId()); sql.bind("public_key", record.getPublicKey()); - sql.bind("last_resort", record.isLastResort() ? 1 : 0); } }; } @@ -136,14 +108,14 @@ public abstract class Keys { } - public static class PreKeyMapper implements ResultSetMapper { + public static class KeyRecordMapper implements ResultSetMapper { @Override public KeyRecord map(int i, ResultSet resultSet, StatementContext statementContext) throws SQLException { return new KeyRecord(resultSet.getLong("id"), resultSet.getString("number"), resultSet.getLong("device_id"), resultSet.getLong("key_id"), - resultSet.getString("public_key"), resultSet.getInt("last_resort") == 1); + resultSet.getString("public_key")); } } diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java index 7a980c008..e4a07c175 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java @@ -45,15 +45,16 @@ public class KeyControllerTest { private static int SAMPLE_REGISTRATION_ID2 = 1002; private static int SAMPLE_REGISTRATION_ID4 = 1555; - private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", false); - private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3", false ); - private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5", false ); - private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6", false ); + private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1"); + private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3"); + private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5"); + private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6"); - private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey(1111, "foofoo", "sig11"); - private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey(2222, "foobar", "sig22"); - private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey(3333, "barfoo", "sig33"); + private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey( 1111, "foofoo", "sig11" ); + private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey( 2222, "foobar", "sig22" ); + private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey( 3333, "barfoo", "sig33" ); + private final SignedPreKey VALID_DEVICE_SIGNED_KEY = new SignedPreKey(89898, "zoofarb", "sigvalid"); private final Keys keys = mock(Keys.class ); private final AccountsManager accounts = mock(AccountsManager.class); @@ -120,20 +121,20 @@ public class KeyControllerTest { List singleDevice = new LinkedList<>(); singleDevice.add(SAMPLE_KEY); - when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(Optional.of(singleDevice)); + when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(singleDevice); - when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.>empty()); + when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(new LinkedList<>()); List multiDevice = new LinkedList<>(); multiDevice.add(SAMPLE_KEY); multiDevice.add(SAMPLE_KEY2); multiDevice.add(SAMPLE_KEY3); multiDevice.add(SAMPLE_KEY4); - when(keys.get(EXISTS_NUMBER)).thenReturn(Optional.of(multiDevice)); + when(keys.get(EXISTS_NUMBER)).thenReturn(multiDevice); when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5); - when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(new SignedPreKey(89898, "zoofarb", "sigvalid")); + when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY); when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null); } @@ -146,7 +147,7 @@ public class KeyControllerTest { AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .get(PreKeyCount.class); - assertThat(result.getCount() == 4); + assertThat(result.getCount()).isEqualTo(4); verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L)); } @@ -159,7 +160,9 @@ public class KeyControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .get(SignedPreKey.class); - assertThat(result.equals(SAMPLE_SIGNED_KEY)); + assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getSignature()); + assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getKeyId()); + assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getPublicKey()); } @Test @@ -171,7 +174,7 @@ public class KeyControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(test, MediaType.APPLICATION_JSON_TYPE)); - assertThat(response.getStatus() == 204); + assertThat(response.getStatus()).isEqualTo(204); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test)); verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT)); @@ -332,10 +335,9 @@ public class KeyControllerTest { @Test public void putKeysTestV2() throws Exception { - final PreKey preKey = new PreKey(31337, "foobar"); - final PreKey lastResortKey = new PreKey(31339, "barbar"); - final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig"); - final String identityKey = "barbar"; + final PreKey preKey = new PreKey(31337, "foobar"); + final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig"); + final String identityKey = "barbar"; List preKeys = new LinkedList() {{ add(preKey); @@ -356,9 +358,9 @@ public class KeyControllerTest { verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture()); List capturedList = listCaptor.getValue(); - assertThat(capturedList.size() == 1); - assertThat(capturedList.get(0).getKeyId() == 31337); - assertThat(capturedList.get(0).getPublicKey().equals("foobar")); + assertThat(capturedList.size()).isEqualTo(1); + assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337); + assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar"); verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar")); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java new file mode 100644 index 000000000..589c79cec --- /dev/null +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/storage/KeysTest.java @@ -0,0 +1,303 @@ +package org.whispersystems.textsecuregcm.tests.storage; + +import com.opentable.db.postgres.embedded.LiquibasePreparer; +import com.opentable.db.postgres.junit.EmbeddedPostgresRules; +import com.opentable.db.postgres.junit.PreparedDbRule; +import org.junit.Rule; +import org.junit.Test; +import org.skife.jdbi.v2.DBI; +import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.storage.KeyRecord; +import org.whispersystems.textsecuregcm.storage.Keys; + +import javax.sql.DataSource; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.LinkedList; +import java.util.List; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +public class KeysTest { + + @Rule + public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml")); + + @Test + public void testPopulateKeys() throws SQLException { + DataSource dataSource = db.getTestDatabase(); + DBI dbi = new DBI(dataSource); + Keys keys = dbi.onDemand(Keys.class); + + List deviceOnePreKeys = new LinkedList<>(); + List deviceTwoPreKeys = new LinkedList<>(); + + List oldAnotherDeviceOnePrKeys = new LinkedList<>(); + List anotherDeviceOnePreKeys = new LinkedList<>(); + List anotherDeviceTwoPreKeys = new LinkedList<>(); + + for (int i=1;i<=100;i++) { + deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i)); + deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i)); + } + + for (int i=1;i<=100;i++) { + oldAnotherDeviceOnePrKeys.add(new PreKey(i, "OldPublicKey" + i)); + anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i)); + anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); + } + + keys.store("+14152222222", 1, deviceOnePreKeys); + keys.store("+14152222222", 2, deviceTwoPreKeys); + + keys.store("+14151111111", 1, oldAnotherDeviceOnePrKeys); + keys.store("+14151111111", 1, anotherDeviceOnePreKeys); + keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); + + PreparedStatement statement = dataSource.getConnection().prepareStatement("SELECT * FROM keys WHERE number = ? AND device_id = ? ORDER BY key_id"); + verifyStoredState(statement, "+14152222222", 1); + verifyStoredState(statement, "+14152222222", 2); + verifyStoredState(statement, "+14151111111", 1); + verifyStoredState(statement, "+14151111111", 2); + } + + @Test + public void testKeyCount() throws SQLException { + DataSource dataSource = db.getTestDatabase(); + DBI dbi = new DBI(dataSource); + Keys keys = dbi.onDemand(Keys.class); + + List deviceOnePreKeys = new LinkedList<>(); + + for (int i=1;i<=100;i++) { + deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i)); + } + + + keys.store("+14152222222", 1, deviceOnePreKeys); + + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); + } + + @Test + public void testGetForDevice() { + DataSource dataSource = db.getTestDatabase(); + DBI dbi = new DBI(dataSource); + Keys keys = dbi.onDemand(Keys.class); + + List deviceOnePreKeys = new LinkedList<>(); + List deviceTwoPreKeys = new LinkedList<>(); + + List anotherDeviceOnePreKeys = new LinkedList<>(); + List anotherDeviceTwoPreKeys = new LinkedList<>(); + + for (int i=1;i<=100;i++) { + deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i)); + deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i)); + } + + for (int i=1;i<=100;i++) { + anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i)); + anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); + } + + keys.store("+14152222222", 1, deviceOnePreKeys); + keys.store("+14152222222", 2, deviceTwoPreKeys); + + keys.store("+14151111111", 1, anotherDeviceOnePreKeys); + keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); + + + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); + List records = keys.get("+14152222222", 1); + + assertThat(records.size()).isEqualTo(1); + assertThat(records.get(0).getKeyId()).isEqualTo(1); + assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey1"); + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99); + assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); + assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); + assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); + + records = keys.get("+14152222222", 1); + + assertThat(records.size()).isEqualTo(1); + assertThat(records.get(0).getKeyId()).isEqualTo(2); + assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device1PublicKey2"); + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); + assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); + assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); + assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); + + records = keys.get("+14152222222", 2); + + assertThat(records.size()).isEqualTo(1); + assertThat(records.get(0).getKeyId()).isEqualTo(1); + assertThat(records.get(0).getPublicKey()).isEqualTo("+14152222222Device2PublicKey1"); + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); + assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99); + assertThat(keys.getCount("+14151111111", 1)).isEqualTo(100); + assertThat(keys.getCount("+14151111111", 2)).isEqualTo(100); + } + + @Test + public void testGetForAllDevices() { + DataSource dataSource = db.getTestDatabase(); + DBI dbi = new DBI(dataSource); + Keys keys = dbi.onDemand(Keys.class); + + List deviceOnePreKeys = new LinkedList<>(); + List deviceTwoPreKeys = new LinkedList<>(); + + List anotherDeviceOnePreKeys = new LinkedList<>(); + List anotherDeviceTwoPreKeys = new LinkedList<>(); + List anotherDeviceThreePreKeys = new LinkedList<>(); + + for (int i=1;i<=100;i++) { + deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i)); + deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i)); + } + + for (int i=1;i<=100;i++) { + anotherDeviceOnePreKeys.add(new PreKey(i, "+14151111111Device1PublicKey" + i)); + anotherDeviceTwoPreKeys.add(new PreKey(i, "+14151111111Device2PublicKey" + i)); + anotherDeviceThreePreKeys.add(new PreKey(i, "+14151111111Device3PublicKey" + i)); + } + + keys.store("+14152222222", 1, deviceOnePreKeys); + keys.store("+14152222222", 2, deviceTwoPreKeys); + + keys.store("+14151111111", 1, anotherDeviceOnePreKeys); + keys.store("+14151111111", 2, anotherDeviceTwoPreKeys); + keys.store("+14151111111", 3, anotherDeviceThreePreKeys); + + + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); + assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); + + List records = keys.get("+14152222222"); + + assertThat(records.size()).isEqualTo(2); + assertThat(records.get(0).getKeyId()).isEqualTo(1); + assertThat(records.get(1).getKeyId()).isEqualTo(1); + + assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey1"))).isTrue(); + assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey1"))).isTrue(); + + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(99); + assertThat(keys.getCount("+14152222222", 2)).isEqualTo(99); + + records = keys.get("+14152222222"); + + assertThat(records.size()).isEqualTo(2); + assertThat(records.get(0).getKeyId()).isEqualTo(2); + assertThat(records.get(1).getKeyId()).isEqualTo(2); + + assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device1PublicKey2"))).isTrue(); + assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14152222222Device2PublicKey2"))).isTrue(); + + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(98); + assertThat(keys.getCount("+14152222222", 2)).isEqualTo(98); + + + records = keys.get("+14151111111"); + + assertThat(records.size()).isEqualTo(3); + assertThat(records.get(0).getKeyId()).isEqualTo(1); + assertThat(records.get(1).getKeyId()).isEqualTo(1); + assertThat(records.get(2).getKeyId()).isEqualTo(1); + + assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device1PublicKey1"))).isTrue(); + assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device2PublicKey1"))).isTrue(); + assertThat(records.stream().anyMatch(record -> record.getPublicKey().equals("+14151111111Device3PublicKey1"))).isTrue(); + + assertThat(keys.getCount("+14151111111", 1)).isEqualTo(99); + assertThat(keys.getCount("+14151111111", 2)).isEqualTo(99); + assertThat(keys.getCount("+14151111111", 3)).isEqualTo(99); + } + + @Test + public void testGetForAllDevicesParallel() throws InterruptedException { + DataSource dataSource = db.getTestDatabase(); + DBI dbi = new DBI(dataSource); + Keys keys = dbi.onDemand(Keys.class); + + List deviceOnePreKeys = new LinkedList<>(); + List deviceTwoPreKeys = new LinkedList<>(); + + for (int i=1;i<=100;i++) { + deviceOnePreKeys.add(new PreKey(i, "+14152222222Device1PublicKey" + i)); + deviceTwoPreKeys.add(new PreKey(i, "+14152222222Device2PublicKey" + i)); + } + + keys.store("+14152222222", 1, deviceOnePreKeys); + keys.store("+14152222222", 2, deviceTwoPreKeys); + + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(100); + assertThat(keys.getCount("+14152222222", 2)).isEqualTo(100); + + List threads = new LinkedList<>(); + + for (int i=0;i<50;i++) { + Thread thread = new Thread(() -> { + for (int j=0;j<10;j++) { + try { + List results = keys.get("+14152222222"); + assertThat(results.size()).isEqualTo(2); + return; + } catch (Exception e) { + System.err.println(e.getMessage()); + } + } + throw new AssertionError(); + }); + thread.start(); + threads.add(thread); + } + + for (Thread thread : threads) { + thread.join(); + } + + assertThat(keys.getCount("+14152222222", 1)).isEqualTo(50); + assertThat(keys.getCount("+14152222222",2)).isEqualTo(50); + } + + + @Test + public void testEmptyKeyGet() { + DBI dbi = new DBI(db.getTestDatabase()); + Keys keys = dbi.onDemand(Keys.class); + + List records = keys.get("+14152222222"); + + assertThat(records.isEmpty()).isTrue(); + } + + + private void verifyStoredState(PreparedStatement statement, String number, int deviceId) throws SQLException { + statement.setString(1, number); + statement.setInt(2, deviceId); + + ResultSet resultSet = statement.executeQuery(); + int rowCount = 1; + + while (resultSet.next()) { + long keyId = resultSet.getLong("key_id"); + String publicKey = resultSet.getString("public_key"); + + + assertThat(keyId).isEqualTo(rowCount); + assertThat(publicKey).isEqualTo(number + "Device" + deviceId + "PublicKey" + rowCount); + + rowCount++; + } + + resultSet.close(); + + assertThat(rowCount).isEqualTo(101); + + } + +}