diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 7ed59ea0d..472b32db1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -336,12 +336,16 @@ public class AccountsManager { keysManager.storeEcSignedPreKeys(phoneNumberIdentifier, pniSignedPreKeys); if (pniPqLastResortPreKeys != null) { - keysManager.storePqLastResort( - phoneNumberIdentifier, - keysManager.getPqEnabledDevices(uuid).join().stream().collect( - Collectors.toMap( - Function.identity(), - pniPqLastResortPreKeys::get))); + keysManager.getPqEnabledDevices(uuid).thenCompose( + deviceIds -> keysManager.storePqLastResort( + phoneNumberIdentifier, + deviceIds.stream() + .filter(pniPqLastResortPreKeys::containsKey) + .collect( + Collectors.toMap( + Function.identity(), + pniPqLastResortPreKeys::get)))) + .join(); } }); @@ -373,11 +377,17 @@ public class AccountsManager { final UUID pni = account.getPhoneNumberIdentifier(); final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); }); - final List pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni).join(); keysManager.delete(pni); keysManager.storeEcSignedPreKeys(pni, pniSignedPreKeys).join(); - if (pniPqLastResortPreKeys != null && !pqEnabledDeviceIDs.isEmpty()) { - keysManager.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get))).join(); + if (pniPqLastResortPreKeys != null) { + keysManager.getPqEnabledDevices(pni) + .thenCompose( + deviceIds -> keysManager.storePqLastResort( + pni, + deviceIds.stream() + .filter(pniPqLastResortPreKeys::containsKey) + .collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get)))) + .join(); } return updatedAccount; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 385f6de2e..2a51fa5b6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -1118,9 +1118,13 @@ class AccountsManagerTest { final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); - when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L))); + when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L, 3L))); + when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - final List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + final List devices = List.of( + DevicesHelper.createDevice(1L, 0L, 101), + DevicesHelper.createDevice(2L, 0L, 102), + DevicesHelper.createDisabledDevice(3L, 103)); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final Account updatedAccount = accountsManager.changeNumber( account, targetNumber, new IdentityKey(Curve.generateKeyPair().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds); @@ -1306,11 +1310,7 @@ class AccountsManagerTest { when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of())); when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.storePqLastResort(any(), any())).thenAnswer( - invocation -> { - assertFalse(invocation.getArgument(1, Map.class).isEmpty()); - return CompletableFuture.completedFuture(null); - }); + when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); Map oldSignedPreKeys = account.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); @@ -1342,9 +1342,7 @@ class AccountsManagerTest { verify(keysManager).delete(oldPni); verify(keysManager).storeEcSignedPreKeys(oldPni, newSignedKeys); - - // no pq-enabled devices -> no pq last resort keys should be stored - verify(keysManager, never()).storePqLastResort(any(), any()); + verify(keysManager).storePqLastResort(any(), argThat(Map::isEmpty)); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java index 54ed18e7d..7d9a453d9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java @@ -34,6 +34,17 @@ public class DevicesHelper { return device; } + public static Device createDisabledDevice(final long deviceId, final int registrationId) { + final Device device = new Device(); + device.setId(deviceId); + device.setUserAgent("OWT"); + device.setRegistrationId(registrationId); + + setEnabled(device, false); + + return device; + } + public static void setEnabled(Device device, boolean enabled) { if (enabled) { device.setSignedPreKey(KeysHelper.signedECPreKey(RANDOM.nextLong(), Curve.generateKeyPair()));