Make `KeysManager` storage/retrieval operations asynchronous

This commit is contained in:
Jon Chambers 2023-06-26 11:17:02 -04:00 committed by Jon Chambers
parent 5847300290
commit f709b00be3
10 changed files with 204 additions and 165 deletions

View File

@ -17,6 +17,7 @@ import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@ -340,12 +341,16 @@ public class DeviceController {
keys.delete(a.getUuid(), device.getId()); keys.delete(a.getUuid(), device.getId());
keys.delete(a.getPhoneNumberIdentifier(), device.getId()); keys.delete(a.getPhoneNumberIdentifier(), device.getId());
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf(
keys.storeEcSignedPreKeys(a.getUuid(), Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())); keys.storeEcSignedPreKeys(a.getUuid(),
keys.storePqLastResort(a.getUuid(), Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())); Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())),
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())); keys.storePqLastResort(a.getUuid(),
keys.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get())); Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())),
}); keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())),
keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get())))
.join());
a.addDevice(device); a.addDevice(device);
}); });

View File

@ -25,6 +25,7 @@ import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@ -95,10 +96,13 @@ public class KeysController {
public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth, public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth,
@QueryParam("identity") final Optional<String> identityType) { @QueryParam("identity") final Optional<String> identityType) {
int ecCount = keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); final CompletableFuture<Integer> ecCountFuture =
int pqCount = keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
return new PreKeyCount(ecCount, pqCount); final CompletableFuture<Integer> pqCountFuture =
keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
return new PreKeyCount(ecCountFuture.join(), pqCountFuture.join());
} }
@Timed @Timed
@ -181,8 +185,9 @@ public class KeysController {
} }
keys.store( keys.store(
getIdentifier(account, identityType), device.getId(), getIdentifier(account, identityType), device.getId(),
preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getSignedPreKey(), preKeys.getPqLastResortPreKey()); preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getSignedPreKey(), preKeys.getPqLastResortPreKey())
.join();
} }
@Timed @Timed
@ -243,8 +248,8 @@ public class KeysController {
for (Device device : devices) { for (Device device : devices) {
UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid; UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid;
ECSignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); ECSignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null); ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).join().orElse(null);
KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null; KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).join().orElse(null) : null;
compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey), compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey),
keys.getEcSignedPreKey(identifier, device.getId())); keys.getEcSignedPreKey(identifier, device.getId()));

View File

@ -23,6 +23,7 @@ import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@ -176,10 +177,16 @@ public class RegistrationController {
registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId -> registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId())); device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
keysManager.storeEcSignedPreKeys(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey().get())); CompletableFuture.allOf(
keysManager.storePqLastResort(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get())); keysManager.storeEcSignedPreKeys(a.getUuid(),
keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey().get())); Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey().get())),
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get())); keysManager.storePqLastResort(a.getUuid(),
Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get())),
keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey().get())),
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get())))
.join();
}); });
} }

View File

@ -332,7 +332,7 @@ public class AccountsManager {
if (pniPqLastResortPreKeys != null) { if (pniPqLastResortPreKeys != null) {
keysManager.storePqLastResort( keysManager.storePqLastResort(
phoneNumberIdentifier, phoneNumberIdentifier,
keysManager.getPqEnabledDevices(uuid).stream().collect( keysManager.getPqEnabledDevices(uuid).join().stream().collect(
Collectors.toMap( Collectors.toMap(
Function.identity(), Function.identity(),
pniPqLastResortPreKeys::get))); pniPqLastResortPreKeys::get)));
@ -367,7 +367,7 @@ public class AccountsManager {
final UUID pni = account.getPhoneNumberIdentifier(); final UUID pni = account.getPhoneNumberIdentifier();
final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); }); final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); });
final List<Long> pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni); final List<Long> pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni).join();
keysManager.delete(pni); keysManager.delete(pni);
keysManager.storeEcSignedPreKeys(pni, pniSignedPreKeys); keysManager.storeEcSignedPreKeys(pni, pniSignedPreKeys);
if (pniPqLastResortPreKeys != null) { if (pniPqLastResortPreKeys != null) {

View File

@ -42,11 +42,11 @@ public class KeysManager {
this.dynamicConfigurationManager = dynamicConfigurationManager; this.dynamicConfigurationManager = dynamicConfigurationManager;
} }
public void store(final UUID identifier, final long deviceId, final List<ECPreKey> keys) { public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final List<ECPreKey> keys) {
store(identifier, deviceId, keys, null, null, null); return store(identifier, deviceId, keys, null, null, null);
} }
public void store( public CompletableFuture<Void> store(
final UUID identifier, final long deviceId, final UUID identifier, final long deviceId,
@Nullable final List<ECPreKey> ecKeys, @Nullable final List<ECPreKey> ecKeys,
@Nullable final List<KEMSignedPreKey> pqKeys, @Nullable final List<KEMSignedPreKey> pqKeys,
@ -71,12 +71,14 @@ public class KeysManager {
storeFutures.add(pqLastResortKeys.store(identifier, deviceId, pqLastResortKey)); storeFutures.add(pqLastResortKeys.store(identifier, deviceId, pqLastResortKey));
} }
CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])).join(); return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0]));
} }
public void storeEcSignedPreKeys(final UUID identifier, final Map<Long, ECSignedPreKey> keys) { public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final Map<Long, ECSignedPreKey> keys) {
if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) { if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) {
ecSignedPreKeys.store(identifier, keys).join(); return ecSignedPreKeys.store(identifier, keys);
} else {
return CompletableFuture.completedFuture(null);
} }
} }
@ -84,40 +86,40 @@ public class KeysManager {
return ecSignedPreKeys.storeIfAbsent(identifier, deviceId, signedPreKey); return ecSignedPreKeys.storeIfAbsent(identifier, deviceId, signedPreKey);
} }
public void storePqLastResort(final UUID identifier, final Map<Long, KEMSignedPreKey> keys) { public CompletableFuture<Void> storePqLastResort(final UUID identifier, final Map<Long, KEMSignedPreKey> keys) {
pqLastResortKeys.store(identifier, keys).join(); return pqLastResortKeys.store(identifier, keys);
} }
public Optional<ECPreKey> takeEC(final UUID identifier, final long deviceId) { public CompletableFuture<Optional<ECPreKey>> takeEC(final UUID identifier, final long deviceId) {
return ecPreKeys.take(identifier, deviceId).join(); return ecPreKeys.take(identifier, deviceId);
} }
public Optional<KEMSignedPreKey> takePQ(final UUID identifier, final long deviceId) { public CompletableFuture<Optional<KEMSignedPreKey>> takePQ(final UUID identifier, final long deviceId) {
return pqPreKeys.take(identifier, deviceId) return pqPreKeys.take(identifier, deviceId)
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey .thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey)) .map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
.orElseGet(() -> pqLastResortKeys.find(identifier, deviceId))).join(); .orElseGet(() -> pqLastResortKeys.find(identifier, deviceId)));
} }
@VisibleForTesting @VisibleForTesting
Optional<KEMSignedPreKey> getLastResort(final UUID identifier, final long deviceId) { CompletableFuture<Optional<KEMSignedPreKey>> getLastResort(final UUID identifier, final long deviceId) {
return pqLastResortKeys.find(identifier, deviceId).join(); return pqLastResortKeys.find(identifier, deviceId);
} }
public CompletableFuture<Optional<ECSignedPreKey>> getEcSignedPreKey(final UUID identifier, final long deviceId) { public CompletableFuture<Optional<ECSignedPreKey>> getEcSignedPreKey(final UUID identifier, final long deviceId) {
return ecSignedPreKeys.find(identifier, deviceId); return ecSignedPreKeys.find(identifier, deviceId);
} }
public List<Long> getPqEnabledDevices(final UUID identifier) { public CompletableFuture<List<Long>> getPqEnabledDevices(final UUID identifier) {
return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().block(); return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().toFuture();
} }
public int getEcCount(final UUID identifier, final long deviceId) { public CompletableFuture<Integer> getEcCount(final UUID identifier, final long deviceId) {
return ecPreKeys.getCount(identifier, deviceId).join(); return ecPreKeys.getCount(identifier, deviceId);
} }
public int getPqCount(final UUID identifier, final long deviceId) { public CompletableFuture<Integer> getPqCount(final UUID identifier, final long deviceId) {
return pqPreKeys.getCount(identifier, deviceId).join(); return pqPreKeys.getCount(identifier, deviceId);
} }
public void delete(final UUID accountUuid) { public void delete(final UUID accountUuid) {

View File

@ -645,6 +645,9 @@ class RegistrationControllerTest {
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null, Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS)))); SESSION_EXPIRATION_SECONDS))));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final UUID accountIdentifier = UUID.randomUUID(); final UUID accountIdentifier = UUID.randomUUID();
final UUID phoneNumberIdentifier = UUID.randomUUID(); final UUID phoneNumberIdentifier = UUID.randomUUID();
final Device device = mock(Device.class); final Device device = mock(Device.class);
@ -664,6 +667,8 @@ class RegistrationControllerTest {
return invocation.getArgument(0); return invocation.getArgument(0);
}); });
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()

View File

@ -33,6 +33,7 @@ import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Base64; import java.util.Base64;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -732,7 +733,7 @@ class AccountsManagerTest {
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]); final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(List.of(1L)); when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L)));
final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]);
@ -783,6 +784,8 @@ class AccountsManagerTest {
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
when(keysManager.getPqEnabledDevices(any())).thenReturn(CompletableFuture.completedFuture(Collections.emptyList()));
final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds); final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds);
// non-PNI stuff should not change // non-PNI stuff should not change
@ -825,7 +828,7 @@ class AccountsManagerTest {
UUID oldUuid = account.getUuid(); UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier(); UUID oldPni = account.getPhoneNumberIdentifier();
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L)); when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of(1L)));
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));

View File

@ -66,204 +66,208 @@ class KeysManagerTest {
@Test @Test
void testStore() { void testStore() {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Initial pre-key count for an account should be zero"); "Initial pre-key count for an account should be zero");
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Initial pre-key count for an account should be zero"); "Initial pre-key count for an account should be zero");
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(), assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent(),
"Initial last-resort pre-key for an account should be missing"); "Initial last-resort pre-key for an account should be missing");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), null, null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), null, null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Repeatedly storing same key should have no effect"); "Repeatedly storing same key should have no effect");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null, null); keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new PQ prekeys should have no effect on EC prekeys"); "Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, null, generateTestKEMSignedPreKey(1001)); keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, null, generateTestKEMSignedPreKey(1001)).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys"); "Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys"); "Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId()); assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null, null); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device"); "Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new EC prekeys should have no effect on PQ prekeys"); "Uploading new EC prekeys should have no effect on PQ prekeys");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null, null); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device"); "Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device"); "Inserting a new key should overwrite all prior keys of the same type for the given account/device");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(4), generateTestPreKey(5)), List.of(generateTestPreKey(4), generateTestPreKey(5)),
List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), null, generateTestKEMSignedPreKey(1002)); List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), null, generateTestKEMSignedPreKey(1002)).join();
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting multiple new keys should overwrite all prior keys for the given account/device"); "Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting multiple new keys should overwrite all prior keys for the given account/device"); "Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId(), assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device"); "Uploading new last-resort key should overwrite prior last-resort key for the account/device");
} }
@Test @Test
void testTakeAccountAndDeviceId() { void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID)); assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join());
final ECPreKey preKey = generateTestPreKey(1); final ECPreKey preKey = generateTestPreKey(1);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2))); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)), null, null, null).join();
final Optional<ECPreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID); final Optional<ECPreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join();
assertEquals(Optional.of(preKey), takenKey); assertEquals(Optional.of(preKey), takenKey);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
} }
@Test @Test
void testTakePQ() { void testTakePQ() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID)); assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join());
final KEMSignedPreKey preKey1 = generateTestKEMSignedPreKey(1); final KEMSignedPreKey preKey1 = generateTestKEMSignedPreKey(1);
final KEMSignedPreKey preKey2 = generateTestKEMSignedPreKey(2); final KEMSignedPreKey preKey2 = generateTestKEMSignedPreKey(2);
final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001); final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), null, preKeyLast); keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), null, preKeyLast).join();
assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(Optional.of(preKey2), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); assertEquals(Optional.of(preKey2), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
} }
@Test @Test
void testGetCount() { void testGetCount() {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null, null); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
} }
@Test @Test
void testDeleteByAccount() { void testDeleteByAccount() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID, keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)), List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)),
generateTestECSignedPreKey(5), generateTestECSignedPreKey(5),
generateTestKEMSignedPreKey(6)); generateTestKEMSignedPreKey(6))
.join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(7)), List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)), List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9), generateTestECSignedPreKey(9),
generateTestKEMSignedPreKey(10)); generateTestKEMSignedPreKey(10))
.join();
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
keysManager.delete(ACCOUNT_UUID); keysManager.delete(ACCOUNT_UUID);
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
} }
@Test @Test
void testDeleteByAccountAndDevice() { void testDeleteByAccountAndDevice() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID, keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1),generateTestPreKey(2)), List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)),
generateTestECSignedPreKey(5), generateTestECSignedPreKey(5),
generateTestKEMSignedPreKey(6)); generateTestKEMSignedPreKey(6))
.join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(7)), List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)), List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9), generateTestECSignedPreKey(9),
generateTestKEMSignedPreKey(10)); generateTestKEMSignedPreKey(10))
.join();
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
keysManager.delete(ACCOUNT_UUID, DEVICE_ID); keysManager.delete(ACCOUNT_UUID, DEVICE_ID);
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
} }
@Test @Test
void testStorePqLastResort() { void testStorePqLastResort() {
assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size()); assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size());
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.storePqLastResort( keysManager.storePqLastResort(
ACCOUNT_UUID, ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))); Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))).join();
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size()); assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId()); assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId()); assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).isPresent()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).join().isPresent());
keysManager.storePqLastResort( keysManager.storePqLastResort(
ACCOUNT_UUID, ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))); Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))).join();
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size(), "storing new last-resort keys should not create duplicates"); assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId(), "storing new last-resort keys should overwrite old ones"); assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId(), "storing new last-resort keys should leave untouched ones alone"); assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().keyId(), "storing new last-resort keys should overwrite old ones"); assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).join().get().keyId(), "storing new last-resort keys should overwrite old ones");
} }
@Test @Test
void testGetPqEnabledDevices() { void testGetPqEnabledDevices() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null); keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair)); keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair)); keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null, null); keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null, null).join();
assertIterableEquals( assertIterableEquals(
Set.of(DEVICE_ID + 1, DEVICE_ID + 2), Set.of(DEVICE_ID + 1, DEVICE_ID + 2),
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID))); Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID).join()));
} }
@Test @Test
@ -273,14 +277,15 @@ class KeysManagerTest {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.store(ACCOUNT_UUID, DEVICE_ID, keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1)), List.of(generateTestPreKey(1)),
List.of(KeysHelper.signedKEMPreKey(2, identityKeyPair)), List.of(KeysHelper.signedKEMPreKey(2, identityKeyPair)),
KeysHelper.signedECPreKey(3, identityKeyPair), KeysHelper.signedECPreKey(3, identityKeyPair),
KeysHelper.signedKEMPreKey(4, identityKeyPair)); KeysHelper.signedKEMPreKey(4, identityKeyPair))
.join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
} }

View File

@ -27,6 +27,7 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
@ -282,6 +283,9 @@ class DeviceControllerTest {
when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest("5678901", final LinkDeviceRequest request = new LinkDeviceRequest("5678901",
new AccountAttributes(fetchesMessages, 1234, null, null, true, null), new AccountAttributes(fetchesMessages, 1234, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId)); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));

View File

@ -222,15 +222,16 @@ class KeysControllerTest {
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.store(any(), anyLong(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(KEYS.getEcSignedPreKey(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(KEYS.getEcSignedPreKey(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY)); when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_KEY_PNI)); when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY_PNI)));
when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY_PNI)); when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY_PNI)));
when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(5); when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5));
when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(5); when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5));
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY); when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_DEVICE.getPhoneNumberIdentitySignedPreKey()).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY); when(AuthHelper.VALID_DEVICE.getPhoneNumberIdentitySignedPreKey()).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY);
@ -334,7 +335,7 @@ class KeysControllerTest {
@Test @Test
void validSingleRequestPqTestNoPqKeysV2() { void validSingleRequestPqTestNoPqKeysV2() {
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.empty()); when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
PreKeyResponse result = resources.getJerseyTest() PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID)) .target(String.format("/v2/keys/%s/1", EXISTS_UUID))
@ -520,10 +521,10 @@ class KeysControllerTest {
@Test @Test
void validMultiRequestTestV2() { void validMultiRequestTestV2() {
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2)); when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2)));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3)); when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4)); when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
PreKeyResponse results = resources.getJerseyTest() PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID)) .target(String.format("/v2/keys/%s/*", EXISTS_UUID))
@ -575,13 +576,15 @@ class KeysControllerTest {
@Test @Test
void validMultiRequestPqTestV2() { void validMultiRequestPqTestV2() {
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); when(KEYS.takeEC(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3)); when(KEYS.takePQ(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY)); when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_PQ_KEY2)); when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_PQ_KEY3)); when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.empty()); when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2)));
when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3)));
PreKeyResponse results = resources.getJerseyTest() PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID)) .target(String.format("/v2/keys/%s/*", EXISTS_UUID))