diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java index a35b6b057..368d98f62 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java @@ -40,6 +40,7 @@ import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest; import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest; +import org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -138,6 +139,54 @@ public class AccountControllerV2 { } } + @Timed + @PUT + @Path("/phone_number_identity_key_distribution") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + @Operation(summary = "Updates key material for the phone-number identity for all devices and sends a synchronization message to companion devices") + public AccountIdentityResponse distributePhoneNumberIdentityKeys(@Auth final AuthenticatedAccount authenticatedAccount, + @NotNull @Valid final PhoneNumberIdentityKeyDistributionRequest request) { + + if (!authenticatedAccount.getAuthenticatedDevice().isMaster()) { + throw new ForbiddenException(); + } + + final Account account = authenticatedAccount.getAccount(); + if (!account.isPniSupported()) { + throw new WebApplicationException(Response.status(425).build()); + } + + try { + final Account updatedAccount = changeNumberManager.updatePNIKeys( + authenticatedAccount.getAccount(), + request.pniIdentityKey(), + request.devicePniSignedPrekeys(), + request.deviceMessages(), + request.pniRegistrationIds()); + + return new AccountIdentityResponse( + updatedAccount.getUuid(), + updatedAccount.getNumber(), + updatedAccount.getPhoneNumberIdentifier(), + updatedAccount.getUsernameHash().orElse(null), + updatedAccount.isStorageSupported()); + } catch (MismatchedDevicesException e) { + throw new WebApplicationException(Response.status(409) + .type(MediaType.APPLICATION_JSON_TYPE) + .entity(new MismatchedDevices(e.getMissingDevices(), + e.getExtraDevices())) + .build()); + } catch (StaleDevicesException e) { + throw new WebApplicationException(Response.status(410) + .type(MediaType.APPLICATION_JSON) + .entity(new StaleDevices(e.getStaleDevices())) + .build()); + } catch (IllegalArgumentException e) { + throw new BadRequestException(e); + } + } + @Timed @PUT @Path("/phone_number_discoverability") diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountIdentityResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountIdentityResponse.java index a5f44596b..ec9e186aa 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountIdentityResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountIdentityResponse.java @@ -8,14 +8,25 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter; +import io.swagger.v3.oas.annotations.media.Schema; import java.util.UUID; import javax.annotation.Nullable; -public record AccountIdentityResponse(UUID uuid, - String number, - UUID pni, - @JsonSerialize(using = ByteArrayBase64UrlAdapter.Serializing.class) - @JsonDeserialize(using = ByteArrayBase64UrlAdapter.Deserializing.class) - @Nullable byte[] usernameHash, - boolean storageCapable) { +public record AccountIdentityResponse( + @Schema(description="the account identifier for this account") + UUID uuid, + + @Schema(description="the phone number associated with this account") + String number, + + @Schema(description="the account identifier for this account's phone-number identity") + UUID pni, + + @Schema(description="a hash of this account's username, if set") + @JsonSerialize(using = ByteArrayBase64UrlAdapter.Serializing.class) + @JsonDeserialize(using = ByteArrayBase64UrlAdapter.Deserializing.class) + @Nullable byte[] usernameHash, + + @Schema(description="whether any of this account's devices support storage") + boolean storageCapable) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java new file mode 100644 index 000000000..f32fc04ea --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import io.swagger.v3.oas.annotations.media.Schema; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import javax.validation.Valid; +import javax.validation.constraints.NotBlank; +import javax.validation.constraints.NotNull; +import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; + +public record PhoneNumberIdentityKeyDistributionRequest( + @NotBlank + @Schema(description="the new identity key for this account's phone-number identity") + String pniIdentityKey, + + @NotNull + @Valid + @Schema(description="A message for each companion device to pass its new private keys") + List<@NotNull @Valid IncomingMessage> deviceMessages, + + @NotNull + @Valid + @Schema(description="The public key of a new signed elliptic-curve prekey pair for each device") + Map devicePniSignedPrekeys, + + @NotNull + @Valid + @Schema(description="The new registration ID to use for the phone-number identity of each device") + Map pniRegistrationIds) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 4c216070d..b18fc3bb5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -28,13 +28,17 @@ import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Collectors; import javax.annotation.Nullable; + +import org.apache.commons.lang3.ObjectUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; @@ -255,24 +259,13 @@ public class AccountsManager { final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier(); if (originalNumber.equals(number)) { + if (pniIdentityKey != null) { + throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePNIKeys"); + } return account; } - if (pniSignedPreKeys != null && pniRegistrationIds != null) { - // Check that all including master ID are in signed pre-keys - DestinationDeviceValidator.validateCompleteDeviceList( - account, - pniSignedPreKeys.keySet(), - Collections.emptySet()); - - // Check that all devices are accounted for in the map of new PNI registration IDs - DestinationDeviceValidator.validateCompleteDeviceList( - account, - pniRegistrationIds.keySet(), - Collections.emptySet()); - } else if (pniSignedPreKeys != null || pniRegistrationIds != null) { - throw new IllegalArgumentException("Signed pre-keys and registration IDs must both be null or both be non-null"); - } + validateDevices(account, pniSignedPreKeys, pniRegistrationIds); final AtomicReference updatedAccount = new AtomicReference<>(); @@ -297,22 +290,7 @@ public class AccountsManager { numberChangedAccount = updateWithRetries( account, - a -> { - //noinspection ConstantConditions - if (pniSignedPreKeys != null && pniRegistrationIds != null) { - pniSignedPreKeys.forEach((deviceId, signedPreKey) -> - a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey))); - - pniRegistrationIds.forEach((deviceId, registrationId) -> - a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId))); - } - - if (pniIdentityKey != null) { - a.setPhoneNumberIdentityKey(pniIdentityKey); - } - - return true; - }, + a -> setPNIKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds), a -> accounts.changeNumber(a, number, phoneNumberIdentifier), () -> accounts.getByAccountIdentifier(uuid).orElseThrow(), AccountChangeValidator.NUMBER_CHANGE_VALIDATOR); @@ -329,6 +307,58 @@ public class AccountsManager { return updatedAccount.get(); } + public Account updatePNIKeys(final Account account, + final String pniIdentityKey, + final Map pniSignedPreKeys, + final Map pniRegistrationIds) throws MismatchedDevicesException { + validateDevices(account, pniSignedPreKeys, pniRegistrationIds); + + return update(account, a -> { return setPNIKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); }); + } + + private boolean setPNIKeys(final Account account, + @Nullable final String pniIdentityKey, + @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniRegistrationIds) { + if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { + return true; + } else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { + throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null"); + } + + pniSignedPreKeys.forEach((deviceId, signedPreKey) -> + account.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey))); + + pniRegistrationIds.forEach((deviceId, registrationId) -> + account.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId))); + + account.setPhoneNumberIdentityKey(pniIdentityKey); + + return true; + } + + private void validateDevices(final Account account, + final Map pniSignedPreKeys, + final Map pniRegistrationIds) throws MismatchedDevicesException { + if (pniSignedPreKeys == null && pniRegistrationIds == null) { + return; + } else if (pniSignedPreKeys == null || pniRegistrationIds == null) { + throw new IllegalArgumentException("Signed pre-keys and registration IDs must both be null or both be non-null"); + } + + // Check that all including master ID are in signed pre-keys + DestinationDeviceValidator.validateCompleteDeviceList( + account, + pniSignedPreKeys.keySet(), + Collections.emptySet()); + + // Check that all devices are accounted for in the map of new PNI registration IDs + DestinationDeviceValidator.validateCompleteDeviceList( + account, + pniRegistrationIds.keySet(), + Collections.emptySet()); + } + public record UsernameReservation(Account account, byte[] reservedUsernameHash){} /** diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index 410071fde..c83ff1831 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.storage; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import com.google.protobuf.ByteString; import org.apache.commons.lang3.ObjectUtils; import org.slf4j.Logger; @@ -46,48 +47,70 @@ public class ChangeNumberManager { throws InterruptedException, MismatchedDevicesException, StaleDevicesException { if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) { - assert pniIdentityKey != null; - assert deviceSignedPreKeys != null; - assert deviceMessages != null; - assert pniRegistrationIds != null; - - // Check that all except master ID are in device messages - DestinationDeviceValidator.validateCompleteDeviceList( - account, - deviceMessages.stream().map(IncomingMessage::destinationDeviceId).collect(Collectors.toSet()), - Set.of(Device.MASTER_ID)); - - DestinationDeviceValidator.validateRegistrationIds( - account, - deviceMessages, - IncomingMessage::destinationDeviceId, - IncomingMessage::destinationRegistrationId, - false); + // AccountsManager validates the device set on deviceSignedPreKeys and pniRegistrationIds + validateDeviceMessages(account, deviceMessages); } else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) { throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null"); } - final Account updatedAccount; if (number.equals(account.getNumber())) { - // This may be a request that got repeated due to poor network conditions or other client error; take no action, - // but report success since the account is in the desired state - updatedAccount = account; - } else { - updatedAccount = accountsManager.changeNumber(account, number, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds); + // The client has gotten confused/desynchronized with us about their own phone number, most likely due to losing + // our OK response to an immediately preceding change-number request, and are sending a change they don't realize + // is a no-op change. + // + // We don't need to actually do a number-change operation in our DB, but we *do* need to accept their new key + // material and distribute the sync messages, to be sure all clients agree with us and each other about what their + // keys are. Pretend this change-number request was actually a PNI key distribution request. + return updatePNIKeys(account, pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds); } - // Whether the account already has this number or not, we resend messages. This makes it so the client can resend a - // request they didn't get a response for (timeout, etc) to make sure their messages sent even if the first time - // around the server crashed at/above this point. + final Account updatedAccount = accountsManager.changeNumber(account, number, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds); + if (deviceMessages != null) { - deviceMessages.forEach(message -> - sendMessageToSelf(updatedAccount, updatedAccount.getDevice(message.destinationDeviceId()), message)); + sendDeviceMessages(updatedAccount, deviceMessages); } return updatedAccount; } + public Account updatePNIKeys(final Account account, + final String pniIdentityKey, + final Map deviceSignedPreKeys, + final List deviceMessages, + final Map pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException { + validateDeviceMessages(account, deviceMessages); + + // Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb + // write anyway. Linked devices can handle some wasted extra key rotations. + final Account updatedAccount = accountsManager.updatePNIKeys(account, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds); + + sendDeviceMessages(updatedAccount, deviceMessages); + return updatedAccount; + } + + private void validateDeviceMessages(final Account account, + final List deviceMessages) throws MismatchedDevicesException, StaleDevicesException { + // Check that all except master ID are in device messages + DestinationDeviceValidator.validateCompleteDeviceList( + account, + deviceMessages.stream().map(IncomingMessage::destinationDeviceId).collect(Collectors.toSet()), + Set.of(Device.MASTER_ID)); + + // check that all sync messages are to the current registration ID for the matching device + DestinationDeviceValidator.validateRegistrationIds( + account, + deviceMessages, + IncomingMessage::destinationDeviceId, + IncomingMessage::destinationRegistrationId, + false); + } + + private void sendDeviceMessages(final Account account, final List deviceMessages) { + deviceMessages.forEach(message -> + sendMessageToSelf(account, account.getDevice(message.destinationDeviceId()), message)); + } + @VisibleForTesting void sendMessageToSelf( Account sourceAndDestinationAccount, Optional destinationDevice, IncomingMessage message) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index d2d517f14..0a228934e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -323,13 +323,38 @@ class AccountControllerTest { final String pniIdentityKey = invocation.getArgument(2, String.class); final UUID uuid = account.getUuid(); + final UUID pni = number.equals(account.getNumber()) ? account.getPhoneNumberIdentifier() : UUID.randomUUID(); final List devices = account.getDevices(); final Account updatedAccount = mock(Account.class); when(updatedAccount.getUuid()).thenReturn(uuid); when(updatedAccount.getNumber()).thenReturn(number); when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey); - when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID()); + when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(pni); + when(updatedAccount.getDevices()).thenReturn(devices); + + for (long i = 1; i <= 3; i++) { + final Optional d = account.getDevice(i); + when(updatedAccount.getDevice(i)).thenReturn(d); + } + + return updatedAccount; + }); + + when(changeNumberManager.updatePNIKeys(any(), any(), any(), any(), any())).thenAnswer((Answer) invocation -> { + final Account account = invocation.getArgument(0, Account.class); + final String pniIdentityKey = invocation.getArgument(1, String.class); + + final String number = account.getNumber(); + final UUID uuid = account.getUuid(); + final UUID pni = account.getPhoneNumberIdentifier(); + final List devices = account.getDevices(); + + final Account updatedAccount = mock(Account.class); + when(updatedAccount.getNumber()).thenReturn(number); + when(updatedAccount.getUuid()).thenReturn(uuid); + when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey); + when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(pni); when(updatedAccount.getDevices()).thenReturn(devices); for (long i = 1; i <= 3; i++) { @@ -1646,6 +1671,61 @@ class AccountControllerTest { assertThat(accountIdentityResponse.pni()).isNotEqualTo(AuthHelper.VALID_PNI); } + @Test + void testChangePhoneNumberSameNumberChangePrekeys() throws Exception { + final String code = "987654"; + final String pniIdentityKey = "changed-pni-identity-key"; + final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8); + + Device device2 = mock(Device.class); + when(device2.getId()).thenReturn(2L); + when(device2.isEnabled()).thenReturn(true); + when(device2.getRegistrationId()).thenReturn(2); + + Device device3 = mock(Device.class); + when(device3.getId()).thenReturn(3L); + when(device3.isEnabled()).thenReturn(true); + when(device3.getRegistrationId()).thenReturn(3); + + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3)); + when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2)); + when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3)); + + when(pendingAccountsManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn( + Optional.of(new StoredVerificationCode(null, System.currentTimeMillis(), "push", sessionId))); + + when(registrationServiceClient.checkVerificationCode(any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(true)); + + var deviceMessages = List.of( + new IncomingMessage(1, 2, 2, "content2"), + new IncomingMessage(1, 3, 3, "content3")); + var deviceKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey()); + + final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); + + final AccountIdentityResponse accountIdentityResponse = + resources.getJerseyTest() + .target("/v1/accounts/number") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(new ChangePhoneNumberRequest( + AuthHelper.VALID_NUMBER, code, null, + pniIdentityKey, deviceMessages, + deviceKeys, + registrationIds), + MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); + + verify(changeNumberManager).changeNumber( + eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), any()); + verifyNoInteractions(rateLimiter); + verifyNoInteractions(pendingAccountsManager); + + assertThat(accountIdentityResponse.uuid()).isEqualTo(AuthHelper.VALID_UUID); + assertThat(accountIdentityResponse.number()).isEqualTo(AuthHelper.VALID_NUMBER); + assertThat(accountIdentityResponse.pni()).isEqualTo(AuthHelper.VALID_PNI); + } + @Test void testSetRegistrationLock() { Response response = diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java index 6009a2034..246fb817b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java @@ -71,6 +71,7 @@ import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse; import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest; import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest; +import org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest; import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.limits.RateLimiter; @@ -147,7 +148,11 @@ class AccountControllerV2Test { when(updatedAccount.getUuid()).thenReturn(uuid); when(updatedAccount.getNumber()).thenReturn(number); when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey); - when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID()); + if (number.equals(account.getNumber())) { + when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI); + } else { + when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID()); + } when(updatedAccount.getDevices()).thenReturn(devices); for (long i = 1; i <= 3; i++) { @@ -187,6 +192,29 @@ class AccountControllerV2Test { assertNotEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni()); } + @Test + void changeNumberSameNumber() throws Exception { + final AccountIdentityResponse accountIdentityResponse = + resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity( + new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null, + "pni-identity-key", + Collections.emptyList(), + Collections.emptyMap(), Collections.emptyMap()), + MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); + + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), + any()); + + assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); + assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number()); + assertEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni()); + } + @Test void unprocessableRequestJson() { final Invocation.Builder request = resources.getJerseyTest() @@ -426,6 +454,144 @@ class AccountControllerV2Test { } } + @Nested + class PhoneNumberIdentityKeyDistribution { + + @BeforeEach + void setUp() throws Exception { + when(changeNumberManager.updatePNIKeys(any(), any(), any(), any(), any())).thenAnswer( + (Answer) invocation -> { + final Account account = invocation.getArgument(0, Account.class); + final String pniIdentityKey = invocation.getArgument(1, String.class); + + final UUID uuid = account.getUuid(); + final UUID pni = account.getPhoneNumberIdentifier(); + final String number = account.getNumber(); + final List devices = account.getDevices(); + + final Account updatedAccount = mock(Account.class); + when(updatedAccount.getUuid()).thenReturn(uuid); + when(updatedAccount.getNumber()).thenReturn(number); + when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey); + when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(pni); + when(updatedAccount.getDevices()).thenReturn(devices); + + for (long i = 1; i <= 3; i++) { + final Optional d = account.getDevice(i); + when(updatedAccount.getDevice(i)).thenReturn(d); + } + + return updatedAccount; + }); + } + + @Test + void pniKeyDistributionSuccess() throws Exception { + when(AuthHelper.VALID_ACCOUNT.isPniSupported()).thenReturn(true); + + final AccountIdentityResponse accountIdentityResponse = + resources.getJerseyTest() + .target("/v2/accounts/phone_number_identity_key_distribution") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.json(requestJson()), AccountIdentityResponse.class); + + verify(changeNumberManager).updatePNIKeys(eq(AuthHelper.VALID_ACCOUNT), eq("pni-identity-key"), any(), any(), any()); + + assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); + assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number()); + assertEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni()); + } + + @Test + void unprocessableRequestJson() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/phone_number_identity_key_distribution") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + try (Response response = request.put(Entity.json(unprocessableJson()))) { + assertEquals(400, response.getStatus()); + } + } + + @Test + void missingBasicAuthorization() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/phone_number_identity_key_distribution") + .request(); + try (Response response = request.put(Entity.json(requestJson()))) { + assertEquals(401, response.getStatus()); + } + } + + @Test + void invalidBasicAuthorization() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/phone_number_identity_key_distribution") + .request() + .header(HttpHeaders.AUTHORIZATION, "Basic but-invalid"); + try (Response response = request.put(Entity.json(requestJson()))) { + assertEquals(401, response.getStatus()); + } + } + + @Test + void invalidRequestBody() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/phone_number_identity_key_distribution") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + try (Response response = request.put(Entity.json(invalidRequestJson()))) { + assertEquals(422, response.getStatus()); + } + } + + /** + * Valid request JSON for a {@link org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest} + */ + private static String requestJson() { + return """ + { + "pniIdentityKey": "pni-identity-key", + "deviceMessages": [], + "devicePniSignedPrekeys": {}, + "pniRegistrationIds": {} + } + """; + } + + /** + * Request JSON in the shape of {@link org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest}, but that + * fails validation + */ + private static String invalidRequestJson() { + return """ + { + "pniIdentityKey": null, + "deviceMessages": [], + "devicePniSignedPrekeys": {}, + "pniRegistrationIds": {} + } + """; + } + + /** + * Request JSON that cannot be marshalled into + * {@link org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest} + */ + private static String unprocessableJson() { + return """ + { + "pniIdentityKey": [] + } + """; + } + + } + @Nested class PhoneNumberDiscoverability { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index db28c03db..4a8af3d29 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -5,6 +5,8 @@ package org.whispersystems.textsecuregcm.storage; +import org.whispersystems.textsecuregcm.util.SystemMapper; + import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertSame; @@ -39,6 +41,7 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.stream.Collectors; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -59,6 +62,7 @@ import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2 import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; +import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; class AccountsManagerTest { @@ -694,6 +698,22 @@ class AccountsManagerTest { verify(keys, never()).delete(any()); } + @Test + void testChangePhoneNumberSameNumberWithPNIData() throws InterruptedException, MismatchedDevicesException { + final String number = "+14152222222"; + + Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); + assertThrows(IllegalArgumentException.class, + () -> accountsManager.changeNumber( + account, number, "new-identity-key", Map.of(1L, new SignedPreKey()), Map.of(1L, 101)), + "AccountsManager should not allow use of changeNumber with new PNI keys but without changing number"); + + verify(accounts, never()).update(any()); + verifyNoInteractions(deletedAccountsManager); + verifyNoInteractions(directoryQueue); + verifyNoInteractions(keys); + } + @Test void testChangePhoneNumberExistingAccount() throws InterruptedException, MismatchedDevicesException { doAnswer(invocation -> invocation.getArgument(2, BiFunction.class).apply(Optional.empty(), Optional.empty())) @@ -733,6 +753,45 @@ class AccountsManagerTest { assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setNumber(targetNumber, UUID.randomUUID()))); } + @Test + void testPNIUpdate() throws InterruptedException, MismatchedDevicesException { + final String number = "+14152222222"; + + List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]); + Map newSignedKeys = Map.of( + 1L, new SignedPreKey(1L, "pub1", "sig1"), + 2L, new SignedPreKey(2L, "pub2", "sig2")); + Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + + UUID oldUuid = account.getUuid(); + UUID oldPni = account.getPhoneNumberIdentifier(); + Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); + + final Account updatedAccount = accountsManager.updatePNIKeys(account, "new-pni-identity-key", newSignedKeys, newRegistrationIds); + + // non-PNI stuff should not change + assertEquals(oldUuid, updatedAccount.getUuid()); + assertEquals(number, updatedAccount.getNumber()); + assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier()); + assertEquals(null, updatedAccount.getIdentityKey()); + assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey))); + assertEquals(Map.of(1L, 101, 2L, 102), + updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); + + // PNI stuff should + assertEquals("new-pni-identity-key", updatedAccount.getPhoneNumberIdentityKey()); + assertEquals(newSignedKeys, + updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentitySignedPreKey))); + assertEquals(newRegistrationIds, + updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); + + verify(accounts).update(any()); + verifyNoInteractions(deletedAccountsManager); + verifyNoInteractions(directoryQueue); + verifyNoInteractions(keys); + } + @Test void testReserveUsernameHash() throws UsernameHashNotAvailableException { final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 19ff99867..a995939a2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -68,6 +69,25 @@ public class ChangeNumberManagerTest { return updatedAccount; }); + + when(accountsManager.updatePNIKeys(any(), any(), any(), any())).thenAnswer((Answer)invocation -> { + final Account account = invocation.getArgument(0, Account.class); + + final UUID uuid = account.getUuid(); + final UUID pni = account.getPhoneNumberIdentifier(); + final List devices = account.getDevices(); + + final Account updatedAccount = mock(Account.class); + when(updatedAccount.getUuid()).thenReturn(uuid); + when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(pni); + when(updatedAccount.getDevices()).thenReturn(devices); + for (long i = 1; i <= 3; i++) { + final Optional d = account.getDevice(i); + when(updatedAccount.getDevice(i)).thenReturn(d); + } + + return updatedAccount; + }); } @Test @@ -134,6 +154,86 @@ public class ChangeNumberManagerTest { assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni())); } + @Test + void changeNumberSameNumberSetPrimaryDevicePrekeyAndSendMessages() throws Exception { + final String originalE164 = "+18005551234"; + final UUID aci = UUID.randomUUID(); + final UUID pni = UUID.randomUUID(); + + final Account account = mock(Account.class); + when(account.getNumber()).thenReturn(originalE164); + when(account.getUuid()).thenReturn(aci); + when(account.getPhoneNumberIdentifier()).thenReturn(pni); + + final Device d2 = mock(Device.class); + when(d2.isEnabled()).thenReturn(true); + when(d2.getId()).thenReturn(2L); + + when(account.getDevice(2L)).thenReturn(Optional.of(d2)); + when(account.getDevices()).thenReturn(List.of(d2)); + + final String pniIdentityKey = "pni-identity-key"; + final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); + final Map registrationIds = Map.of(1L, 17, 2L, 19); + + final IncomingMessage msg = mock(IncomingMessage.class); + when(msg.destinationDeviceId()).thenReturn(2L); + when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); + + changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, List.of(msg), registrationIds); + + verify(accountsManager).updatePNIKeys(account, pniIdentityKey, prekeys, registrationIds); + + final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); + verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); + + final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); + + assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); + assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); + assertEquals(Device.MASTER_ID, envelope.getSourceDevice()); + assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); + } + + @Test + void updatePNIKeysSetPrimaryDevicePrekeyAndSendMessages() throws Exception { + final UUID aci = UUID.randomUUID(); + final UUID pni = UUID.randomUUID(); + + final Account account = mock(Account.class); + when(account.getUuid()).thenReturn(aci); + when(account.getPhoneNumberIdentifier()).thenReturn(pni); + + final Device d2 = mock(Device.class); + when(d2.isEnabled()).thenReturn(true); + when(d2.getId()).thenReturn(2L); + + when(account.getDevice(2L)).thenReturn(Optional.of(d2)); + when(account.getDevices()).thenReturn(List.of(d2)); + + final String pniIdentityKey = "pni-identity-key"; + final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); + final Map registrationIds = Map.of(1L, 17, 2L, 19); + + final IncomingMessage msg = mock(IncomingMessage.class); + when(msg.destinationDeviceId()).thenReturn(2L); + when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); + + changeNumberManager.updatePNIKeys(account, pniIdentityKey, prekeys, List.of(msg), registrationIds); + + verify(accountsManager).updatePNIKeys(account, pniIdentityKey, prekeys, registrationIds); + + final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); + verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); + + final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); + + assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); + assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); + assertEquals(Device.MASTER_ID, envelope.getSourceDevice()); + assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); + } + @Test void changeNumberMismatchedRegistrationId() { final Account account = mock(Account.class); @@ -164,6 +264,36 @@ public class ChangeNumberManagerTest { () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", preKeys, messages, registrationIds)); } + @Test + void updatePNIKeysMismatchedRegistrationId() { + final Account account = mock(Account.class); + when(account.getNumber()).thenReturn("+18005551234"); + + final List devices = new ArrayList<>(); + + for (int i = 1; i <= 3; i++) { + final Device device = mock(Device.class); + when(device.getId()).thenReturn((long) i); + when(device.isEnabled()).thenReturn(true); + when(device.getRegistrationId()).thenReturn(i); + + devices.add(device); + when(account.getDevice(i)).thenReturn(Optional.of(device)); + } + + when(account.getDevices()).thenReturn(devices); + + final List messages = List.of( + new IncomingMessage(1, 2, 1, "foo"), + new IncomingMessage(1, 3, 1, "foo")); + + final Map preKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey()); + final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); + + assertThrows(StaleDevicesException.class, + () -> changeNumberManager.updatePNIKeys(account, "pni-identity-key", preKeys, messages, registrationIds)); + } + @Test void changeNumberMissingData() { final Account account = mock(Account.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java index 0c40a87a5..ea424e236 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java @@ -19,10 +19,15 @@ public class DevicesHelper { } public static Device createDevice(final long deviceId, final long lastSeen) { + return createDevice(deviceId, lastSeen, 0); + } + + public static Device createDevice(final long deviceId, final long lastSeen, final int registrationId) { final Device device = new Device(); device.setId(deviceId); device.setLastSeen(lastSeen); device.setUserAgent("OWT"); + device.setRegistrationId(registrationId); setEnabled(device, true);