diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index d13055040..27e9d7f6e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -21,6 +21,7 @@ import java.security.NoSuchAlgorithmException; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.Base64; import java.util.LinkedList; import java.util.List; @@ -354,23 +355,24 @@ public class DeviceController { final Optional maybeDeviceActivationRequest) throws RateLimitExceededException, DeviceLimitExceededException { - final Optional maybeAciFromToken = checkVerificationToken(verificationCode); - - final Account account = maybeAciFromToken.flatMap(accounts::getByAccountIdentifier) + final Account account = checkVerificationToken(verificationCode) + .flatMap(accounts::getByAccountIdentifier) .orElseThrow(ForbiddenException::new); rateLimiters.getVerifyDeviceLimiter().validate(account.getUuid()); maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { assert deviceActivationRequest.aciSignedPreKey().isPresent(); - assert deviceActivationRequest.pniSignedPreKey().isPresent(); assert deviceActivationRequest.aciPqLastResortPreKey().isPresent(); - assert deviceActivationRequest.pniPqLastResortPreKey().isPresent(); final boolean allKeysValid = PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(), - List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get())) - && PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), - List.of(deviceActivationRequest.pniSignedPreKey().get(), deviceActivationRequest.pniPqLastResortPreKey().get())); + List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get())) && + deviceActivationRequest.pniSignedPreKey().map(pniSignedPreKey -> + PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniSignedPreKey))) + .orElse(true) && + deviceActivationRequest.pniPqLastResortPreKey().map(pniPqLastResortPreKey -> + PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniPqLastResortPreKey))) + .orElse(true); if (!allKeysValid) { throw new WebApplicationException(Response.status(422).build()); @@ -409,7 +411,8 @@ public class DeviceController { maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey().get()); - device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey().get()); + + deviceActivationRequest.pniSignedPreKey().ifPresent(device::setPhoneNumberIdentitySignedPreKey); deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> { device.setApnId(apnRegistrationId.apnRegistrationId()); @@ -431,24 +434,31 @@ public class DeviceController { deleteKeysFuture.join(); - maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf( - keys.storeEcSignedPreKeys(a.getUuid(), - Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())), - keys.storePqLastResort(a.getUuid(), - Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())), - keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), - Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())), - keys.storePqLastResort(a.getPhoneNumberIdentifier(), - Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get()))) - .join()); + maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { + final List> storeKeyFutures = new ArrayList<>(4); + + storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getUuid(), + Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get()))); + + storeKeyFutures.add(keys.storePqLastResort(a.getUuid(), + Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get()))); + + deviceActivationRequest.pniSignedPreKey().ifPresent(pniSignedPreKey -> + storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), + Map.of(device.getId(), pniSignedPreKey)))); + + deviceActivationRequest.pniPqLastResortPreKey().ifPresent(pniPqLastResortPreKey -> + storeKeyFutures.add(keys.storePqLastResort(a.getPhoneNumberIdentifier(), + Map.of(device.getId(), pniPqLastResortPreKey)))); + + CompletableFuture.allOf(storeKeyFutures.toArray(new CompletableFuture[0])).join(); + }); a.addDevice(device); }); - if (maybeAciFromToken.isPresent()) { - usedTokenCluster.useCluster(connection -> - connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION))); - } + usedTokenCluster.useCluster(connection -> + connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION))); return new Pair<>(updatedAccount, device); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java index b12fb36a5..ba8bb4971 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java @@ -7,6 +7,7 @@ import io.swagger.v3.oas.annotations.media.Schema; import javax.validation.Valid; import javax.validation.constraints.AssertTrue; +import javax.validation.constraints.NotBlank; import java.util.Optional; public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """ @@ -23,8 +24,8 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI @JsonCreator @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode, - @JsonProperty("accountAttributes") AccountAttributes accountAttributes, + public LinkDeviceRequest(@JsonProperty("verificationCode") @NotBlank String verificationCode, + @JsonProperty("accountAttributes") @Valid AccountAttributes accountAttributes, @JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey, @JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey, @JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey, @@ -38,10 +39,14 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI @AssertTrue public boolean hasAllRequiredFields() { + // PNI-associated credentials are not yet required, but will be when all devices are assumed to have a PNI identity + // key. + final boolean mismatchedPniKeys = deviceActivationRequest().pniSignedPreKey().isPresent() + ^ deviceActivationRequest().pniPqLastResortPreKey().isPresent(); + return deviceActivationRequest().aciSignedPreKey().isPresent() - && deviceActivationRequest().pniSignedPreKey().isPresent() && deviceActivationRequest().aciPqLastResortPreKey().isPresent() - && deviceActivationRequest().pniPqLastResortPreKey().isPresent(); + && !mismatchedPniKeys; } @AssertTrue diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 9f29a1d06..a32426d08 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -245,7 +245,8 @@ class DeviceControllerTest { final Optional gcmRegistrationId, final Optional expectedApnsToken, final Optional expectedApnsVoipToken, - final Optional expectedGcmToken) { + final Optional expectedGcmToken, + final boolean includePniKeys) { final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.MASTER_ID); @@ -266,12 +267,15 @@ class DeviceControllerTest { final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); - pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + pniSignedPreKey = includePniKeys ? Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)) : Optional.empty(); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); - pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); + pniPqLastResortPreKey = includePniKeys ? Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)) : Optional.empty(); when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); - when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); + + if (includePniKeys) { + when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); + } when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -294,7 +298,11 @@ class DeviceControllerTest { final Device device = deviceCaptor.getValue(); assertEquals(aciSignedPreKey.get(), device.getSignedPreKey()); - assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey()); + + if (includePniKeys) { + assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey()); + } + assertEquals(fetchesMessages, device.getFetchesMessages()); expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()), @@ -307,24 +315,35 @@ class DeviceControllerTest { () -> assertNull(device.getGcmId())); verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); + + if (includePniKeys) { + verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get())); + verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get())); + } else { + verify(keysManager, never()).storeEcSignedPreKeys(eq(AuthHelper.VALID_PNI), any()); + verify(keysManager, never()).storePqLastResort(eq(AuthHelper.VALID_PNI), any()); + } + verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get())); - verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get())); verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get())); - verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get())); verify(commands).set(anyString(), anyString(), any()); } - private static Stream linkDeviceAtomic() { final String apnsToken = "apns-token"; final String apnsVoipToken = "apns-voip-token"; final String gcmToken = "gcm-token"; return Stream.of( - Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), - Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty()), - Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty()), - Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken)) + Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true), + Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), true), + Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), true), + Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), true), + + Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), false), + Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), false), + Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), false), + Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), false) ); }