diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 977546b67..edd144380 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -8,6 +8,9 @@ import com.codahale.metrics.annotation.Timed; import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HttpHeaders; import io.dropwizard.auth.Auth; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.headers.Header; +import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.tags.Tag; import java.security.SecureRandom; import java.util.LinkedList; @@ -18,6 +21,7 @@ import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; +import javax.ws.rs.ForbiddenException; import javax.ws.rs.GET; import javax.ws.rs.HeaderParam; import javax.ws.rs.PUT; @@ -36,9 +40,12 @@ import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.entities.AccountAttributes; +import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.entities.DeviceInfoList; import org.whispersystems.textsecuregcm.entities.DeviceResponse; +import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest; +import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; @@ -47,6 +54,7 @@ import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.Keys; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.VerificationCode; @@ -54,7 +62,7 @@ import org.whispersystems.textsecuregcm.util.VerificationCode; @Tag(name = "Devices") public class DeviceController { - private static final int MAX_DEVICES = 6; + static final int MAX_DEVICES = 6; private final StoredVerificationCodeManager pendingDevices; private final AccountsManager accounts; @@ -142,75 +150,69 @@ public class DeviceController { return verificationCode; } + /** + * @deprecated callers should use {@link #linkDevice(BasicAuthorizationHeader, LinkDeviceRequest, ContainerRequest)} + * instead + */ @Timed @PUT @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) @Path("/{verification_code}") @ChangesDeviceEnabledState + @Deprecated(forRemoval = true) public DeviceResponse verifyDeviceToken(@PathParam("verification_code") String verificationCode, @HeaderParam(HttpHeaders.AUTHORIZATION) BasicAuthorizationHeader authorizationHeader, - @HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @NotNull @Valid AccountAttributes accountAttributes, @Context ContainerRequest containerRequest) throws RateLimitExceededException, DeviceLimitExceededException { - String number = authorizationHeader.getUsername(); - String password = authorizationHeader.getPassword(); + final Pair accountAndDevice = createDevice(authorizationHeader.getUsername(), + authorizationHeader.getPassword(), + verificationCode, + accountAttributes, + containerRequest, + Optional.empty()); - rateLimiters.getVerifyDeviceLimiter().validate(number); + final Account account = accountAndDevice.first(); + final Device device = accountAndDevice.second(); - Optional storedVerificationCode = pendingDevices.getCodeForNumber(number); + return new DeviceResponse(account.getUuid(), account.getPhoneNumberIdentifier(), device.getId()); + } - if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(verificationCode)) { - throw new WebApplicationException(Response.status(403).build()); - } + @Timed + @PUT + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + @Path("/link") + @ChangesDeviceEnabledState + @Operation(summary = "Link a device to an account", + description = """ + Links a device to an account identified by a given phone number. + """) + @ApiResponse(responseCode = "200", description = "The new device was linked to the calling account", useReturnTypeSchema = true) + @ApiResponse(responseCode = "403", description = "The given account was not found or the given verification code was incorrect") + @ApiResponse(responseCode = "411", description = "The given account already has its maximum number of linked devices") + @ApiResponse(responseCode = "422", description = "The request did not pass validation") + @ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header( + name = "Retry-After", + description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) + public DeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) BasicAuthorizationHeader authorizationHeader, + @NotNull @Valid LinkDeviceRequest linkDeviceRequest, + @Context ContainerRequest containerRequest) + throws RateLimitExceededException, DeviceLimitExceededException { - Optional account = accounts.getByE164(number); + final Pair accountAndDevice = createDevice(authorizationHeader.getUsername(), + authorizationHeader.getPassword(), + linkDeviceRequest.verificationCode(), + linkDeviceRequest.accountAttributes(), + containerRequest, + Optional.of(linkDeviceRequest.deviceActivationRequest())); - if (account.isEmpty()) { - throw new WebApplicationException(Response.status(403).build()); - } + final Account account = accountAndDevice.first(); + final Device device = accountAndDevice.second(); - // Normally, the "do we need to refresh somebody's websockets" listener can do this on its own. In this case, - // we're not using the conventional authentication system, and so we need to give it a hint so it knows who the - // active user is and what their device states look like. - AuthEnablementRefreshRequirementProvider.setAccount(containerRequest, account.get()); - - int maxDeviceLimit = MAX_DEVICES; - - if (maxDeviceConfiguration.containsKey(account.get().getNumber())) { - maxDeviceLimit = maxDeviceConfiguration.get(account.get().getNumber()); - } - - if (account.get().getEnabledDeviceCount() >= maxDeviceLimit) { - throw new DeviceLimitExceededException(account.get().getDevices().size(), MAX_DEVICES); - } - - final DeviceCapabilities capabilities = accountAttributes.getCapabilities(); - if (capabilities != null && isCapabilityDowngrade(account.get(), capabilities)) { - throw new WebApplicationException(Response.status(409).build()); - } - - Device device = new Device(); - device.setName(accountAttributes.getName()); - device.setAuthTokenHash(SaltedTokenHash.generateFor(password)); - device.setFetchesMessages(accountAttributes.getFetchesMessages()); - device.setRegistrationId(accountAttributes.getRegistrationId()); - accountAttributes.getPhoneNumberIdentityRegistrationId().ifPresent(device::setPhoneNumberIdentityRegistrationId); - device.setLastSeen(Util.todayInMillis()); - device.setCreated(System.currentTimeMillis()); - device.setCapabilities(accountAttributes.getCapabilities()); - - final Account updatedAccount = accounts.update(account.get(), a -> { - device.setId(a.getNextDeviceId()); - messages.clear(a.getUuid(), device.getId()); - a.addDevice(device); - }); - - pendingDevices.remove(number); - - return new DeviceResponse(updatedAccount.getUuid(), updatedAccount.getPhoneNumberIdentifier(), device.getId()); + return new DeviceResponse(account.getUuid(), account.getPhoneNumberIdentifier(), device.getId()); } @Timed @@ -236,7 +238,7 @@ public class DeviceController { return new VerificationCode(randomInt); } - private boolean isCapabilityDowngrade(Account account, DeviceCapabilities capabilities) { + static boolean isCapabilityDowngrade(Account account, DeviceCapabilities capabilities) { boolean isDowngrade = false; isDowngrade |= account.isStoriesSupported() && !capabilities.isStories(); @@ -248,4 +250,103 @@ public class DeviceController { return isDowngrade; } + + private Pair createDevice(final String phoneNumber, + final String password, + final String verificationCode, + final AccountAttributes accountAttributes, + final ContainerRequest containerRequest, + final Optional maybeDeviceActivationRequest) + throws RateLimitExceededException, DeviceLimitExceededException { + + rateLimiters.getVerifyDeviceLimiter().validate(phoneNumber); + + Optional storedVerificationCode = pendingDevices.getCodeForNumber(phoneNumber); + + if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(verificationCode)) { + throw new WebApplicationException(Response.status(403).build()); + } + + final Account account = accounts.getByE164(phoneNumber) + .orElseThrow(ForbiddenException::new); + + maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { + assert deviceActivationRequest.aciSignedPreKey().isPresent(); + assert deviceActivationRequest.pniSignedPreKey().isPresent(); + assert deviceActivationRequest.aciPqLastResortPreKey().isPresent(); + assert deviceActivationRequest.pniPqLastResortPreKey().isPresent(); + + final boolean allKeysValid = PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(), + List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get())) + && PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), + List.of(deviceActivationRequest.pniSignedPreKey().get(), deviceActivationRequest.pniPqLastResortPreKey().get())); + + if (!allKeysValid) { + throw new WebApplicationException(Response.status(422).build()); + } + }); + + // Normally, the "do we need to refresh somebody's websockets" listener can do this on its own. In this case, + // we're not using the conventional authentication system, and so we need to give it a hint so it knows who the + // active user is and what their device states look like. + AuthEnablementRefreshRequirementProvider.setAccount(containerRequest, account); + + int maxDeviceLimit = MAX_DEVICES; + + if (maxDeviceConfiguration.containsKey(account.getNumber())) { + maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber()); + } + + if (account.getEnabledDeviceCount() >= maxDeviceLimit) { + throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES); + } + + final DeviceCapabilities capabilities = accountAttributes.getCapabilities(); + if (capabilities != null && isCapabilityDowngrade(account, capabilities)) { + throw new WebApplicationException(Response.status(409).build()); + } + + final Device device = new Device(); + device.setName(accountAttributes.getName()); + device.setAuthTokenHash(SaltedTokenHash.generateFor(password)); + device.setFetchesMessages(accountAttributes.getFetchesMessages()); + device.setRegistrationId(accountAttributes.getRegistrationId()); + accountAttributes.getPhoneNumberIdentityRegistrationId().ifPresent(device::setPhoneNumberIdentityRegistrationId); + device.setLastSeen(Util.todayInMillis()); + device.setCreated(System.currentTimeMillis()); + device.setCapabilities(accountAttributes.getCapabilities()); + + maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { + device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey().get()); + device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey().get()); + + deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> { + device.setApnId(apnRegistrationId.apnRegistrationId()); + device.setVoipApnId(apnRegistrationId.voipRegistrationId()); + }); + + deviceActivationRequest.gcmToken().ifPresent(gcmRegistrationId -> + device.setGcmId(gcmRegistrationId.gcmRegistrationId())); + }); + + final Account updatedAccount = accounts.update(account, a -> { + device.setId(a.getNextDeviceId()); + + messages.clear(a.getUuid(), device.getId()); + + keys.delete(a.getUuid(), device.getId()); + keys.delete(a.getPhoneNumberIdentifier(), device.getId()); + + maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { + keys.storePqLastResort(a.getUuid(), Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())); + keys.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get())); + }); + + a.addDevice(device); + }); + + pendingDevices.remove(phoneNumber); + + return new Pair<>(updatedAccount, device); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java new file mode 100644 index 000000000..6dab8a238 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java @@ -0,0 +1,55 @@ +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonUnwrapped; +import io.swagger.v3.oas.annotations.media.Schema; + +import javax.validation.Valid; +import javax.validation.constraints.AssertTrue; +import java.util.Optional; + +public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """ + The verification code associated with this device. Must match the verification code + provided by the server when provisioning this device. + """) + String verificationCode, + + AccountAttributes accountAttributes, + + @JsonUnwrapped + @JsonProperty(access = JsonProperty.Access.READ_ONLY) + DeviceActivationRequest deviceActivationRequest) { + + @JsonCreator + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode, + @JsonProperty("accountAttributes") AccountAttributes accountAttributes, + @JsonProperty("aciSignedPreKey") Optional<@Valid SignedPreKey> aciSignedPreKey, + @JsonProperty("pniSignedPreKey") Optional<@Valid SignedPreKey> pniSignedPreKey, + @JsonProperty("aciPqLastResortPreKey") Optional<@Valid SignedPreKey> aciPqLastResortPreKey, + @JsonProperty("pniPqLastResortPreKey") Optional<@Valid SignedPreKey> pniPqLastResortPreKey, + @JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken, + @JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) { + + this(verificationCode, accountAttributes, + new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnToken, gcmToken)); + } + + @AssertTrue + public boolean hasAllRequiredFields() { + return deviceActivationRequest().aciSignedPreKey().isPresent() + && deviceActivationRequest().pniSignedPreKey().isPresent() + && deviceActivationRequest().aciPqLastResortPreKey().isPresent() + && deviceActivationRequest().pniPqLastResortPreKey().isPresent(); + } + + @AssertTrue + public boolean hasExactlyOneMessageDeliveryChannel() { + if (accountAttributes.getFetchesMessages()) { + return deviceActivationRequest().apnToken().isEmpty() && deviceActivationRequest().gcmToken().isEmpty(); + } else { + return deviceActivationRequest().apnToken().isPresent() ^ deviceActivationRequest().gcmToken().isPresent(); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index 4c47583c5..08739bc94 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.tests.controllers; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.eq; @@ -21,10 +22,13 @@ import com.google.common.net.HttpHeaders; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.Stream; import javax.ws.rs.Path; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; @@ -35,14 +39,24 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; import org.whispersystems.textsecuregcm.controllers.DeviceController; import org.whispersystems.textsecuregcm.entities.AccountAttributes; +import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; +import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; import org.whispersystems.textsecuregcm.entities.DeviceResponse; +import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; +import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; @@ -56,6 +70,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.VerificationCode; @ExtendWith(DropwizardExtensionsSupport.class) @@ -121,6 +136,7 @@ class DeviceControllerTest { when(account.getNextDeviceId()).thenReturn(42L); when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); + when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI); when(account.isEnabled()).thenReturn(false); when(account.isSenderKeySupported()).thenReturn(true); when(account.isAnnouncementGroupSupported()).thenReturn(true); @@ -225,6 +241,282 @@ class DeviceControllerTest { assertEquals(401, response.getStatus()); } + @ParameterizedTest + @MethodSource + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + void linkDeviceAtomic(final boolean fetchesMessages, + final Optional apnRegistrationId, + final Optional gcmRegistrationId, + final Optional expectedApnsToken, + final Optional expectedApnsVoipToken, + final Optional expectedGcmToken) { + + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + + final Device existingDevice = mock(Device.class); + when(existingDevice.getId()).thenReturn(Device.MASTER_ID); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + + VerificationCode deviceCode = resources.getJerseyTest() + .target("/v1/devices/provisioning/code") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(VerificationCode.class); + + assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); + + final Optional aciSignedPreKey; + final Optional pniSignedPreKey; + final Optional aciPqLastResortPreKey; + final Optional pniPqLastResortPreKey; + + final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + + aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); + pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); + pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); + + when(account.getIdentityKey()).thenReturn(KeysHelper.serializeIdentityKey(aciIdentityKeyPair)); + when(account.getPhoneNumberIdentityKey()).thenReturn(KeysHelper.serializeIdentityKey(pniIdentityKeyPair)); + + final LinkDeviceRequest request = new LinkDeviceRequest("5678901", + new AccountAttributes(fetchesMessages, 1234, null, null, true, null), + new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId)); + + final DeviceResponse response = resources.getJerseyTest() + .target("/v1/devices/link") + .request() + .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) + .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE), DeviceResponse.class); + + assertThat(response.getDeviceId()).isEqualTo(42L); + + final ArgumentCaptor deviceCaptor = ArgumentCaptor.forClass(Device.class); + verify(account).addDevice(deviceCaptor.capture()); + + final Device device = deviceCaptor.getValue(); + + assertEquals(aciSignedPreKey.get(), device.getSignedPreKey()); + assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey()); + assertEquals(fetchesMessages, device.getFetchesMessages()); + + expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()), + () -> assertNull(device.getApnId())); + + expectedApnsVoipToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getVoipApnId()), + () -> assertNull(device.getVoipApnId())); + + expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()), + () -> assertNull(device.getGcmId())); + + verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER); + verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); + verify(clientPresenceManager).disconnectPresence(AuthHelper.VALID_UUID, Device.MASTER_ID); + verify(keys).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get())); + verify(keys).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get())); + } + + private static Stream linkDeviceAtomic() { + final String apnsToken = "apns-token"; + final String apnsVoipToken = "apns-voip-token"; + final String gcmToken = "gcm-token"; + + return Stream.of( + Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), + Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty()), + Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty()), + Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken)) + ); + } + + @ParameterizedTest + @MethodSource + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + void linkDeviceAtomicConflictingChannel(final boolean fetchesMessages, + final Optional apnRegistrationId, + final Optional gcmRegistrationId) { + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + + final Device existingDevice = mock(Device.class); + when(existingDevice.getId()).thenReturn(Device.MASTER_ID); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + + VerificationCode deviceCode = resources.getJerseyTest() + .target("/v1/devices/provisioning/code") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(VerificationCode.class); + + assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); + + final Optional aciSignedPreKey; + final Optional pniSignedPreKey; + final Optional aciPqLastResortPreKey; + final Optional pniPqLastResortPreKey; + + final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + + aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); + pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); + pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); + + when(account.getIdentityKey()).thenReturn(KeysHelper.serializeIdentityKey(aciIdentityKeyPair)); + when(account.getPhoneNumberIdentityKey()).thenReturn(KeysHelper.serializeIdentityKey(pniIdentityKeyPair)); + + final LinkDeviceRequest request = new LinkDeviceRequest("5678901", + new AccountAttributes(fetchesMessages, 1234, null, null, true, null), + new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId)); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/link") + .request() + .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) + .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(422, response.getStatus()); + } + } + + private static Stream linkDeviceAtomicConflictingChannel() { + return Stream.of( + Arguments.of(true, Optional.of(new ApnRegistrationId("apns-token", null)), Optional.of(new GcmRegistrationId("gcm-token"))), + Arguments.of(true, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-token"))), + Arguments.of(true, Optional.of(new ApnRegistrationId("apns-token", null)), Optional.empty()), + Arguments.of(false, Optional.of(new ApnRegistrationId("apns-token", null)), Optional.of(new GcmRegistrationId("gcm-token"))) + ); + } + + @ParameterizedTest + @MethodSource + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + void linkDeviceAtomicMissingProperty(final String aciIdentityKey, + final String pniIdentityKey, + final Optional aciSignedPreKey, + final Optional pniSignedPreKey, + final Optional aciPqLastResortPreKey, + final Optional pniPqLastResortPreKey) { + + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + + final Device existingDevice = mock(Device.class); + when(existingDevice.getId()).thenReturn(Device.MASTER_ID); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + + VerificationCode deviceCode = resources.getJerseyTest() + .target("/v1/devices/provisioning/code") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(VerificationCode.class); + + assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); + + when(account.getIdentityKey()).thenReturn(aciIdentityKey); + when(account.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey); + + final LinkDeviceRequest request = new LinkDeviceRequest("5678901", + new AccountAttributes(true, 1234, null, null, true, null), + new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty())); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/link") + .request() + .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) + .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(422, response.getStatus()); + } + } + + private static Stream linkDeviceAtomicMissingProperty() { + final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + + final Optional aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); + final Optional pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Optional aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); + final Optional pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); + + final String aciIdentityKey = KeysHelper.serializeIdentityKey(aciIdentityKeyPair); + final String pniIdentityKey = KeysHelper.serializeIdentityKey(pniIdentityKeyPair); + + return Stream.of( + Arguments.of(aciIdentityKey, pniIdentityKey, Optional.empty(), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), + Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, Optional.empty(), aciPqLastResortPreKey, pniPqLastResortPreKey), + Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, Optional.empty(), pniPqLastResortPreKey), + Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, Optional.empty()) + ); + } + + @ParameterizedTest + @MethodSource + void linkDeviceAtomicInvalidSignature(final String aciIdentityKey, + final String pniIdentityKey, + final SignedPreKey aciSignedPreKey, + final SignedPreKey pniSignedPreKey, + final SignedPreKey aciPqLastResortPreKey, + final SignedPreKey pniPqLastResortPreKey) { + + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + + final Device existingDevice = mock(Device.class); + when(existingDevice.getId()).thenReturn(Device.MASTER_ID); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + + VerificationCode deviceCode = resources.getJerseyTest() + .target("/v1/devices/provisioning/code") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(VerificationCode.class); + + assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); + + when(account.getIdentityKey()).thenReturn(aciIdentityKey); + when(account.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey); + + final LinkDeviceRequest request = new LinkDeviceRequest("5678901", + new AccountAttributes(true, 1234, null, null, true, null), + new DeviceActivationRequest(Optional.of(aciSignedPreKey), Optional.of(pniSignedPreKey), Optional.of(aciPqLastResortPreKey), Optional.of(pniPqLastResortPreKey), Optional.empty(), Optional.empty())); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/link") + .request() + .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) + .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(422, response.getStatus()); + } + } + + private static Stream linkDeviceAtomicInvalidSignature() { + final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + + final SignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair); + final SignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair); + final SignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair); + final SignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); + + final String aciIdentityKey = KeysHelper.serializeIdentityKey(aciIdentityKeyPair); + final String pniIdentityKey = KeysHelper.serializeIdentityKey(pniIdentityKeyPair); + + return Stream.of( + Arguments.of(aciIdentityKey, pniIdentityKey, signedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), + Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, signedPreKeyWithBadSignature(pniSignedPreKey), aciPqLastResortPreKey, pniPqLastResortPreKey), + Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, signedPreKeyWithBadSignature(aciPqLastResortPreKey), pniPqLastResortPreKey), + Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, signedPreKeyWithBadSignature(pniPqLastResortPreKey)) + ); + } + + private static SignedPreKey signedPreKeyWithBadSignature(final SignedPreKey signedPreKey) { + return new SignedPreKey(signedPreKey.getKeyId(), + signedPreKey.getPublicKey(), + Base64.getEncoder().encodeToString("incorrect-signature".getBytes(StandardCharsets.UTF_8))); + } + @Test void disabledDeviceRegisterTest() { Response response = resources.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java index 972e1f9f4..15f1324a4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -130,6 +130,7 @@ public class AccountsHelper { case "getEnabledDeviceCount" -> when(updatedAccount.getEnabledDeviceCount()).thenAnswer(stubbing); case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing); case "getIdentityKey" -> when(updatedAccount.getIdentityKey()).thenAnswer(stubbing); + case "getPhoneNumberIdentityKey" -> when(updatedAccount.getPhoneNumberIdentityKey()).thenAnswer(stubbing); case "getBadges" -> when(updatedAccount.getBadges()).thenAnswer(stubbing); case "getLastSeen" -> when(updatedAccount.getLastSeen()).thenAnswer(stubbing); case "hasLockedCredentials" -> when(updatedAccount.hasLockedCredentials()).thenAnswer(stubbing);