Add an optional parameter to require atomic account creation

By default, if a registration request has no optional fields for atomic
account creation set, the request will proceed non-atomically. If a
client sets the `atomic` field, now such a request would be rejected.
This commit is contained in:
ravi-signal 2023-07-05 11:24:11 -05:00 committed by GitHub
parent b593d49399
commit fedeef4da5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 160 additions and 74 deletions

View File

@ -77,7 +77,7 @@ public final class Operations {
// register account
final RegistrationRequest registrationRequest = new RegistrationRequest(
null, registrationPassword, accountAttributes, true,
null, registrationPassword, accountAttributes, true, false,
Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
final AccountIdentityResponse registrationResponse = apiPost("/v1/registration", registrationRequest)
@ -113,6 +113,7 @@ public final class Operations {
registrationPassword,
accountAttributes,
true,
true,
Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())),
Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())),
Optional.of(generateSignedECPreKey(1, aciIdentityKeyPair)),

View File

@ -50,6 +50,13 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
""")
boolean skipDeviceTransfer,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
If true, indicates that this is a request for "atomic" registration. If any properties
needed for atomic account creation are not present, the request will fail. If false,
atomic account creation can still occur, but only if all required fields are present.
""")
boolean requireAtomic,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
The ACI-associated identity key for the account, encoded as a base64 string. If
provided, an account will be created "atomically," and all other properties needed for
@ -78,6 +85,7 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
@JsonProperty("recoveryPassword") byte[] recoveryPassword,
@JsonProperty("accountAttributes") AccountAttributes accountAttributes,
@JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer,
@JsonProperty("requireAtomic") boolean requireAtomic,
@JsonProperty("aciIdentityKey") Optional<IdentityKey> aciIdentityKey,
@JsonProperty("pniIdentityKey") Optional<IdentityKey> pniIdentityKey,
@JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey,
@ -90,7 +98,7 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
// This may seem a little verbose, but at the time of writing, Jackson struggles with `@JsonUnwrapped` members in
// records, and this is a workaround. Please see
// https://github.com/FasterXML/jackson-databind/issues/3726#issuecomment-1525396869 for additional context.
this(sessionId, recoveryPassword, accountAttributes, skipDeviceTransfer, aciIdentityKey, pniIdentityKey,
this(sessionId, recoveryPassword, accountAttributes, skipDeviceTransfer, requireAtomic, aciIdentityKey, pniIdentityKey,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnToken, gcmToken));
}
@ -122,7 +130,7 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
&& deviceActivationRequest().aciPqLastResortPreKey().isEmpty()
&& deviceActivationRequest().pniPqLastResortPreKey().isEmpty();
return supportsAtomicAccountCreation() || hasNoAtomicAccountCreationParameters;
return supportsAtomicAccountCreation() || (!requireAtomic() && hasNoAtomicAccountCreationParameters);
}
public boolean supportsAtomicAccountCreation() {

View File

@ -47,6 +47,7 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
@ -117,10 +118,10 @@ class RegistrationControllerTest {
@Test
public void testRegistrationRequest() throws Exception {
assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
}
@Test
@ -447,6 +448,7 @@ class RegistrationControllerTest {
new byte[0],
fetchesMessagesAccountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
@ -461,6 +463,7 @@ class RegistrationControllerTest {
new byte[0],
fetchesMessagesAccountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
@ -475,6 +478,7 @@ class RegistrationControllerTest {
new byte[0],
pushAccountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
@ -533,6 +537,7 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
@ -547,6 +552,7 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
Optional.empty(),
@ -561,6 +567,7 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
@ -575,6 +582,7 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
@ -589,6 +597,7 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
Optional.empty(),
pniIdentityKey,
aciSignedPreKey,
@ -603,6 +612,7 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
Optional.empty(),
aciSignedPreKey,
@ -686,6 +696,43 @@ class RegistrationControllerTest {
() -> verify(device, never()).setGcmId(any()));
}
@ParameterizedTest
@ValueSource(booleans = {false, true})
void nonAtomicAccountCreationWithNoAtomicFields(boolean requireAtomic) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
RegistrationRequest reg = new RegistrationRequest("session-id",
new byte[0],
new AccountAttributes(true, 1, "test", null, true, new Device.DeviceCapabilities()),
true,
requireAtomic,
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty());
try (final Response response = request.post(Entity.json(reg))) {
int expected = requireAtomic ? 422 : 200;
assertEquals(expected, response.getStatus());
}
}
private static Stream<Arguments> atomicAccountCreationSuccess() {
final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey;
@ -715,75 +762,105 @@ class RegistrationControllerTest {
final String apnsVoipToken = "apns-voip-token";
final String gcmToken = "gcm-token";
return Stream.of(
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
fetchesMessagesAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.empty()),
return Stream.of(false, true)
// try with and without strict atomic checking
.flatMap(requireAtomic ->
Stream.of(
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
fetchesMessagesAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.empty()),
// Has APNs tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.empty()),
// Has APNs tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.empty()),
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken))),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.of(gcmToken)));
// requires the request to be atomic
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.empty()),
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken))),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.of(gcmToken))));
}
/**