From 8e598c19dcb52aa78730136468869c0caf0ad858 Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer <125505367+jkt-signal@users.noreply.github.com> Date: Thu, 14 Sep 2023 11:11:22 -0700 Subject: [PATCH] don't attempt to update KEM prekeys if we have no PQ-enabled devices --- .../storage/AccountsManager.java | 2 +- .../storage/AccountsManagerTest.java | 62 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 57ee23382..4cc24c421 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -372,7 +372,7 @@ public class AccountsManager { final List pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni).join(); keysManager.delete(pni); keysManager.storeEcSignedPreKeys(pni, pniSignedPreKeys).join(); - if (pniPqLastResortPreKeys != null) { + if (pniPqLastResortPreKeys != null && !pqEnabledDeviceIDs.isEmpty()) { keysManager.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get))).join(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index b0c0f8d93..4abee6efb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -16,6 +16,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; @@ -1264,6 +1265,67 @@ class AccountsManagerTest { verify(keysManager).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L)))); } + @Test + void testPniNonPqToPqUpdate() throws MismatchedDevicesException { + final String number = "+14152222222"; + + List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]); + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + final Map newSignedKeys = Map.of( + 1L, KeysHelper.signedECPreKey(1, identityKeyPair), + 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); + final Map newSignedPqKeys = Map.of( + 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), + 2L, KeysHelper.signedKEMPreKey(4, identityKeyPair)); + Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + + UUID oldUuid = account.getUuid(); + UUID oldPni = account.getPhoneNumberIdentifier(); + + when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of())); + when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storePqLastResort(any(), any())).thenAnswer( + invocation -> { + assertFalse(invocation.getArgument(1, Map.class).isEmpty()); + return CompletableFuture.completedFuture(null); + }); + + Map oldSignedPreKeys = account.getDevices().stream() + .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); + + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + + final Account updatedAccount = + accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, newSignedPqKeys, newRegistrationIds); + + // non-PNI-keys stuff should not change + assertEquals(oldUuid, updatedAccount.getUuid()); + assertEquals(number, updatedAccount.getNumber()); + assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier()); + assertNull(updatedAccount.getIdentityKey(IdentityType.ACI)); + assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream() + .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)))); + assertEquals(Map.of(1L, 101, 2L, 102), + updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); + + // PNI keys should + assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI)); + assertEquals(newSignedKeys, + updatedAccount.getDevices().stream() + .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.PNI)))); + assertEquals(newRegistrationIds, + updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); + + verify(accounts).update(any()); + + verify(keysManager).delete(oldPni); + verify(keysManager).storeEcSignedPreKeys(oldPni, newSignedKeys); + + // no pq-enabled devices -> no pq last resort keys should be stored + verify(keysManager, never()).storePqLastResort(any(), any()); + } + @Test void testPniUpdate_incompleteKeys() { final String number = "+14152222222";