Revert "Don't immediately require PNI-associated keys for "atomic" device linking"

This reverts commit 4ec97cf006.
This commit is contained in:
Jon Chambers 2023-08-09 15:34:26 -04:00 committed by Jon Chambers
parent bed33d042a
commit 2ecf3cb303
3 changed files with 39 additions and 73 deletions

View File

@ -21,7 +21,6 @@ import java.security.NoSuchAlgorithmException;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList;
import java.util.Base64; import java.util.Base64;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@ -355,24 +354,23 @@ public class DeviceController {
final Optional<DeviceActivationRequest> maybeDeviceActivationRequest) final Optional<DeviceActivationRequest> maybeDeviceActivationRequest)
throws RateLimitExceededException, DeviceLimitExceededException { throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = checkVerificationToken(verificationCode) final Optional<UUID> maybeAciFromToken = checkVerificationToken(verificationCode);
.flatMap(accounts::getByAccountIdentifier)
final Account account = maybeAciFromToken.flatMap(accounts::getByAccountIdentifier)
.orElseThrow(ForbiddenException::new); .orElseThrow(ForbiddenException::new);
rateLimiters.getVerifyDeviceLimiter().validate(account.getUuid()); rateLimiters.getVerifyDeviceLimiter().validate(account.getUuid());
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
assert deviceActivationRequest.aciSignedPreKey().isPresent(); assert deviceActivationRequest.aciSignedPreKey().isPresent();
assert deviceActivationRequest.pniSignedPreKey().isPresent();
assert deviceActivationRequest.aciPqLastResortPreKey().isPresent(); assert deviceActivationRequest.aciPqLastResortPreKey().isPresent();
assert deviceActivationRequest.pniPqLastResortPreKey().isPresent();
final boolean allKeysValid = PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(), final boolean allKeysValid = PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(),
List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get())) && List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get()))
deviceActivationRequest.pniSignedPreKey().map(pniSignedPreKey -> && PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(),
PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniSignedPreKey))) List.of(deviceActivationRequest.pniSignedPreKey().get(), deviceActivationRequest.pniPqLastResortPreKey().get()));
.orElse(true) &&
deviceActivationRequest.pniPqLastResortPreKey().map(pniPqLastResortPreKey ->
PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniPqLastResortPreKey)))
.orElse(true);
if (!allKeysValid) { if (!allKeysValid) {
throw new WebApplicationException(Response.status(422).build()); throw new WebApplicationException(Response.status(422).build());
@ -411,8 +409,7 @@ public class DeviceController {
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey().get()); device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey().get());
device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey().get());
deviceActivationRequest.pniSignedPreKey().ifPresent(device::setPhoneNumberIdentitySignedPreKey);
deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> { deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId()); device.setApnId(apnRegistrationId.apnRegistrationId());
@ -434,31 +431,24 @@ public class DeviceController {
deleteKeysFuture.join(); deleteKeysFuture.join();
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf(
final List<CompletableFuture<Void>> storeKeyFutures = new ArrayList<>(4); keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())),
storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getUuid(), keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get()))); Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())),
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
storeKeyFutures.add(keys.storePqLastResort(a.getUuid(), Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get()))); keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get())))
deviceActivationRequest.pniSignedPreKey().ifPresent(pniSignedPreKey -> .join());
storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), pniSignedPreKey))));
deviceActivationRequest.pniPqLastResortPreKey().ifPresent(pniPqLastResortPreKey ->
storeKeyFutures.add(keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), pniPqLastResortPreKey))));
CompletableFuture.allOf(storeKeyFutures.toArray(new CompletableFuture[0])).join();
});
a.addDevice(device); a.addDevice(device);
}); });
usedTokenCluster.useCluster(connection -> if (maybeAciFromToken.isPresent()) {
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION))); usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
}
return new Pair<>(updatedAccount, device); return new Pair<>(updatedAccount, device);
} }

View File

@ -7,7 +7,6 @@ import io.swagger.v3.oas.annotations.media.Schema;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue; import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotBlank;
import java.util.Optional; import java.util.Optional;
public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
@ -24,8 +23,8 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
@JsonCreator @JsonCreator
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public LinkDeviceRequest(@JsonProperty("verificationCode") @NotBlank String verificationCode, public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode,
@JsonProperty("accountAttributes") @Valid AccountAttributes accountAttributes, @JsonProperty("accountAttributes") AccountAttributes accountAttributes,
@JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey, @JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey,
@JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey, @JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey,
@JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey, @JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey,
@ -39,14 +38,10 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
@AssertTrue @AssertTrue
public boolean hasAllRequiredFields() { public boolean hasAllRequiredFields() {
// PNI-associated credentials are not yet required, but will be when all devices are assumed to have a PNI identity
// key.
final boolean mismatchedPniKeys = deviceActivationRequest().pniSignedPreKey().isPresent()
^ deviceActivationRequest().pniPqLastResortPreKey().isPresent();
return deviceActivationRequest().aciSignedPreKey().isPresent() return deviceActivationRequest().aciSignedPreKey().isPresent()
&& deviceActivationRequest().pniSignedPreKey().isPresent()
&& deviceActivationRequest().aciPqLastResortPreKey().isPresent() && deviceActivationRequest().aciPqLastResortPreKey().isPresent()
&& !mismatchedPniKeys; && deviceActivationRequest().pniPqLastResortPreKey().isPresent();
} }
@AssertTrue @AssertTrue

View File

@ -245,8 +245,7 @@ class DeviceControllerTest {
final Optional<GcmRegistrationId> gcmRegistrationId, final Optional<GcmRegistrationId> gcmRegistrationId,
final Optional<String> expectedApnsToken, final Optional<String> expectedApnsToken,
final Optional<String> expectedApnsVoipToken, final Optional<String> expectedApnsVoipToken,
final Optional<String> expectedGcmToken, final Optional<String> expectedGcmToken) {
final boolean includePniKeys) {
final Device existingDevice = mock(Device.class); final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID); when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
@ -267,15 +266,12 @@ class DeviceControllerTest {
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = includePniKeys ? Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)) : Optional.empty(); pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = includePniKeys ? Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)) : Optional.empty(); pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
if (includePniKeys) {
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
}
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
@ -298,11 +294,7 @@ class DeviceControllerTest {
final Device device = deviceCaptor.getValue(); final Device device = deviceCaptor.getValue();
assertEquals(aciSignedPreKey.get(), device.getSignedPreKey()); assertEquals(aciSignedPreKey.get(), device.getSignedPreKey());
assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey());
if (includePniKeys) {
assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey());
}
assertEquals(fetchesMessages, device.getFetchesMessages()); assertEquals(fetchesMessages, device.getFetchesMessages());
expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()), expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()),
@ -315,35 +307,24 @@ class DeviceControllerTest {
() -> assertNull(device.getGcmId())); () -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
if (includePniKeys) {
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
} else {
verify(keysManager, never()).storeEcSignedPreKeys(eq(AuthHelper.VALID_PNI), any());
verify(keysManager, never()).storePqLastResort(eq(AuthHelper.VALID_PNI), any());
}
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get())); verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get()));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get())); verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
verify(commands).set(anyString(), anyString(), any()); verify(commands).set(anyString(), anyString(), any());
} }
private static Stream<Arguments> linkDeviceAtomic() { private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token"; final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token"; final String apnsVoipToken = "apns-voip-token";
final String gcmToken = "gcm-token"; final String gcmToken = "gcm-token";
return Stream.of( return Stream.of(
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true), Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), true), Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty()),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), true), Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty()),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), true), Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken))
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), false),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), false),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), false),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), false)
); );
} }