Don't immediately require PNI-associated keys for "atomic" device linking

This commit is contained in:
Jon Chambers 2023-08-08 16:53:41 -04:00 committed by Jon Chambers
parent d51c6fd2f8
commit 4ec97cf006
3 changed files with 73 additions and 39 deletions

View File

@ -21,6 +21,7 @@ import java.security.NoSuchAlgorithmException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
@ -354,23 +355,24 @@ public class DeviceController {
final Optional<DeviceActivationRequest> maybeDeviceActivationRequest)
throws RateLimitExceededException, DeviceLimitExceededException {
final Optional<UUID> maybeAciFromToken = checkVerificationToken(verificationCode);
final Account account = maybeAciFromToken.flatMap(accounts::getByAccountIdentifier)
final Account account = checkVerificationToken(verificationCode)
.flatMap(accounts::getByAccountIdentifier)
.orElseThrow(ForbiddenException::new);
rateLimiters.getVerifyDeviceLimiter().validate(account.getUuid());
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
assert deviceActivationRequest.aciSignedPreKey().isPresent();
assert deviceActivationRequest.pniSignedPreKey().isPresent();
assert deviceActivationRequest.aciPqLastResortPreKey().isPresent();
assert deviceActivationRequest.pniPqLastResortPreKey().isPresent();
final boolean allKeysValid = PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(),
List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get()))
&& PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(),
List.of(deviceActivationRequest.pniSignedPreKey().get(), deviceActivationRequest.pniPqLastResortPreKey().get()));
List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get())) &&
deviceActivationRequest.pniSignedPreKey().map(pniSignedPreKey ->
PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniSignedPreKey)))
.orElse(true) &&
deviceActivationRequest.pniPqLastResortPreKey().map(pniPqLastResortPreKey ->
PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniPqLastResortPreKey)))
.orElse(true);
if (!allKeysValid) {
throw new WebApplicationException(Response.status(422).build());
@ -409,7 +411,8 @@ public class DeviceController {
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey().get());
device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey().get());
deviceActivationRequest.pniSignedPreKey().ifPresent(device::setPhoneNumberIdentitySignedPreKey);
deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
@ -431,24 +434,31 @@ public class DeviceController {
deleteKeysFuture.join();
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf(
keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())),
keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())),
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())),
keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get())))
.join());
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
final List<CompletableFuture<Void>> storeKeyFutures = new ArrayList<>(4);
storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())));
storeKeyFutures.add(keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())));
deviceActivationRequest.pniSignedPreKey().ifPresent(pniSignedPreKey ->
storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), pniSignedPreKey))));
deviceActivationRequest.pniPqLastResortPreKey().ifPresent(pniPqLastResortPreKey ->
storeKeyFutures.add(keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), pniPqLastResortPreKey))));
CompletableFuture.allOf(storeKeyFutures.toArray(new CompletableFuture[0])).join();
});
a.addDevice(device);
});
if (maybeAciFromToken.isPresent()) {
usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
}
usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
return new Pair<>(updatedAccount, device);
}

View File

@ -7,6 +7,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
import javax.validation.Valid;
import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotBlank;
import java.util.Optional;
public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
@ -23,8 +24,8 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
@JsonCreator
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode,
@JsonProperty("accountAttributes") AccountAttributes accountAttributes,
public LinkDeviceRequest(@JsonProperty("verificationCode") @NotBlank String verificationCode,
@JsonProperty("accountAttributes") @Valid AccountAttributes accountAttributes,
@JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey,
@JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey,
@JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey,
@ -38,10 +39,14 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
@AssertTrue
public boolean hasAllRequiredFields() {
// PNI-associated credentials are not yet required, but will be when all devices are assumed to have a PNI identity
// key.
final boolean mismatchedPniKeys = deviceActivationRequest().pniSignedPreKey().isPresent()
^ deviceActivationRequest().pniPqLastResortPreKey().isPresent();
return deviceActivationRequest().aciSignedPreKey().isPresent()
&& deviceActivationRequest().pniSignedPreKey().isPresent()
&& deviceActivationRequest().aciPqLastResortPreKey().isPresent()
&& deviceActivationRequest().pniPqLastResortPreKey().isPresent();
&& !mismatchedPniKeys;
}
@AssertTrue

View File

@ -245,7 +245,8 @@ class DeviceControllerTest {
final Optional<GcmRegistrationId> gcmRegistrationId,
final Optional<String> expectedApnsToken,
final Optional<String> expectedApnsVoipToken,
final Optional<String> expectedGcmToken) {
final Optional<String> expectedGcmToken,
final boolean includePniKeys) {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
@ -266,12 +267,15 @@ class DeviceControllerTest {
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
pniSignedPreKey = includePniKeys ? Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)) : Optional.empty();
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
pniPqLastResortPreKey = includePniKeys ? Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)) : Optional.empty();
when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
if (includePniKeys) {
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
}
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
@ -294,7 +298,11 @@ class DeviceControllerTest {
final Device device = deviceCaptor.getValue();
assertEquals(aciSignedPreKey.get(), device.getSignedPreKey());
assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey());
if (includePniKeys) {
assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey());
}
assertEquals(fetchesMessages, device.getFetchesMessages());
expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()),
@ -307,24 +315,35 @@ class DeviceControllerTest {
() -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
if (includePniKeys) {
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
} else {
verify(keysManager, never()).storeEcSignedPreKeys(eq(AuthHelper.VALID_PNI), any());
verify(keysManager, never()).storePqLastResort(eq(AuthHelper.VALID_PNI), any());
}
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get()));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
verify(commands).set(anyString(), anyString(), any());
}
private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token";
final String gcmToken = "gcm-token";
return Stream.of(
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty()),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty()),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken))
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), true),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), true),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), true),
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), false),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), false),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), false),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), false)
);
}