From 5e221fa9a33820f826e54916401a1689e1270f5c Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Wed, 30 Aug 2023 17:07:33 -0400 Subject: [PATCH] Tests for validation of Kyber keys on PNI change/key distribution events Co-authored-by: Jonathan Klabunde Tomer --- .../storage/AccountsManager.java | 2 +- .../storage/AccountsManagerTest.java | 89 +++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index c88165dc2..dc7e7197e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -441,7 +441,7 @@ public class AccountsManager { pniPqLastResortPreKeys.keySet(), Collections.emptySet()); } - + // Check that all devices are accounted for in the map of new PNI registration IDs DestinationDeviceValidator.validateCompleteDeviceList( account, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 4e10a4977..f82c67600 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -1125,6 +1125,37 @@ class AccountsManagerTest { verifyNoMoreInteractions(keysManager); } + + @Test + void testChangePhoneNumberWithMismatchedPqKeys() throws InterruptedException, MismatchedDevicesException { + final String originalNumber = "+14152222222"; + final String targetNumber = "+14153333333"; + final UUID existingAccountUuid = UUID.randomUUID(); + final UUID uuid = UUID.randomUUID(); + final UUID originalPni = UUID.randomUUID(); + final UUID targetPni = UUID.randomUUID(); + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + final Map newSignedKeys = Map.of( + 1L, KeysHelper.signedECPreKey(1, identityKeyPair), + 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); + final Map newSignedPqKeys = Map.of( + 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair)); + final Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + + final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]); + when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); + when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L))); + + final List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]); + assertThrows(MismatchedDevicesException.class, + () -> accountsManager.changeNumber( + account, targetNumber, new IdentityKey(Curve.generateKeyPair().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds)); + + verifyNoInteractions(accounts); + verifyNoInteractions(keysManager); + } + @Test void testChangePhoneNumberViaUpdate() { final String originalNumber = "+14152222222"; @@ -1242,6 +1273,64 @@ class AccountsManagerTest { verify(keysManager).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L)))); } + @Test + void testPniUpdate_incompleteKeys() { + final String number = "+14152222222"; + + List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]); + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + final Map newSignedKeys = Map.of( + 2L, KeysHelper.signedECPreKey(1, identityKeyPair), + 3L, KeysHelper.signedECPreKey(2, identityKeyPair)); + Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + + UUID oldUuid = account.getUuid(); + UUID oldPni = account.getPhoneNumberIdentifier(); + + Map oldSignedPreKeys = account.getDevices().stream() + .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); + + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + + assertThrows(MismatchedDevicesException.class, + () -> accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds)); + + verifyNoInteractions(accounts); + verifyNoInteractions(deletedAccounts); + verifyNoInteractions(keysManager); + } + + @Test + void testPniPqUpdate_incompleteKeys() { + final String number = "+14152222222"; + + List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]); + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + final Map newSignedKeys = Map.of( + 1L, KeysHelper.signedECPreKey(1, identityKeyPair), + 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); + final Map newSignedPqKeys = Map.of( + 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair)); + Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + + UUID oldUuid = account.getUuid(); + UUID oldPni = account.getPhoneNumberIdentifier(); + + Map oldSignedPreKeys = account.getDevices().stream() + .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); + + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + + assertThrows(MismatchedDevicesException.class, + () -> accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, newSignedPqKeys, newRegistrationIds)); + + verifyNoInteractions(accounts); + verifyNoInteractions(deletedAccounts); + verifyNoInteractions(keysManager); + } + @Test void testReserveUsernameHash() throws UsernameHashNotAvailableException { final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);