From 2ecf3cb303b10f11f46862023bb38139320615b7 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 9 Aug 2023 15:34:26 -0400 Subject: [PATCH] Revert "Don't immediately require PNI-associated keys for "atomic" device linking" This reverts commit 4ec97cf00673b4d3bc3d2fd7d60a94513b8885f2. --- .../controllers/DeviceController.java | 56 ++++++++----------- .../entities/LinkDeviceRequest.java | 13 ++--- .../controllers/DeviceControllerTest.java | 43 ++++---------- 3 files changed, 39 insertions(+), 73 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 27e9d7f6e..d13055040 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -21,7 +21,6 @@ import java.security.NoSuchAlgorithmException; import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.ArrayList; import java.util.Base64; import java.util.LinkedList; import java.util.List; @@ -355,24 +354,23 @@ public class DeviceController { final Optional maybeDeviceActivationRequest) throws RateLimitExceededException, DeviceLimitExceededException { - final Account account = checkVerificationToken(verificationCode) - .flatMap(accounts::getByAccountIdentifier) + final Optional maybeAciFromToken = checkVerificationToken(verificationCode); + + final Account account = maybeAciFromToken.flatMap(accounts::getByAccountIdentifier) .orElseThrow(ForbiddenException::new); rateLimiters.getVerifyDeviceLimiter().validate(account.getUuid()); maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { assert deviceActivationRequest.aciSignedPreKey().isPresent(); + assert deviceActivationRequest.pniSignedPreKey().isPresent(); assert deviceActivationRequest.aciPqLastResortPreKey().isPresent(); + assert deviceActivationRequest.pniPqLastResortPreKey().isPresent(); final boolean allKeysValid = PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(), - List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get())) && - deviceActivationRequest.pniSignedPreKey().map(pniSignedPreKey -> - PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniSignedPreKey))) - .orElse(true) && - deviceActivationRequest.pniPqLastResortPreKey().map(pniPqLastResortPreKey -> - PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniPqLastResortPreKey))) - .orElse(true); + List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get())) + && PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), + List.of(deviceActivationRequest.pniSignedPreKey().get(), deviceActivationRequest.pniPqLastResortPreKey().get())); if (!allKeysValid) { throw new WebApplicationException(Response.status(422).build()); @@ -411,8 +409,7 @@ public class DeviceController { maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey().get()); - - deviceActivationRequest.pniSignedPreKey().ifPresent(device::setPhoneNumberIdentitySignedPreKey); + device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey().get()); deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> { device.setApnId(apnRegistrationId.apnRegistrationId()); @@ -434,31 +431,24 @@ public class DeviceController { deleteKeysFuture.join(); - maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { - final List> storeKeyFutures = new ArrayList<>(4); - - storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getUuid(), - Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get()))); - - storeKeyFutures.add(keys.storePqLastResort(a.getUuid(), - Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get()))); - - deviceActivationRequest.pniSignedPreKey().ifPresent(pniSignedPreKey -> - storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), - Map.of(device.getId(), pniSignedPreKey)))); - - deviceActivationRequest.pniPqLastResortPreKey().ifPresent(pniPqLastResortPreKey -> - storeKeyFutures.add(keys.storePqLastResort(a.getPhoneNumberIdentifier(), - Map.of(device.getId(), pniPqLastResortPreKey)))); - - CompletableFuture.allOf(storeKeyFutures.toArray(new CompletableFuture[0])).join(); - }); + maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf( + keys.storeEcSignedPreKeys(a.getUuid(), + Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())), + keys.storePqLastResort(a.getUuid(), + Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())), + keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), + Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())), + keys.storePqLastResort(a.getPhoneNumberIdentifier(), + Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get()))) + .join()); a.addDevice(device); }); - usedTokenCluster.useCluster(connection -> - connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION))); + if (maybeAciFromToken.isPresent()) { + usedTokenCluster.useCluster(connection -> + connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION))); + } return new Pair<>(updatedAccount, device); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java index ba8bb4971..b12fb36a5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java @@ -7,7 +7,6 @@ import io.swagger.v3.oas.annotations.media.Schema; import javax.validation.Valid; import javax.validation.constraints.AssertTrue; -import javax.validation.constraints.NotBlank; import java.util.Optional; public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """ @@ -24,8 +23,8 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI @JsonCreator @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - public LinkDeviceRequest(@JsonProperty("verificationCode") @NotBlank String verificationCode, - @JsonProperty("accountAttributes") @Valid AccountAttributes accountAttributes, + public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode, + @JsonProperty("accountAttributes") AccountAttributes accountAttributes, @JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey, @JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey, @JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey, @@ -39,14 +38,10 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI @AssertTrue public boolean hasAllRequiredFields() { - // PNI-associated credentials are not yet required, but will be when all devices are assumed to have a PNI identity - // key. - final boolean mismatchedPniKeys = deviceActivationRequest().pniSignedPreKey().isPresent() - ^ deviceActivationRequest().pniPqLastResortPreKey().isPresent(); - return deviceActivationRequest().aciSignedPreKey().isPresent() + && deviceActivationRequest().pniSignedPreKey().isPresent() && deviceActivationRequest().aciPqLastResortPreKey().isPresent() - && !mismatchedPniKeys; + && deviceActivationRequest().pniPqLastResortPreKey().isPresent(); } @AssertTrue diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index a32426d08..9f29a1d06 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -245,8 +245,7 @@ class DeviceControllerTest { final Optional gcmRegistrationId, final Optional expectedApnsToken, final Optional expectedApnsVoipToken, - final Optional expectedGcmToken, - final boolean includePniKeys) { + final Optional expectedGcmToken) { final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.MASTER_ID); @@ -267,15 +266,12 @@ class DeviceControllerTest { final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); - pniSignedPreKey = includePniKeys ? Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)) : Optional.empty(); + pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); - pniPqLastResortPreKey = includePniKeys ? Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)) : Optional.empty(); + pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); - - if (includePniKeys) { - when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); - } + when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -298,11 +294,7 @@ class DeviceControllerTest { final Device device = deviceCaptor.getValue(); assertEquals(aciSignedPreKey.get(), device.getSignedPreKey()); - - if (includePniKeys) { - assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey()); - } - + assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey()); assertEquals(fetchesMessages, device.getFetchesMessages()); expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()), @@ -315,35 +307,24 @@ class DeviceControllerTest { () -> assertNull(device.getGcmId())); verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); - - if (includePniKeys) { - verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get())); - verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get())); - } else { - verify(keysManager, never()).storeEcSignedPreKeys(eq(AuthHelper.VALID_PNI), any()); - verify(keysManager, never()).storePqLastResort(eq(AuthHelper.VALID_PNI), any()); - } - verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get())); + verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get())); verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get())); + verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get())); verify(commands).set(anyString(), anyString(), any()); } + private static Stream linkDeviceAtomic() { final String apnsToken = "apns-token"; final String apnsVoipToken = "apns-voip-token"; final String gcmToken = "gcm-token"; return Stream.of( - Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true), - Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), true), - Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), true), - Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), true), - - Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), false), - Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), false), - Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), false), - Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), false) + Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), + Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty()), + Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty()), + Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken)) ); }