diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 07250cbd6..dde0b9ba1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -17,6 +17,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.ws.rs.Consumes; @@ -340,12 +341,16 @@ public class DeviceController { keys.delete(a.getUuid(), device.getId()); keys.delete(a.getPhoneNumberIdentifier(), device.getId()); - maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { - keys.storeEcSignedPreKeys(a.getUuid(), Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())); - keys.storePqLastResort(a.getUuid(), Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())); - keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())); - keys.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get())); - }); + maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf( + keys.storeEcSignedPreKeys(a.getUuid(), + Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())), + keys.storePqLastResort(a.getUuid(), + Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())), + keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), + Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())), + keys.storePqLastResort(a.getPhoneNumberIdentifier(), + Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get()))) + .join()); a.addDevice(device); }); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 7288ba1bf..1bd650ff3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.ws.rs.Consumes; @@ -95,10 +96,13 @@ public class KeysController { public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth, @QueryParam("identity") final Optional identityType) { - int ecCount = keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); - int pqCount = keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); + final CompletableFuture ecCountFuture = + keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); - return new PreKeyCount(ecCount, pqCount); + final CompletableFuture pqCountFuture = + keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); + + return new PreKeyCount(ecCountFuture.join(), pqCountFuture.join()); } @Timed @@ -181,8 +185,9 @@ public class KeysController { } keys.store( - getIdentifier(account, identityType), device.getId(), - preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getSignedPreKey(), preKeys.getPqLastResortPreKey()); + getIdentifier(account, identityType), device.getId(), + preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getSignedPreKey(), preKeys.getPqLastResortPreKey()) + .join(); } @Timed @@ -243,8 +248,8 @@ public class KeysController { for (Device device : devices) { UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid; ECSignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); - ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null); - KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null; + ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).join().orElse(null); + KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).join().orElse(null) : null; compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey), keys.getEcSignedPreKey(identifier, device.getId())); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java index 9db861b4c..4d600d02f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java @@ -23,6 +23,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.ws.rs.Consumes; @@ -176,10 +177,16 @@ public class RegistrationController { registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId -> device.setGcmId(gcmRegistrationId.gcmRegistrationId())); - keysManager.storeEcSignedPreKeys(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey().get())); - keysManager.storePqLastResort(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get())); - keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey().get())); - keysManager.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get())); + CompletableFuture.allOf( + keysManager.storeEcSignedPreKeys(a.getUuid(), + Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey().get())), + keysManager.storePqLastResort(a.getUuid(), + Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get())), + keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), + Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey().get())), + keysManager.storePqLastResort(a.getPhoneNumberIdentifier(), + Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get()))) + .join(); }); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 586309ee5..a6ef4e2a7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -332,7 +332,7 @@ public class AccountsManager { if (pniPqLastResortPreKeys != null) { keysManager.storePqLastResort( phoneNumberIdentifier, - keysManager.getPqEnabledDevices(uuid).stream().collect( + keysManager.getPqEnabledDevices(uuid).join().stream().collect( Collectors.toMap( Function.identity(), pniPqLastResortPreKeys::get))); @@ -367,7 +367,7 @@ public class AccountsManager { final UUID pni = account.getPhoneNumberIdentifier(); final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); }); - final List pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni); + final List pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni).join(); keysManager.delete(pni); keysManager.storeEcSignedPreKeys(pni, pniSignedPreKeys); if (pniPqLastResortPreKeys != null) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index 419281656..c79020a49 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -42,11 +42,11 @@ public class KeysManager { this.dynamicConfigurationManager = dynamicConfigurationManager; } - public void store(final UUID identifier, final long deviceId, final List keys) { - store(identifier, deviceId, keys, null, null, null); + public CompletableFuture store(final UUID identifier, final long deviceId, final List keys) { + return store(identifier, deviceId, keys, null, null, null); } - public void store( + public CompletableFuture store( final UUID identifier, final long deviceId, @Nullable final List ecKeys, @Nullable final List pqKeys, @@ -71,12 +71,14 @@ public class KeysManager { storeFutures.add(pqLastResortKeys.store(identifier, deviceId, pqLastResortKey)); } - CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])).join(); + return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])); } - public void storeEcSignedPreKeys(final UUID identifier, final Map keys) { + public CompletableFuture storeEcSignedPreKeys(final UUID identifier, final Map keys) { if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) { - ecSignedPreKeys.store(identifier, keys).join(); + return ecSignedPreKeys.store(identifier, keys); + } else { + return CompletableFuture.completedFuture(null); } } @@ -84,40 +86,40 @@ public class KeysManager { return ecSignedPreKeys.storeIfAbsent(identifier, deviceId, signedPreKey); } - public void storePqLastResort(final UUID identifier, final Map keys) { - pqLastResortKeys.store(identifier, keys).join(); + public CompletableFuture storePqLastResort(final UUID identifier, final Map keys) { + return pqLastResortKeys.store(identifier, keys); } - public Optional takeEC(final UUID identifier, final long deviceId) { - return ecPreKeys.take(identifier, deviceId).join(); + public CompletableFuture> takeEC(final UUID identifier, final long deviceId) { + return ecPreKeys.take(identifier, deviceId); } - public Optional takePQ(final UUID identifier, final long deviceId) { + public CompletableFuture> takePQ(final UUID identifier, final long deviceId) { return pqPreKeys.take(identifier, deviceId) .thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey .map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey)) - .orElseGet(() -> pqLastResortKeys.find(identifier, deviceId))).join(); + .orElseGet(() -> pqLastResortKeys.find(identifier, deviceId))); } @VisibleForTesting - Optional getLastResort(final UUID identifier, final long deviceId) { - return pqLastResortKeys.find(identifier, deviceId).join(); + CompletableFuture> getLastResort(final UUID identifier, final long deviceId) { + return pqLastResortKeys.find(identifier, deviceId); } public CompletableFuture> getEcSignedPreKey(final UUID identifier, final long deviceId) { return ecSignedPreKeys.find(identifier, deviceId); } - public List getPqEnabledDevices(final UUID identifier) { - return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().block(); + public CompletableFuture> getPqEnabledDevices(final UUID identifier) { + return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().toFuture(); } - public int getEcCount(final UUID identifier, final long deviceId) { - return ecPreKeys.getCount(identifier, deviceId).join(); + public CompletableFuture getEcCount(final UUID identifier, final long deviceId) { + return ecPreKeys.getCount(identifier, deviceId); } - public int getPqCount(final UUID identifier, final long deviceId) { - return pqPreKeys.getCount(identifier, deviceId).join(); + public CompletableFuture getPqCount(final UUID identifier, final long deviceId) { + return pqPreKeys.getCount(identifier, deviceId); } public void delete(final UUID accountUuid) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index d775533f8..e572aa175 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -645,6 +645,9 @@ class RegistrationControllerTest { Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null, SESSION_EXPIRATION_SECONDS)))); + when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + final UUID accountIdentifier = UUID.randomUUID(); final UUID phoneNumberIdentifier = UUID.randomUUID(); final Device device = mock(Device.class); @@ -664,6 +667,8 @@ class RegistrationControllerTest { return invocation.getArgument(0); }); + when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 0ba7183db..94f6db25b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -33,6 +33,7 @@ import java.time.Clock; import java.time.Duration; import java.util.ArrayList; import java.util.Base64; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -732,7 +733,7 @@ class AccountsManagerTest { final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); - when(keysManager.getPqEnabledDevices(uuid)).thenReturn(List.of(1L)); + when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L))); final List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]); @@ -783,6 +784,8 @@ class AccountsManagerTest { final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + when(keysManager.getPqEnabledDevices(any())).thenReturn(CompletableFuture.completedFuture(Collections.emptyList())); + final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds); // non-PNI stuff should not change @@ -825,7 +828,7 @@ class AccountsManagerTest { UUID oldUuid = account.getUuid(); UUID oldPni = account.getPhoneNumberIdentifier(); - when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L)); + when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of(1L))); Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java index f2b0b6a4b..207864fc0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java @@ -66,204 +66,208 @@ class KeysManagerTest { @Test void testStore() { - assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), "Initial pre-key count for an account should be zero"); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), "Initial pre-key count for an account should be zero"); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(), + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent(), "Initial last-resort pre-key for an account should be missing"); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), null, null, null).join(); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), null, null, null).join(); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), "Repeatedly storing same key should have no effect"); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null, null); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null, null).join(); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), "Uploading new PQ prekeys should have no effect on EC prekeys"); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, null, generateTestKEMSignedPreKey(1001)); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, null, generateTestKEMSignedPreKey(1001)).join(); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), "Uploading new PQ last-resort prekey should have no effect on EC prekeys"); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), "Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys"); - assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId()); + assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId()); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null, null); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null, null).join(); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), "Uploading new EC prekeys should have no effect on PQ prekeys"); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null, null); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null, null).join(); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(4), generateTestPreKey(5)), - List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), null, generateTestKEMSignedPreKey(1002)); - assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), null, generateTestKEMSignedPreKey(1002)).join(); + assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), "Inserting multiple new keys should overwrite all prior keys for the given account/device"); - assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), + assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), "Inserting multiple new keys should overwrite all prior keys for the given account/device"); - assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId(), + assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(), "Uploading new last-resort key should overwrite prior last-resort key for the account/device"); } @Test void testTakeAccountAndDeviceId() { - assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join()); final ECPreKey preKey = generateTestPreKey(1); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2))); - final Optional takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID); + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)), null, null, null).join(); + final Optional takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join(); assertEquals(Optional.of(preKey), takenKey); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); } @Test void testTakePQ() { - assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join()); final KEMSignedPreKey preKey1 = generateTestKEMSignedPreKey(1); final KEMSignedPreKey preKey2 = generateTestKEMSignedPreKey(2); final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), null, preKeyLast); + keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), null, preKeyLast).join(); - assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertEquals(Optional.of(preKey2), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(Optional.of(preKey2), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); } @Test void testGetCount() { - assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null, null); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null, null).join(); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); } @Test void testDeleteByAccount() { keysManager.store(ACCOUNT_UUID, DEVICE_ID, - List.of(generateTestPreKey(1), generateTestPreKey(2)), - List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), - generateTestECSignedPreKey(5), - generateTestKEMSignedPreKey(6)); + List.of(generateTestPreKey(1), generateTestPreKey(2)), + List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), + generateTestECSignedPreKey(5), + generateTestKEMSignedPreKey(6)) + .join(); keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, - List.of(generateTestPreKey(7)), - List.of(generateTestKEMSignedPreKey(8)), - generateTestECSignedPreKey(9), - generateTestKEMSignedPreKey(10)); + List.of(generateTestPreKey(7)), + List.of(generateTestKEMSignedPreKey(8)), + generateTestECSignedPreKey(9), + generateTestKEMSignedPreKey(10)) + .join(); - assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); keysManager.delete(ACCOUNT_UUID); - assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); - assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); } @Test void testDeleteByAccountAndDevice() { keysManager.store(ACCOUNT_UUID, DEVICE_ID, - List.of(generateTestPreKey(1),generateTestPreKey(2)), - List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), - generateTestECSignedPreKey(5), - generateTestKEMSignedPreKey(6)); + List.of(generateTestPreKey(1), generateTestPreKey(2)), + List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), + generateTestECSignedPreKey(5), + generateTestKEMSignedPreKey(6)) + .join(); keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, - List.of(generateTestPreKey(7)), - List.of(generateTestKEMSignedPreKey(8)), - generateTestECSignedPreKey(9), - generateTestKEMSignedPreKey(10)); + List.of(generateTestPreKey(7)), + List.of(generateTestKEMSignedPreKey(8)), + generateTestECSignedPreKey(9), + generateTestKEMSignedPreKey(10)) + .join(); - assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); keysManager.delete(ACCOUNT_UUID, DEVICE_ID); - assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); } @Test void testStorePqLastResort() { - assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size()); + assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size()); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); keysManager.storePqLastResort( ACCOUNT_UUID, - Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))); - assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size()); - assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId()); - assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId()); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).isPresent()); + Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))).join(); + assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size()); + assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId()); + assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId()); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).join().isPresent()); keysManager.storePqLastResort( ACCOUNT_UUID, - Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))); - assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size(), "storing new last-resort keys should not create duplicates"); - assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId(), "storing new last-resort keys should overwrite old ones"); - assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId(), "storing new last-resort keys should leave untouched ones alone"); - assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().keyId(), "storing new last-resort keys should overwrite old ones"); + Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))).join(); + assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing new last-resort keys should not create duplicates"); + assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId(), "storing new last-resort keys should overwrite old ones"); + assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId(), "storing new last-resort keys should leave untouched ones alone"); + assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).join().get().keyId(), "storing new last-resort keys should overwrite old ones"); } @Test void testGetPqEnabledDevices() { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null); - keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair)); - keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair)); - keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null, null); + keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null).join(); + keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join(); + keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join(); + keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null, null).join(); assertIterableEquals( Set.of(DEVICE_ID + 1, DEVICE_ID + 2), - Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID))); + Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID).join())); } @Test @@ -273,14 +277,15 @@ class KeysManagerTest { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); keysManager.store(ACCOUNT_UUID, DEVICE_ID, - List.of(generateTestPreKey(1)), - List.of(KeysHelper.signedKEMPreKey(2, identityKeyPair)), - KeysHelper.signedECPreKey(3, identityKeyPair), - KeysHelper.signedKEMPreKey(4, identityKeyPair)); + List.of(generateTestPreKey(1)), + List.of(KeysHelper.signedKEMPreKey(2, identityKeyPair)), + KeysHelper.signedECPreKey(3, identityKeyPair), + KeysHelper.signedKEMPreKey(4, identityKeyPair)) + .join(); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index 1f34c2b1c..35b35a2e3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -27,6 +27,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; import javax.ws.rs.Path; import javax.ws.rs.client.Entity; @@ -282,6 +283,9 @@ class DeviceControllerTest { when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); + when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + final LinkDeviceRequest request = new LinkDeviceRequest("5678901", new AccountAttributes(fetchesMessages, 1234, null, null, true, null), new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index ad94fddc7..a3133e1cb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -222,15 +222,16 @@ class KeysControllerTest { when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); + when(KEYS.store(any(), anyLong(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(KEYS.getEcSignedPreKey(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); - when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); - when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY)); - when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_KEY_PNI)); - when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY_PNI)); + when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); + when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); + when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY_PNI))); + when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY_PNI))); - when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(5); - when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(5); + when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5)); + when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5)); when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY); when(AuthHelper.VALID_DEVICE.getPhoneNumberIdentitySignedPreKey()).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY); @@ -334,7 +335,7 @@ class KeysControllerTest { @Test void validSingleRequestPqTestNoPqKeysV2() { - when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.empty()); + when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.empty())); PreKeyResponse result = resources.getJerseyTest() .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) @@ -520,10 +521,10 @@ class KeysControllerTest { @Test void validMultiRequestTestV2() { - when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); - when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2)); - when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3)); - when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4)); + when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); + when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2))); + when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3))); + when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4))); PreKeyResponse results = resources.getJerseyTest() .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) @@ -575,13 +576,15 @@ class KeysControllerTest { @Test void validMultiRequestPqTestV2() { - when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); - when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3)); - when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4)); - when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY)); - when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_PQ_KEY2)); - when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_PQ_KEY3)); - when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.empty()); + when(KEYS.takeEC(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(KEYS.takePQ(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); + when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3))); + when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4))); + when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); + when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2))); + when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3))); PreKeyResponse results = resources.getJerseyTest() .target(String.format("/v2/keys/%s/*", EXISTS_UUID))