Revert "Don't immediately require PNI-associated keys for "atomic" device linking"

This reverts commit 4ec97cf006.
This commit is contained in:
Jon Chambers 2023-08-09 15:34:26 -04:00 committed by Jon Chambers
parent bed33d042a
commit 2ecf3cb303
3 changed files with 39 additions and 73 deletions

View File

@ -21,7 +21,6 @@ import java.security.NoSuchAlgorithmException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
@ -355,24 +354,23 @@ public class DeviceController {
final Optional<DeviceActivationRequest> maybeDeviceActivationRequest)
throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = checkVerificationToken(verificationCode)
.flatMap(accounts::getByAccountIdentifier)
final Optional<UUID> maybeAciFromToken = checkVerificationToken(verificationCode);
final Account account = maybeAciFromToken.flatMap(accounts::getByAccountIdentifier)
.orElseThrow(ForbiddenException::new);
rateLimiters.getVerifyDeviceLimiter().validate(account.getUuid());
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
assert deviceActivationRequest.aciSignedPreKey().isPresent();
assert deviceActivationRequest.pniSignedPreKey().isPresent();
assert deviceActivationRequest.aciPqLastResortPreKey().isPresent();
assert deviceActivationRequest.pniPqLastResortPreKey().isPresent();
final boolean allKeysValid = PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(),
List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get())) &&
deviceActivationRequest.pniSignedPreKey().map(pniSignedPreKey ->
PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniSignedPreKey)))
.orElse(true) &&
deviceActivationRequest.pniPqLastResortPreKey().map(pniPqLastResortPreKey ->
PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(), List.of(pniPqLastResortPreKey)))
.orElse(true);
List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get()))
&& PreKeySignatureValidator.validatePreKeySignatures(account.getPhoneNumberIdentityKey(),
List.of(deviceActivationRequest.pniSignedPreKey().get(), deviceActivationRequest.pniPqLastResortPreKey().get()));
if (!allKeysValid) {
throw new WebApplicationException(Response.status(422).build());
@ -411,8 +409,7 @@ public class DeviceController {
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey().get());
deviceActivationRequest.pniSignedPreKey().ifPresent(device::setPhoneNumberIdentitySignedPreKey);
device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey().get());
deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
@ -434,31 +431,24 @@ public class DeviceController {
deleteKeysFuture.join();
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
final List<CompletableFuture<Void>> storeKeyFutures = new ArrayList<>(4);
storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())));
storeKeyFutures.add(keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())));
deviceActivationRequest.pniSignedPreKey().ifPresent(pniSignedPreKey ->
storeKeyFutures.add(keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), pniSignedPreKey))));
deviceActivationRequest.pniPqLastResortPreKey().ifPresent(pniPqLastResortPreKey ->
storeKeyFutures.add(keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), pniPqLastResortPreKey))));
CompletableFuture.allOf(storeKeyFutures.toArray(new CompletableFuture[0])).join();
});
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf(
keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())),
keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())),
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())),
keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get())))
.join());
a.addDevice(device);
});
usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
if (maybeAciFromToken.isPresent()) {
usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
}
return new Pair<>(updatedAccount, device);
}

View File

@ -7,7 +7,6 @@ import io.swagger.v3.oas.annotations.media.Schema;
import javax.validation.Valid;
import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotBlank;
import java.util.Optional;
public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
@ -24,8 +23,8 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
@JsonCreator
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public LinkDeviceRequest(@JsonProperty("verificationCode") @NotBlank String verificationCode,
@JsonProperty("accountAttributes") @Valid AccountAttributes accountAttributes,
public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode,
@JsonProperty("accountAttributes") AccountAttributes accountAttributes,
@JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey,
@JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey,
@JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey,
@ -39,14 +38,10 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
@AssertTrue
public boolean hasAllRequiredFields() {
// PNI-associated credentials are not yet required, but will be when all devices are assumed to have a PNI identity
// key.
final boolean mismatchedPniKeys = deviceActivationRequest().pniSignedPreKey().isPresent()
^ deviceActivationRequest().pniPqLastResortPreKey().isPresent();
return deviceActivationRequest().aciSignedPreKey().isPresent()
&& deviceActivationRequest().pniSignedPreKey().isPresent()
&& deviceActivationRequest().aciPqLastResortPreKey().isPresent()
&& !mismatchedPniKeys;
&& deviceActivationRequest().pniPqLastResortPreKey().isPresent();
}
@AssertTrue

View File

@ -245,8 +245,7 @@ class DeviceControllerTest {
final Optional<GcmRegistrationId> gcmRegistrationId,
final Optional<String> expectedApnsToken,
final Optional<String> expectedApnsVoipToken,
final Optional<String> expectedGcmToken,
final boolean includePniKeys) {
final Optional<String> expectedGcmToken) {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
@ -267,15 +266,12 @@ class DeviceControllerTest {
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = includePniKeys ? Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)) : Optional.empty();
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = includePniKeys ? Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)) : Optional.empty();
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
if (includePniKeys) {
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
}
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
@ -298,11 +294,7 @@ class DeviceControllerTest {
final Device device = deviceCaptor.getValue();
assertEquals(aciSignedPreKey.get(), device.getSignedPreKey());
if (includePniKeys) {
assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey());
}
assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey());
assertEquals(fetchesMessages, device.getFetchesMessages());
expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()),
@ -315,35 +307,24 @@ class DeviceControllerTest {
() -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
if (includePniKeys) {
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
} else {
verify(keysManager, never()).storeEcSignedPreKeys(eq(AuthHelper.VALID_PNI), any());
verify(keysManager, never()).storePqLastResort(eq(AuthHelper.VALID_PNI), any());
}
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get()));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
verify(commands).set(anyString(), anyString(), any());
}
private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token";
final String gcmToken = "gcm-token";
return Stream.of(
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), true),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), true),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), true),
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), false),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty(), false),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty(), false),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken), false)
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty()),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty()),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken))
);
}