Always require atomic account creation

This commit is contained in:
Jon Chambers 2023-11-11 10:07:16 -08:00 committed by Jon Chambers
parent 9069c5abb6
commit 521900c048
12 changed files with 409 additions and 584 deletions

View File

@ -75,36 +75,6 @@ public final class Operations {
INTEGRATION_TOOLS.populateRecoveryPassword(number, registrationPassword).join(); INTEGRATION_TOOLS.populateRecoveryPassword(number, registrationPassword).join();
// register account
final RegistrationRequest registrationRequest = new RegistrationRequest(
null, registrationPassword, accountAttributes, true, false,
Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
final AccountIdentityResponse registrationResponse = apiPost("/v1/registration", registrationRequest)
.authorized(number, accountPassword)
.executeExpectSuccess(AccountIdentityResponse.class);
user.setAciUuid(registrationResponse.uuid());
user.setPniUuid(registrationResponse.pni());
// upload pre-key
final TestUser.PreKeySetPublicView preKeySetPublicView = user.preKeys(Device.PRIMARY_ID, false);
apiPut("/v2/keys", preKeySetPublicView)
.authorized(user, Device.PRIMARY_ID)
.executeExpectSuccess();
return user;
}
public static TestUser newRegisteredUserAtomic(final String number) {
final byte[] registrationPassword = RandomUtils.nextBytes(32);
final String accountPassword = Base64.getEncoder().encodeToString(RandomUtils.nextBytes(32));
final TestUser user = TestUser.create(number, accountPassword, registrationPassword);
final AccountAttributes accountAttributes = user.accountAttributes();
INTEGRATION_TOOLS.populateRecoveryPassword(number, registrationPassword).join();
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@ -113,13 +83,12 @@ public final class Operations {
registrationPassword, registrationPassword,
accountAttributes, accountAttributes,
true, true,
true, new IdentityKey(aciIdentityKeyPair.getPublicKey()),
Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())), new IdentityKey(pniIdentityKeyPair.getPublicKey()),
Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())), generateSignedECPreKey(1, aciIdentityKeyPair),
Optional.of(generateSignedECPreKey(1, aciIdentityKeyPair)), generateSignedECPreKey(2, pniIdentityKeyPair),
Optional.of(generateSignedECPreKey(2, pniIdentityKeyPair)), generateSignedKEMPreKey(3, aciIdentityKeyPair),
Optional.of(generateSignedKEMPreKey(3, aciIdentityKeyPair)), generateSignedKEMPreKey(4, pniIdentityKeyPair),
Optional.of(generateSignedKEMPreKey(4, pniIdentityKeyPair)),
Optional.empty(), Optional.empty(),
Optional.empty()); Optional.empty());

View File

@ -18,7 +18,6 @@ import org.signal.libsignal.usernames.Username;
import org.whispersystems.textsecuregcm.entities.AccountIdentifierResponse; import org.whispersystems.textsecuregcm.entities.AccountIdentifierResponse;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ConfirmUsernameHashRequest; import org.whispersystems.textsecuregcm.entities.ConfirmUsernameHashRequest;
import org.whispersystems.textsecuregcm.entities.EncryptedUsername;
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest;
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse;
import org.whispersystems.textsecuregcm.entities.UsernameHashResponse; import org.whispersystems.textsecuregcm.entities.UsernameHashResponse;
@ -41,7 +40,7 @@ public class AccountTest {
@Test @Test
public void testCreateAccountAtomic() throws Exception { public void testCreateAccountAtomic() throws Exception {
final TestUser user = Operations.newRegisteredUserAtomic("+19995550201"); final TestUser user = Operations.newRegisteredUser("+19995550201");
try { try {
final Pair<Integer, AccountIdentityResponse> execute = Operations.apiGet("/v1/accounts/whoami") final Pair<Integer, AccountIdentityResponse> execute = Operations.apiGet("/v1/accounts/whoami")
.authorized(user) .authorized(user)

View File

@ -11,8 +11,7 @@ import java.nio.charset.StandardCharsets;
import java.util.Base64; import java.util.Base64;
import java.util.List; import java.util.List;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
@ -21,19 +20,10 @@ import org.whispersystems.textsecuregcm.storage.Device;
public class MessagingTest { public class MessagingTest {
@ParameterizedTest @Test
@ValueSource(booleans = {true, false}) public void testSendMessageUnsealed() {
public void testSendMessageUnsealed(final boolean atomicAccountCreation) throws Exception { final TestUser userA = Operations.newRegisteredUser("+19995550102");
final TestUser userA; final TestUser userB = Operations.newRegisteredUser("+19995550103");
final TestUser userB;
if (atomicAccountCreation) {
userA = Operations.newRegisteredUserAtomic("+19995550102");
userB = Operations.newRegisteredUserAtomic("+19995550103");
} else {
userA = Operations.newRegisteredUser("+19995550104");
userB = Operations.newRegisteredUser("+19995550105");
}
try { try {
final byte[] expectedContent = "Hello, World!".getBytes(StandardCharsets.UTF_8); final byte[] expectedContent = "Hello, World!".getBytes(StandardCharsets.UTF_8);

View File

@ -77,12 +77,12 @@ public class DeviceController {
static final int MAX_DEVICES = 6; static final int MAX_DEVICES = 6;
private final Key verificationTokenKey; private final Key verificationTokenKey;
private final AccountsManager accounts; private final AccountsManager accounts;
private final MessagesManager messages; private final MessagesManager messages;
private final KeysManager keys; private final KeysManager keys;
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final FaultTolerantRedisCluster usedTokenCluster; private final FaultTolerantRedisCluster usedTokenCluster;
private final Map<String, Integer> maxDeviceConfiguration; private final Map<String, Integer> maxDeviceConfiguration;
private final Clock clock; private final Clock clock;
@ -217,8 +217,8 @@ public class DeviceController {
@ChangesDeviceEnabledState @ChangesDeviceEnabledState
@Operation(summary = "Link a device to an account", @Operation(summary = "Link a device to an account",
description = """ description = """
Links a device to an account identified by a given phone number. Links a device to an account identified by a given phone number.
""") """)
@ApiResponse(responseCode = "200", description = "The new device was linked to the calling account", useReturnTypeSchema = true) @ApiResponse(responseCode = "200", description = "The new device was linked to the calling account", useReturnTypeSchema = true)
@ApiResponse(responseCode = "403", description = "The given account was not found or the given verification code was incorrect") @ApiResponse(responseCode = "403", description = "The given account was not found or the given verification code was incorrect")
@ApiResponse(responseCode = "411", description = "The given account already has its maximum number of linked devices") @ApiResponse(responseCode = "411", description = "The given account already has its maximum number of linked devices")
@ -227,8 +227,8 @@ public class DeviceController {
name = "Retry-After", name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public DeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) BasicAuthorizationHeader authorizationHeader, public DeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) BasicAuthorizationHeader authorizationHeader,
@NotNull @Valid LinkDeviceRequest linkDeviceRequest, @NotNull @Valid LinkDeviceRequest linkDeviceRequest,
@Context ContainerRequest containerRequest) @Context ContainerRequest containerRequest)
throws RateLimitExceededException, DeviceLimitExceededException { throws RateLimitExceededException, DeviceLimitExceededException {
final Pair<Account, Device> accountAndDevice = createDevice(authorizationHeader.getPassword(), final Pair<Account, Device> accountAndDevice = createDevice(authorizationHeader.getPassword(),
@ -296,7 +296,8 @@ public class DeviceController {
return Optional.empty(); return Optional.empty();
} }
final byte[] expectedSignature = getInitializedMac().doFinal(claimsAndSignature[0].getBytes(StandardCharsets.UTF_8)); final byte[] expectedSignature = getInitializedMac().doFinal(
claimsAndSignature[0].getBytes(StandardCharsets.UTF_8));
final byte[] providedSignature; final byte[] providedSignature;
try { try {
@ -345,10 +346,10 @@ public class DeviceController {
} }
private Pair<Account, Device> createDevice(final String password, private Pair<Account, Device> createDevice(final String password,
final String verificationCode, final String verificationCode,
final AccountAttributes accountAttributes, final AccountAttributes accountAttributes,
final ContainerRequest containerRequest, final ContainerRequest containerRequest,
final Optional<DeviceActivationRequest> maybeDeviceActivationRequest) final Optional<DeviceActivationRequest> maybeDeviceActivationRequest)
throws RateLimitExceededException, DeviceLimitExceededException { throws RateLimitExceededException, DeviceLimitExceededException {
final Optional<UUID> maybeAciFromToken = checkVerificationToken(verificationCode); final Optional<UUID> maybeAciFromToken = checkVerificationToken(verificationCode);
@ -359,16 +360,11 @@ public class DeviceController {
rateLimiters.getVerifyDeviceLimiter().validate(account.getUuid()); rateLimiters.getVerifyDeviceLimiter().validate(account.getUuid());
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
assert deviceActivationRequest.aciSignedPreKey().isPresent(); final boolean allKeysValid =
assert deviceActivationRequest.pniSignedPreKey().isPresent(); PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(IdentityType.ACI),
assert deviceActivationRequest.aciPqLastResortPreKey().isPresent(); List.of(deviceActivationRequest.aciSignedPreKey(), deviceActivationRequest.aciPqLastResortPreKey()))
assert deviceActivationRequest.pniPqLastResortPreKey().isPresent(); && PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(IdentityType.PNI),
List.of(deviceActivationRequest.pniSignedPreKey(), deviceActivationRequest.pniPqLastResortPreKey()));
final boolean allKeysValid = PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(
IdentityType.ACI),
List.of(deviceActivationRequest.aciSignedPreKey().get(), deviceActivationRequest.aciPqLastResortPreKey().get()))
&& PreKeySignatureValidator.validatePreKeySignatures(account.getIdentityKey(IdentityType.PNI),
List.of(deviceActivationRequest.pniSignedPreKey().get(), deviceActivationRequest.pniPqLastResortPreKey().get()));
if (!allKeysValid) { if (!allKeysValid) {
throw new WebApplicationException(Response.status(422).build()); throw new WebApplicationException(Response.status(422).build());
@ -406,8 +402,8 @@ public class DeviceController {
device.setCapabilities(accountAttributes.getCapabilities()); device.setCapabilities(accountAttributes.getCapabilities());
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey().get()); device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey().get()); device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey());
deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> { deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId()); device.setApnId(apnRegistrationId.apnRegistrationId());
@ -431,13 +427,13 @@ public class DeviceController {
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf( maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf(
keys.storeEcSignedPreKeys(a.getUuid(), keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey().get())), Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey())),
keys.storePqLastResort(a.getUuid(), keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey().get())), Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey())),
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey().get())), Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey())),
keys.storePqLastResort(a.getPhoneNumberIdentifier(), keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey().get()))) Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey())))
.join()); .join());
a.addDevice(device); a.addDevice(device);

View File

@ -64,7 +64,6 @@ public class RegistrationController {
private static final String COUNTRY_CODE_TAG_NAME = "countryCode"; private static final String COUNTRY_CODE_TAG_NAME = "countryCode";
private static final String REGION_CODE_TAG_NAME = "regionCode"; private static final String REGION_CODE_TAG_NAME = "regionCode";
private static final String VERIFICATION_TYPE_TAG_NAME = "verification"; private static final String VERIFICATION_TYPE_TAG_NAME = "verification";
private static final String ACCOUNT_ACTIVATED_TAG_NAME = "accountActivated";
private static final String INVALID_ACCOUNT_ATTRS_COUNTER_NAME = name(RegistrationController.class, "invalidAccountAttrs"); private static final String INVALID_ACCOUNT_ATTRS_COUNTER_NAME = name(RegistrationController.class, "invalidAccountAttrs");
private final AccountsManager accounts; private final AccountsManager accounts;
@ -145,50 +144,39 @@ public class RegistrationController {
Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(), Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new)); existingAccount.map(Account::getBadges).orElseGet(ArrayList::new));
// If the request includes all the information we need to fully "activate" the account, we should do so account = accounts.update(account, a -> {
if (registrationRequest.supportsAtomicAccountCreation()) { a.setIdentityKey(registrationRequest.aciIdentityKey());
assert registrationRequest.aciIdentityKey().isPresent(); a.setPhoneNumberIdentityKey(registrationRequest.pniIdentityKey());
assert registrationRequest.pniIdentityKey().isPresent();
assert registrationRequest.deviceActivationRequest().aciSignedPreKey().isPresent();
assert registrationRequest.deviceActivationRequest().pniSignedPreKey().isPresent();
assert registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().isPresent();
assert registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().isPresent();
account = accounts.update(account, a -> { final Device device = a.getPrimaryDevice().orElseThrow();
a.setIdentityKey(registrationRequest.aciIdentityKey().get());
a.setPhoneNumberIdentityKey(registrationRequest.pniIdentityKey().get());
final Device device = a.getPrimaryDevice().orElseThrow(); device.setSignedPreKey(registrationRequest.deviceActivationRequest().aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(registrationRequest.deviceActivationRequest().pniSignedPreKey());
device.setSignedPreKey(registrationRequest.deviceActivationRequest().aciSignedPreKey().get()); registrationRequest.deviceActivationRequest().apnToken().ifPresent(apnRegistrationId -> {
device.setPhoneNumberIdentitySignedPreKey(registrationRequest.deviceActivationRequest().pniSignedPreKey().get()); device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
registrationRequest.deviceActivationRequest().apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
CompletableFuture.allOf(
keysManager.storeEcSignedPreKeys(a.getUuid(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey().get())),
keysManager.storePqLastResort(a.getUuid(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get())),
keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey().get())),
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get())))
.join();
}); });
}
registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
CompletableFuture.allOf(
keysManager.storeEcSignedPreKeys(a.getUuid(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey())),
keysManager.storePqLastResort(a.getUuid(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey())),
keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey())),
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey())))
.join();
});
Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)), Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),
Tag.of(REGION_CODE_TAG_NAME, Util.getRegion(number)), Tag.of(REGION_CODE_TAG_NAME, Util.getRegion(number)),
Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType.name()), Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType.name())))
Tag.of(ACCOUNT_ACTIVATED_TAG_NAME, String.valueOf(account.isEnabled()))))
.increment(); .increment();
return new AccountIdentityResponse(account.getUuid(), return new AccountIdentityResponse(account.getUuid(),

View File

@ -122,4 +122,9 @@ public class AccountAttributes {
this.recoveryPassword = recoveryPassword; this.recoveryPassword = recoveryPassword;
return this; return this;
} }
@VisibleForTesting
public void setPhoneNumberIdentityRegistrationId(final Integer phoneNumberIdentityRegistrationId) {
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
}
} }

View File

@ -3,52 +3,52 @@ package org.whispersystems.textsecuregcm.entities;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import java.util.Optional; import java.util.Optional;
public record DeviceActivationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ public record DeviceActivationRequest(
A signed EC pre-key to be associated with this account's ACI. If provided, an account @NotNull
will be created "atomically," and all other properties needed for atomic account @Valid
creation must also be present. @Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
""") A signed EC pre-key to be associated with this account's ACI.
Optional<@Valid ECSignedPreKey> aciSignedPreKey, """)
ECSignedPreKey aciSignedPreKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @NotNull
A signed EC pre-key to be associated with this account's PNI. If provided, an account @Valid
will be created "atomically," and all other properties needed for atomic account @Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
creation must also be present. A signed EC pre-key to be associated with this account's PNI.
""") """)
Optional<@Valid ECSignedPreKey> pniSignedPreKey, ECSignedPreKey pniSignedPreKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @NotNull
A signed Kyber-1024 "last resort" pre-key to be associated with this account's ACI. If @Valid
provided, an account will be created "atomically," and all other properties needed for @Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
atomic account creation must also be present. A signed Kyber-1024 "last resort" pre-key to be associated with this account's ACI.
""") """)
Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey, KEMSignedPreKey aciPqLastResortPreKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @NotNull
A signed Kyber-1024 "last resort" pre-key to be associated with this account's PNI. If @Valid
provided, an account will be created "atomically," and all other properties needed for @Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
atomic account creation must also be present. A signed Kyber-1024 "last resort" pre-key to be associated with this account's PNI.
""") """)
Optional<@Valid KEMSignedPreKey> pniPqLastResortPreKey, KEMSignedPreKey pniPqLastResortPreKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
An APNs token set for the account's primary device. If provided, the account's primary An APNs token set for the account's primary device. If provided, the account's primary
device will be notified of new messages via push notifications to the given token. If device will be notified of new messages via push notifications to the given token.
creating an account "atomically," callers must provide exactly one of an APNs token Callers must provide exactly one of an APNs token set, an FCM token, or an
set, an FCM token, or an `AccountAttributes` entity with `fetchesMessages` set to `AccountAttributes` entity with `fetchesMessages` set to `true`.
`true`. """)
""") Optional<@Valid ApnRegistrationId> apnToken,
Optional<@Valid ApnRegistrationId> apnToken,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
An FCM/GCM token for the account's primary device. If provided, the account's primary An FCM/GCM token for the account's primary device. If provided, the account's primary
device will be notified of new messages via push notifications to the given token. If device will be notified of new messages via push notifications to the given token.
creating an account "atomically," callers must provide exactly one of an APNs token Callers must provide exactly one of an APNs token set, an FCM token, or an
set, an FCM token, or an `AccountAttributes` entity with `fetchesMessages` set to `AccountAttributes` entity with `fetchesMessages` set to `true`.
`true`. """)
""") Optional<@Valid GcmRegistrationId> gcmToken) {
Optional<@Valid GcmRegistrationId> gcmToken) {
} }

View File

@ -7,6 +7,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue; import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotNull;
import java.util.Optional; import java.util.Optional;
public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
@ -17,6 +18,8 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
AccountAttributes accountAttributes, AccountAttributes accountAttributes,
@NotNull
@Valid
@JsonUnwrapped @JsonUnwrapped
@JsonProperty(access = JsonProperty.Access.READ_ONLY) @JsonProperty(access = JsonProperty.Access.READ_ONLY)
DeviceActivationRequest deviceActivationRequest) { DeviceActivationRequest deviceActivationRequest) {
@ -25,10 +28,10 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode, public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode,
@JsonProperty("accountAttributes") AccountAttributes accountAttributes, @JsonProperty("accountAttributes") AccountAttributes accountAttributes,
@JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey, @JsonProperty("aciSignedPreKey") @NotNull @Valid ECSignedPreKey aciSignedPreKey,
@JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey, @JsonProperty("pniSignedPreKey") @NotNull @Valid ECSignedPreKey pniSignedPreKey,
@JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey, @JsonProperty("aciPqLastResortPreKey") @NotNull @Valid KEMSignedPreKey aciPqLastResortPreKey,
@JsonProperty("pniPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> pniPqLastResortPreKey, @JsonProperty("pniPqLastResortPreKey") @NotNull @Valid KEMSignedPreKey pniPqLastResortPreKey,
@JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken, @JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken,
@JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) { @JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) {
@ -36,14 +39,6 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnToken, gcmToken)); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnToken, gcmToken));
} }
@AssertTrue
public boolean hasAllRequiredFields() {
return deviceActivationRequest().aciSignedPreKey().isPresent()
&& deviceActivationRequest().pniSignedPreKey().isPresent()
&& deviceActivationRequest().aciPqLastResortPreKey().isPresent()
&& deviceActivationRequest().pniPqLastResortPreKey().isPresent();
}
@AssertTrue @AssertTrue
public boolean hasExactlyOneMessageDeliveryChannel() { public boolean hasExactlyOneMessageDeliveryChannel() {
if (accountAttributes.getFetchesMessages()) { if (accountAttributes.getFetchesMessages()) {

View File

@ -19,7 +19,7 @@ import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.OptionalIdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
The ID of an existing verification session as it appears in a verification session The ID of an existing verification session as it appears in a verification session
@ -50,31 +50,26 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
""") """)
boolean skipDeviceTransfer, boolean skipDeviceTransfer,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @NotNull
If true, indicates that this is a request for "atomic" registration. If any properties @Valid
needed for atomic account creation are not present, the request will fail. If false, @Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
atomic account creation can still occur, but only if all required fields are present. The ACI-associated identity key for the account, encoded as a base64 string.
""") """)
boolean requireAtomic, @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
IdentityKey aciIdentityKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @NotNull
The ACI-associated identity key for the account, encoded as a base64 string. If @Valid
provided, an account will be created "atomically," and all other properties needed for @Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
atomic account creation must also be present. The PNI-associated identity key for the account, encoded as a base64 string.
""") """)
@JsonSerialize(using = OptionalIdentityKeyAdapter.Serializer.class) @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@JsonDeserialize(using = OptionalIdentityKeyAdapter.Deserializer.class) @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
Optional<IdentityKey> aciIdentityKey, IdentityKey pniIdentityKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
The PNI-associated identity key for the account, encoded as a base64 string. If
provided, an account will be created "atomically," and all other properties needed for
atomic account creation must also be present.
""")
@JsonSerialize(using = OptionalIdentityKeyAdapter.Serializer.class)
@JsonDeserialize(using = OptionalIdentityKeyAdapter.Deserializer.class)
Optional<IdentityKey> pniIdentityKey,
@NotNull
@Valid
@JsonUnwrapped @JsonUnwrapped
@JsonProperty(access = JsonProperty.Access.READ_ONLY) @JsonProperty(access = JsonProperty.Access.READ_ONLY)
DeviceActivationRequest deviceActivationRequest) implements PhoneVerificationRequest { DeviceActivationRequest deviceActivationRequest) implements PhoneVerificationRequest {
@ -85,65 +80,37 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
@JsonProperty("recoveryPassword") byte[] recoveryPassword, @JsonProperty("recoveryPassword") byte[] recoveryPassword,
@JsonProperty("accountAttributes") AccountAttributes accountAttributes, @JsonProperty("accountAttributes") AccountAttributes accountAttributes,
@JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer, @JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer,
@JsonProperty("requireAtomic") boolean requireAtomic, @JsonProperty("aciIdentityKey") @NotNull @Valid IdentityKey aciIdentityKey,
@JsonProperty("aciIdentityKey") Optional<IdentityKey> aciIdentityKey, @JsonProperty("pniIdentityKey") @NotNull @Valid IdentityKey pniIdentityKey,
@JsonProperty("pniIdentityKey") Optional<IdentityKey> pniIdentityKey, @JsonProperty("aciSignedPreKey") @NotNull @Valid ECSignedPreKey aciSignedPreKey,
@JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey, @JsonProperty("pniSignedPreKey") @NotNull @Valid ECSignedPreKey pniSignedPreKey,
@JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey, @JsonProperty("aciPqLastResortPreKey") @NotNull @Valid KEMSignedPreKey aciPqLastResortPreKey,
@JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey, @JsonProperty("pniPqLastResortPreKey") @NotNull @Valid KEMSignedPreKey pniPqLastResortPreKey,
@JsonProperty("pniPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> pniPqLastResortPreKey,
@JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken, @JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken,
@JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) { @JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) {
// This may seem a little verbose, but at the time of writing, Jackson struggles with `@JsonUnwrapped` members in // This may seem a little verbose, but at the time of writing, Jackson struggles with `@JsonUnwrapped` members in
// records, and this is a workaround. Please see // records, and this is a workaround. Please see
// https://github.com/FasterXML/jackson-databind/issues/3726#issuecomment-1525396869 for additional context. // https://github.com/FasterXML/jackson-databind/issues/3726#issuecomment-1525396869 for additional context.
this(sessionId, recoveryPassword, accountAttributes, skipDeviceTransfer, requireAtomic, aciIdentityKey, pniIdentityKey, this(sessionId, recoveryPassword, accountAttributes, skipDeviceTransfer, aciIdentityKey, pniIdentityKey,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnToken, gcmToken)); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnToken, gcmToken));
} }
@AssertTrue @AssertTrue
public boolean isEverySignedKeyValid() { public boolean isEverySignedKeyValid() {
return validatePreKeySignature(aciIdentityKey(), deviceActivationRequest().aciSignedPreKey()) if (deviceActivationRequest().aciSignedPreKey() == null ||
&& validatePreKeySignature(pniIdentityKey(), deviceActivationRequest().pniSignedPreKey()) deviceActivationRequest().pniSignedPreKey() == null ||
&& validatePreKeySignature(aciIdentityKey(), deviceActivationRequest().aciPqLastResortPreKey()) deviceActivationRequest().aciPqLastResortPreKey() == null ||
&& validatePreKeySignature(pniIdentityKey(), deviceActivationRequest().pniPqLastResortPreKey()); deviceActivationRequest().pniPqLastResortPreKey() == null) {
} return false;
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") return PreKeySignatureValidator.validatePreKeySignatures(aciIdentityKey(), List.of(deviceActivationRequest().aciSignedPreKey(), deviceActivationRequest().aciPqLastResortPreKey()))
private static boolean validatePreKeySignature(final Optional<IdentityKey> maybeIdentityKey, && PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey(), List.of(deviceActivationRequest().pniSignedPreKey(), deviceActivationRequest().pniPqLastResortPreKey()));
final Optional<? extends SignedPreKey<?>> maybeSignedPreKey) {
return maybeSignedPreKey.map(signedPreKey -> maybeIdentityKey
.map(identityKey -> PreKeySignatureValidator.validatePreKeySignatures(identityKey, List.of(signedPreKey)))
.orElse(false))
.orElse(true);
}
@AssertTrue
public boolean isCompleteRequest() {
final boolean hasNoAtomicAccountCreationParameters =
aciIdentityKey().isEmpty()
&& pniIdentityKey().isEmpty()
&& deviceActivationRequest().aciSignedPreKey().isEmpty()
&& deviceActivationRequest().pniSignedPreKey().isEmpty()
&& deviceActivationRequest().aciPqLastResortPreKey().isEmpty()
&& deviceActivationRequest().pniPqLastResortPreKey().isEmpty();
return supportsAtomicAccountCreation() || (!requireAtomic() && hasNoAtomicAccountCreationParameters);
}
public boolean supportsAtomicAccountCreation() {
return hasExactlyOneMessageDeliveryChannel()
&& aciIdentityKey().isPresent()
&& pniIdentityKey().isPresent()
&& deviceActivationRequest().aciSignedPreKey().isPresent()
&& deviceActivationRequest().pniSignedPreKey().isPresent()
&& deviceActivationRequest().aciPqLastResortPreKey().isPresent()
&& deviceActivationRequest().pniPqLastResortPreKey().isPresent();
} }
@VisibleForTesting @VisibleForTesting
@AssertTrue
boolean hasExactlyOneMessageDeliveryChannel() { boolean hasExactlyOneMessageDeliveryChannel() {
if (accountAttributes.getFetchesMessages()) { if (accountAttributes.getFetchesMessages()) {
return deviceActivationRequest().apnToken().isEmpty() && deviceActivationRequest().gcmToken().isEmpty(); return deviceActivationRequest().apnToken().isEmpty() && deviceActivationRequest().gcmToken().isEmpty();

View File

@ -1,53 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
import java.util.Base64;
import java.util.Optional;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.InvalidKeyException;
public class OptionalIdentityKeyAdapter {
public static class Serializer extends JsonSerializer<Optional<IdentityKey>> {
@Override
public void serialize(final Optional<IdentityKey> maybePublicKey,
final JsonGenerator jsonGenerator,
final SerializerProvider serializers) throws IOException {
if (maybePublicKey.isPresent()) {
jsonGenerator.writeString(Base64.getEncoder().encodeToString(maybePublicKey.get().serialize()));
} else {
jsonGenerator.writeNull();
}
}
}
public static class Deserializer extends JsonDeserializer<Optional<IdentityKey>> {
@Override
public Optional<IdentityKey> deserialize(final JsonParser jsonParser, final DeserializationContext deserializationContext) throws IOException {
try {
return Optional.of(new IdentityKey(Base64.getDecoder().decode(jsonParser.getValueAsString())));
} catch (final InvalidKeyException e) {
throw new IOException(e);
}
}
@Override
public Optional<IdentityKey> getNullValue(DeserializationContext ctxt) {
return Optional.empty();
}
}
}

View File

@ -288,18 +288,18 @@ class DeviceControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class); .get(VerificationCode.class);
final Optional<ECSignedPreKey> aciSignedPreKey; final ECSignedPreKey aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey; final ECSignedPreKey pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey; final KEMSignedPreKey aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey; final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
@ -324,8 +324,8 @@ class DeviceControllerTest {
final Device device = deviceCaptor.getValue(); final Device device = deviceCaptor.getValue();
assertEquals(aciSignedPreKey.get(), device.getSignedPreKey(IdentityType.ACI)); assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI));
assertEquals(pniSignedPreKey.get(), device.getSignedPreKey(IdentityType.PNI)); assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI));
assertEquals(fetchesMessages, device.getFetchesMessages()); assertEquals(fetchesMessages, device.getFetchesMessages());
expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()), expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()),
@ -338,14 +338,13 @@ class DeviceControllerTest {
() -> assertNull(device.getGcmId())); () -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID)); verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get())); verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get())); verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get())); verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get())); verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey));
verify(commands).set(anyString(), anyString(), any()); verify(commands).set(anyString(), anyString(), any());
} }
private static Stream<Arguments> linkDeviceAtomic() { private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token"; final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token"; final String apnsVoipToken = "apns-voip-token";
@ -368,18 +367,18 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
final Optional<ECSignedPreKey> aciSignedPreKey; final ECSignedPreKey aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey; final ECSignedPreKey pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey; final KEMSignedPreKey aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey; final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
@ -421,18 +420,18 @@ class DeviceControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class); .get(VerificationCode.class);
final Optional<ECSignedPreKey> aciSignedPreKey; final ECSignedPreKey aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey; final ECSignedPreKey pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey; final KEMSignedPreKey aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey; final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
@ -465,10 +464,10 @@ class DeviceControllerTest {
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void linkDeviceAtomicMissingProperty(final IdentityKey aciIdentityKey, void linkDeviceAtomicMissingProperty(final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey, final IdentityKey pniIdentityKey,
final Optional<ECSignedPreKey> aciSignedPreKey, final ECSignedPreKey aciSignedPreKey,
final Optional<ECSignedPreKey> pniSignedPreKey, final ECSignedPreKey pniSignedPreKey,
final Optional<KEMSignedPreKey> aciPqLastResortPreKey, final KEMSignedPreKey aciPqLastResortPreKey,
final Optional<KEMSignedPreKey> pniPqLastResortPreKey) { final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
@ -503,19 +502,19 @@ class DeviceControllerTest {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Optional<ECSignedPreKey> aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
final Optional<ECSignedPreKey> pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
final Optional<KEMSignedPreKey> aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
final Optional<KEMSignedPreKey> pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey()); final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
return Stream.of( return Stream.of(
Arguments.of(aciIdentityKey, pniIdentityKey, Optional.empty(), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), Arguments.of(aciIdentityKey, pniIdentityKey, null, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, Optional.empty(), aciPqLastResortPreKey, pniPqLastResortPreKey), Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, null, aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, Optional.empty(), pniPqLastResortPreKey), Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, null, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, Optional.empty()) Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, null)
); );
} }
@ -545,7 +544,7 @@ class DeviceControllerTest {
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
new AccountAttributes(true, 1234, null, null, true, null), new AccountAttributes(true, 1234, null, null, true, null),
new DeviceActivationRequest(Optional.of(aciSignedPreKey), Optional.of(pniSignedPreKey), Optional.of(aciPqLastResortPreKey), Optional.of(pniPqLastResortPreKey), Optional.empty(), Optional.empty())); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
try (final Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link") .target("/v1/devices/link")

View File

@ -6,10 +6,8 @@
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -20,11 +18,11 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.Base64; import java.util.Base64;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@ -49,7 +47,6 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets; import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
@ -60,6 +57,7 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
@ -118,14 +116,20 @@ class RegistrationControllerTest {
@BeforeEach @BeforeEach
void setUp() { void setUp() {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter); when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
}
@Test when(accountsManager.update(any(), any())).thenAnswer(invocation -> {
public void testRegistrationRequest() throws Exception { final Account account = invocation.getArgument(0);
assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); final Consumer<Account> accountUpdater = invocation.getArgument(1);
assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); accountUpdater.accept(account);
assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
return invocation.getArgument(0);
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeEcOneTimePreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeKemOneTimePreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
} }
@Test @Test
@ -151,32 +155,26 @@ class RegistrationControllerTest {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource() @MethodSource
void invalidRegistrationId(Optional<Integer> registrationId, Optional<Integer> pniRegistrationId, int statusCode) throws InterruptedException, JsonProcessingException { void invalidRegistrationId(Optional<Integer> registrationId, Optional<Integer> pniRegistrationId, int statusCode) throws InterruptedException, JsonProcessingException {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
when(registrationServiceClient.getSession(any(), any())) when(registrationServiceClient.getSession(any(), any()))
.thenReturn( .thenReturn(
CompletableFuture.completedFuture( CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null, Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS)))); SESSION_EXPIRATION_SECONDS))));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class)); .thenReturn(account);
final String recoveryPassword = encodeRecoveryPassword(new byte[0]); final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId);
final Map<String, Object> accountAttrs = new HashMap<>();
accountAttrs.put("recoveryPassword", recoveryPassword);
registrationId.ifPresent(id -> accountAttrs.put("registrationId", id));
pniRegistrationId.ifPresent(id -> accountAttrs.put("pniRegistrationId", id));
final String json = SystemMapper.jsonMapper().writeValueAsString(Map.of(
"sessionId", encodeSessionId("sessionId"),
"recoveryPassword", recoveryPassword,
"accountAttributes", accountAttrs,
"skipDeviceTransfer", true
));
try (Response response = request.post(Entity.json(json))) { try (Response response = request.post(Entity.json(json))) {
assertEquals(statusCode, response.getStatus()); assertEquals(statusCode, response.getStatus());
@ -292,8 +290,12 @@ class RegistrationControllerTest {
void recoveryPasswordManagerVerificationTrue() throws InterruptedException { void recoveryPasswordManagerVerificationTrue() throws InterruptedException {
when(registrationRecoveryPasswordsManager.verify(any(), any())) when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true)); .thenReturn(CompletableFuture.completedFuture(true));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class)); .thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
@ -340,15 +342,19 @@ class RegistrationControllerTest {
expectedStatus = 409; expectedStatus = 409;
} else if (error != null) { } else if (error != null) {
final Exception e = switch (error) { final Exception e = switch (error) {
case MISMATCH -> new WebApplicationException(error.getExpectedStatus()); case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
case RATE_LIMITED -> new RateLimitExceededException(null, true); case RATE_LIMITED -> new RateLimitExceededException(null, true);
}; };
doThrow(e) doThrow(e)
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any(), any(), any(), any()); .when(registrationLockVerificationManager).verifyRegistrationLock(any(), any(), any(), any(), any());
expectedStatus = error.getExpectedStatus(); expectedStatus = error.getExpectedStatus();
} else { } else {
final Account createdAccount = mock(Account.class);
when(createdAccount.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class)); .thenReturn(createdAccount);
expectedStatus = 200; expectedStatus = 200;
} }
@ -396,13 +402,17 @@ class RegistrationControllerTest {
maybeAccount = Optional.empty(); maybeAccount = Optional.empty();
} }
when(accountsManager.getByE164(any())).thenReturn(maybeAccount); when(accountsManager.getByE164(any())).thenReturn(maybeAccount);
when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(mock(Account.class));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) { try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer, 1, Optional.of(2))))) {
assertEquals(expectedStatus, response.getStatus()); assertEquals(expectedStatus, response.getStatus());
} }
} }
@ -415,8 +425,12 @@ class RegistrationControllerTest {
CompletableFuture.completedFuture( CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null, Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS)))); SESSION_EXPIRATION_SECONDS))));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class)); .thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
@ -447,22 +461,22 @@ class RegistrationControllerTest {
} }
static Stream<Arguments> atomicAccountCreationConflictingChannel() { static Stream<Arguments> atomicAccountCreationConflictingChannel() {
final Optional<IdentityKey> aciIdentityKey; final IdentityKey aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey; final IdentityKey pniIdentityKey;
final Optional<ECSignedPreKey> aciSignedPreKey; final ECSignedPreKey aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey; final ECSignedPreKey pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey; final KEMSignedPreKey aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey; final KEMSignedPreKey pniPqLastResortPreKey;
{ {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())); aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())); pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
} }
final AccountAttributes fetchesMessagesAccountAttributes = final AccountAttributes fetchesMessagesAccountAttributes =
@ -477,7 +491,6 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
fetchesMessagesAccountAttributes, fetchesMessagesAccountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
@ -492,7 +505,6 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
fetchesMessagesAccountAttributes, fetchesMessagesAccountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
@ -507,7 +519,6 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
pushAccountAttributes, pushAccountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
@ -539,22 +550,22 @@ class RegistrationControllerTest {
} }
static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() { static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
final Optional<IdentityKey> aciIdentityKey; final IdentityKey aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey; final IdentityKey pniIdentityKey;
final Optional<ECSignedPreKey> aciSignedPreKey; final ECSignedPreKey aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey; final ECSignedPreKey pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey; final KEMSignedPreKey aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey; final KEMSignedPreKey pniPqLastResortPreKey;
{ {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())); aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())); pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
} }
final AccountAttributes accountAttributes = final AccountAttributes accountAttributes =
@ -566,11 +577,10 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
Optional.empty(), null,
aciPqLastResortPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey, pniPqLastResortPreKey,
Optional.empty(), Optional.empty(),
@ -581,10 +591,9 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
Optional.empty(), null,
pniSignedPreKey, pniSignedPreKey,
aciPqLastResortPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey, pniPqLastResortPreKey,
@ -596,13 +605,12 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
pniSignedPreKey, pniSignedPreKey,
aciPqLastResortPreKey, aciPqLastResortPreKey,
Optional.empty(), null,
Optional.empty(), Optional.empty(),
Optional.empty())), Optional.empty())),
@ -611,12 +619,11 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
pniSignedPreKey, pniSignedPreKey,
Optional.empty(), null,
pniPqLastResortPreKey, pniPqLastResortPreKey,
Optional.empty(), Optional.empty(),
Optional.empty())), Optional.empty())),
@ -626,8 +633,7 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false, null,
Optional.empty(),
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
pniSignedPreKey, pniSignedPreKey,
@ -641,9 +647,8 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
Optional.empty(), null,
aciSignedPreKey, aciSignedPreKey,
pniSignedPreKey, pniSignedPreKey,
aciPqLastResortPreKey, aciPqLastResortPreKey,
@ -689,13 +694,6 @@ class RegistrationControllerTest {
when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account); when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account);
when(accountsManager.update(eq(account), any())).thenAnswer(invocation -> {
final Consumer<Account> accountUpdater = invocation.getArgument(1);
accountUpdater.accept(account);
return invocation.getArgument(0);
});
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -730,60 +728,23 @@ class RegistrationControllerTest {
() -> verify(device, never()).setGcmId(any())); () -> verify(device, never()).setGcmId(any()));
} }
@ParameterizedTest
@ValueSource(booleans = {false, true})
void nonAtomicAccountCreationWithNoAtomicFields(boolean requireAtomic) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
RegistrationRequest reg = new RegistrationRequest("session-id",
new byte[0],
new AccountAttributes(true, 1, "test", null, true, new Device.DeviceCapabilities(false, false, false, false)),
true,
requireAtomic,
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty());
try (final Response response = request.post(Entity.json(reg))) {
int expected = requireAtomic ? 422 : 200;
assertEquals(expected, response.getStatus());
}
}
private static Stream<Arguments> atomicAccountCreationSuccess() { private static Stream<Arguments> atomicAccountCreationSuccess() {
final Optional<IdentityKey> aciIdentityKey; final IdentityKey aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey; final IdentityKey pniIdentityKey;
final Optional<ECSignedPreKey> aciSignedPreKey; final ECSignedPreKey aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey; final ECSignedPreKey pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey; final KEMSignedPreKey aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey; final KEMSignedPreKey pniPqLastResortPreKey;
{ {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())); aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())); pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
} }
final AccountAttributes fetchesMessagesAccountAttributes = final AccountAttributes fetchesMessagesAccountAttributes =
@ -796,137 +757,154 @@ class RegistrationControllerTest {
final String apnsVoipToken = "apns-voip-token"; final String apnsVoipToken = "apns-voip-token";
final String gcmToken = "gcm-token"; final String gcmToken = "gcm-token";
return Stream.of(false, true) return Stream.of(
// try with and without strict atomic checking // Fetches messages; no push tokens
.flatMap(requireAtomic -> Arguments.of(new RegistrationRequest("session-id",
Stream.of( new byte[0],
// Fetches messages; no push tokens fetchesMessagesAccountAttributes,
Arguments.of(new RegistrationRequest("session-id", true,
new byte[0], aciIdentityKey,
fetchesMessagesAccountAttributes, pniIdentityKey,
true, aciSignedPreKey,
requireAtomic, pniSignedPreKey,
aciIdentityKey, aciPqLastResortPreKey,
pniIdentityKey, pniPqLastResortPreKey,
aciSignedPreKey, Optional.empty(),
pniSignedPreKey, Optional.empty()),
aciPqLastResortPreKey, aciIdentityKey,
pniPqLastResortPreKey, pniIdentityKey,
Optional.empty(), aciSignedPreKey,
Optional.empty()), pniSignedPreKey,
aciIdentityKey.get(), aciPqLastResortPreKey,
pniIdentityKey.get(), pniPqLastResortPreKey,
aciSignedPreKey.get(), Optional.empty(),
pniSignedPreKey.get(), Optional.empty(),
aciPqLastResortPreKey.get(), Optional.empty()),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.empty()),
// Has APNs tokens // Has APNs tokens
Arguments.of(new RegistrationRequest("session-id", Arguments.of(new RegistrationRequest("session-id",
new byte[0], new byte[0],
pushAccountAttributes, pushAccountAttributes,
true, true,
requireAtomic, aciIdentityKey,
aciIdentityKey, pniIdentityKey,
pniIdentityKey, aciSignedPreKey,
aciSignedPreKey, pniSignedPreKey,
pniSignedPreKey, aciPqLastResortPreKey,
aciPqLastResortPreKey, pniPqLastResortPreKey,
pniPqLastResortPreKey, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty()),
Optional.empty()), aciIdentityKey,
aciIdentityKey.get(), pniIdentityKey,
pniIdentityKey.get(), aciSignedPreKey,
aciSignedPreKey.get(), pniSignedPreKey,
pniSignedPreKey.get(), aciPqLastResortPreKey,
aciPqLastResortPreKey.get(), pniPqLastResortPreKey,
pniPqLastResortPreKey.get(), Optional.of(apnsToken),
Optional.of(apnsToken), Optional.of(apnsVoipToken),
Optional.of(apnsVoipToken), Optional.empty()),
Optional.empty()),
// requires the request to be atomic // requires the request to be atomic
Arguments.of(new RegistrationRequest("session-id", Arguments.of(new RegistrationRequest("session-id",
new byte[0], new byte[0],
pushAccountAttributes, pushAccountAttributes,
true, true,
requireAtomic, aciIdentityKey,
aciIdentityKey, pniIdentityKey,
pniIdentityKey, aciSignedPreKey,
aciSignedPreKey, pniSignedPreKey,
pniSignedPreKey, aciPqLastResortPreKey,
aciPqLastResortPreKey, pniPqLastResortPreKey,
pniPqLastResortPreKey, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty()),
Optional.empty()), aciIdentityKey,
aciIdentityKey.get(), pniIdentityKey,
pniIdentityKey.get(), aciSignedPreKey,
aciSignedPreKey.get(), pniSignedPreKey,
pniSignedPreKey.get(), aciPqLastResortPreKey,
aciPqLastResortPreKey.get(), pniPqLastResortPreKey,
pniPqLastResortPreKey.get(), Optional.of(apnsToken),
Optional.of(apnsToken), Optional.of(apnsVoipToken),
Optional.of(apnsVoipToken), Optional.empty()),
Optional.empty()),
// Fetches messages; no push tokens // Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id", Arguments.of(new RegistrationRequest("session-id",
new byte[0], new byte[0],
pushAccountAttributes, pushAccountAttributes,
true, true,
requireAtomic, aciIdentityKey,
aciIdentityKey, pniIdentityKey,
pniIdentityKey, aciSignedPreKey,
aciSignedPreKey, pniSignedPreKey,
pniSignedPreKey, aciPqLastResortPreKey,
aciPqLastResortPreKey, pniPqLastResortPreKey,
pniPqLastResortPreKey, Optional.empty(),
Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken))),
Optional.of(new GcmRegistrationId(gcmToken))), aciIdentityKey,
aciIdentityKey.get(), pniIdentityKey,
pniIdentityKey.get(), aciSignedPreKey,
aciSignedPreKey.get(), pniSignedPreKey,
pniSignedPreKey.get(), aciPqLastResortPreKey,
aciPqLastResortPreKey.get(), pniPqLastResortPreKey,
pniPqLastResortPreKey.get(), Optional.empty(),
Optional.empty(), Optional.empty(),
Optional.empty(), Optional.of(gcmToken)));
Optional.of(gcmToken))));
} }
/** /**
* Valid request JSON with the give session ID and skipDeviceTransfer * Valid request JSON with the give session ID and skipDeviceTransfer
*/ */
private static String requestJson(final String sessionId, final byte[] recoveryPassword, final boolean skipDeviceTransfer) { private static String requestJson(final String sessionId,
final String rp = encodeRecoveryPassword(recoveryPassword); final byte[] recoveryPassword,
return String.format(""" final boolean skipDeviceTransfer,
{ final int registrationId,
"sessionId": "%s", @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Integer> pniRegistrationId) {
"recoveryPassword": "%s",
"accountAttributes": { final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
"recoveryPassword": "%s", final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
"registrationId": 1
}, final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
"skipDeviceTransfer": %s final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
}
""", encodeSessionId(sessionId), rp, rp, skipDeviceTransfer); final AccountAttributes accountAttributes = new AccountAttributes(true, registrationId, "name", "reglock", true,
new Device.DeviceCapabilities(true, true, true, true));
pniRegistrationId.ifPresent(accountAttributes::setPhoneNumberIdentityRegistrationId);
final RegistrationRequest request = new RegistrationRequest(
Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8)),
recoveryPassword,
accountAttributes,
skipDeviceTransfer,
aciIdentityKey,
pniIdentityKey,
new DeviceActivationRequest(
KeysHelper.signedECPreKey(1, aciIdentityKeyPair),
KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair),
KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair),
Optional.empty(),
Optional.empty()));
try {
return SystemMapper.jsonMapper().writerWithDefaultPrettyPrinter().writeValueAsString(request);
} catch (final JsonProcessingException e) {
throw new UncheckedIOException(e);
}
} }
/** /**
* Valid request JSON with the given session ID * Valid request JSON with the given session ID
*/ */
private static String requestJson(final String sessionId) { private static String requestJson(final String sessionId) {
return requestJson(sessionId, new byte[0], false); return requestJson(sessionId, new byte[0], false, 1, Optional.of(2));
} }
/** /**
* Valid request JSON with the given Recovery Password * Valid request JSON with the given Recovery Password
*/ */
private static String requestJsonRecoveryPassword(final byte[] recoveryPassword) { private static String requestJsonRecoveryPassword(final byte[] recoveryPassword) {
return requestJson("", recoveryPassword, false); return requestJson("", recoveryPassword, false, 1, Optional.of(2));
} }
/** /**
@ -953,12 +931,4 @@ class RegistrationControllerTest {
} }
"""; """;
} }
private static String encodeSessionId(final String sessionId) {
return Base64.getUrlEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
}
private static String encodeRecoveryPassword(final byte[] recoveryPassword) {
return Base64.getEncoder().encodeToString(recoveryPassword);
}
} }