Don't immediately require PNI-associated keys for "atomic" device linking

This commit is contained in:
Jon Chambers 2023-08-08 16:53:41 -04:00 committed by Jon Chambers
parent d51c6fd2f8
commit 4ec97cf006
3 changed files with 73 additions and 39 deletions

View File

@ -21,6 +21,7 @@ import java.security.NoSuchAlgorithmException;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList;
import java.util.Base64; import java.util.Base64;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@ -354,23 +355,24 @@ public class DeviceController {
final Optional<DeviceActivationRequest> maybeDeviceActivationRequest) final Optional<DeviceActivationRequest> maybeDeviceActivationRequest)
throws RateLimitExceededException, DeviceLimitExceededException { throws RateLimitExceededException, DeviceLimitExceededException {
final Optional<UUID> maybeAciFromToken = checkVerificationToken(verificationCode); final Account account = checkVerificationToken(verificationCode)
.flatMap(accounts::getByAccountIdentifier)
final Account account = maybeAciFromToken.flatMap(accounts::getByAccountIdentifier)
.orElseThrow(ForbiddenException::new); .orElseThrow(ForbiddenException::new);
rateLimiters.getVerifyDeviceLimiter().validate(account.getUuid()); rateLimiters.getVerifyDeviceLimiter().validate(account.getUuid());
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
assert deviceActivationRequest.aciSignedPreKey().isPresent(); assert deviceActivationRequest.aciSignedPreKey().isPresent();
assert deviceActivationRequest.pniSignedPreKey().isPresent();
assert deviceActivationRequest.aciPqLastResortPreKey().isPresent(); assert deviceActivationRequest.aciPqLastResortPreKey().isPresent();
assert deviceActivationRequest.pniPqLastResortPreKey().isPresent();
final boolean allKeysValid = PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(), final boolean allKeysValid = PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(),
List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get())) List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get())) &&
&& PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), deviceActivationRequest.pniSignedPreKey().map(pniSignedPreKey ->
List.of(deviceActivationRequest.pniSignedPreKey().get(), deviceActivationRequest.pniPqLastResortPreKey().get())); PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniSignedPreKey)))
.orElse(true) &&
deviceActivationRequest.pniPqLastResortPreKey().map(pniPqLastResortPreKey ->
PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniPqLastResortPreKey)))
.orElse(true);
if (!allKeysValid) { if (!allKeysValid) {
throw new WebApplicationException(Response.status(422).build()); throw new WebApplicationException(Response.status(422).build());
@ -409,7 +411,8 @@ public class DeviceController {
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey().get()); device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey().get());
device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey().get());
deviceActivationRequest.pniSignedPreKey().ifPresent(device::setPhoneNumberIdentitySignedPreKey);
deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> { deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId()); device.setApnId(apnRegistrationId.apnRegistrationId());
@ -431,24 +434,31 @@ public class DeviceController {
deleteKeysFuture.join(); deleteKeysFuture.join();
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf( maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
keys.storeEcSignedPreKeys(a.getUuid(), final List<CompletableFuture<Void>> storeKeyFutures = new ArrayList<>(4);
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())),
keys.storePqLastResort(a.getUuid(), storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())), Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())));
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())), storeKeyFutures.add(keys.storePqLastResort(a.getUuid(),
keys.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())));
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get())))
.join()); deviceActivationRequest.pniSignedPreKey().ifPresent(pniSignedPreKey ->
storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), pniSignedPreKey))));
deviceActivationRequest.pniPqLastResortPreKey().ifPresent(pniPqLastResortPreKey ->
storeKeyFutures.add(keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), pniPqLastResortPreKey))));
CompletableFuture.allOf(storeKeyFutures.toArray(new CompletableFuture[0])).join();
});
a.addDevice(device); a.addDevice(device);
}); });
if (maybeAciFromToken.isPresent()) { usedTokenCluster.useCluster(connection ->
usedTokenCluster.useCluster(connection -> connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
}
return new Pair<>(updatedAccount, device); return new Pair<>(updatedAccount, device);
} }

View File

@ -7,6 +7,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue; import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotBlank;
import java.util.Optional; import java.util.Optional;
public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
@ -23,8 +24,8 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
@JsonCreator @JsonCreator
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode, public LinkDeviceRequest(@JsonProperty("verificationCode") @NotBlank String verificationCode,
@JsonProperty("accountAttributes") AccountAttributes accountAttributes, @JsonProperty("accountAttributes") @Valid AccountAttributes accountAttributes,
@JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey, @JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey,
@JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey, @JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey,
@JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey, @JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey,
@ -38,10 +39,14 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
@AssertTrue @AssertTrue
public boolean hasAllRequiredFields() { public boolean hasAllRequiredFields() {
// PNI-associated credentials are not yet required, but will be when all devices are assumed to have a PNI identity
// key.
final boolean mismatchedPniKeys = deviceActivationRequest().pniSignedPreKey().isPresent()
^ deviceActivationRequest().pniPqLastResortPreKey().isPresent();
return deviceActivationRequest().aciSignedPreKey().isPresent() return deviceActivationRequest().aciSignedPreKey().isPresent()
&& deviceActivationRequest().pniSignedPreKey().isPresent()
&& deviceActivationRequest().aciPqLastResortPreKey().isPresent() && deviceActivationRequest().aciPqLastResortPreKey().isPresent()
&& deviceActivationRequest().pniPqLastResortPreKey().isPresent(); && !mismatchedPniKeys;
} }
@AssertTrue @AssertTrue

View File

@ -245,7 +245,8 @@ class DeviceControllerTest {
final Optional<GcmRegistrationId> gcmRegistrationId, final Optional<GcmRegistrationId> gcmRegistrationId,
final Optional<String> expectedApnsToken, final Optional<String> expectedApnsToken,
final Optional<String> expectedApnsVoipToken, final Optional<String> expectedApnsVoipToken,
final Optional<String> expectedGcmToken) { final Optional<String> expectedGcmToken,
final boolean includePniKeys) {
final Device existingDevice = mock(Device.class); final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID); when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
@ -266,12 +267,15 @@ class DeviceControllerTest {
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); pniSignedPreKey = includePniKeys ? Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)) : Optional.empty();
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); pniPqLastResortPreKey = includePniKeys ? Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)) : Optional.empty();
when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
if (includePniKeys) {
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
}
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
@ -294,7 +298,11 @@ class DeviceControllerTest {
final Device device = deviceCaptor.getValue(); final Device device = deviceCaptor.getValue();
assertEquals(aciSignedPreKey.get(), device.getSignedPreKey()); assertEquals(aciSignedPreKey.get(), device.getSignedPreKey());
assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey());
if (includePniKeys) {
assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey());
}
assertEquals(fetchesMessages, device.getFetchesMessages()); assertEquals(fetchesMessages, device.getFetchesMessages());
expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()), expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()),
@ -307,24 +315,35 @@ class DeviceControllerTest {
() -> assertNull(device.getGcmId())); () -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
if (includePniKeys) {
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
} else {
verify(keysManager, never()).storeEcSignedPreKeys(eq(AuthHelper.VALID_PNI), any());
verify(keysManager, never()).storePqLastResort(eq(AuthHelper.VALID_PNI), any());
}
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get())); verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get()));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get())); verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
verify(commands).set(anyString(), anyString(), any()); verify(commands).set(anyString(), anyString(), any());
} }
private static Stream<Arguments> linkDeviceAtomic() { private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token"; final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token"; final String apnsVoipToken = "apns-voip-token";
final String gcmToken = "gcm-token"; final String gcmToken = "gcm-token";
return Stream.of( return Stream.of(
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty()), Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), true),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty()), Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), true),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken)) Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), true),
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), false),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), false),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), false),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), false)
); );
} }