Add an optional parameter to require atomic account creation

By default, if a registration request has no optional fields for atomic
account creation set, the request will proceed non-atomically. If a
client sets the `atomic` field, now such a request would be rejected.
This commit is contained in:
ravi-signal 2023-07-05 11:24:11 -05:00 committed by GitHub
parent b593d49399
commit fedeef4da5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 160 additions and 74 deletions

View File

@ -77,7 +77,7 @@ public final class Operations {
// register account // register account
final RegistrationRequest registrationRequest = new RegistrationRequest( final RegistrationRequest registrationRequest = new RegistrationRequest(
null, registrationPassword, accountAttributes, true, null, registrationPassword, accountAttributes, true, false,
Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
final AccountIdentityResponse registrationResponse = apiPost("/v1/registration", registrationRequest) final AccountIdentityResponse registrationResponse = apiPost("/v1/registration", registrationRequest)
@ -113,6 +113,7 @@ public final class Operations {
registrationPassword, registrationPassword,
accountAttributes, accountAttributes,
true, true,
true,
Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())), Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())),
Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())), Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())),
Optional.of(generateSignedECPreKey(1, aciIdentityKeyPair)), Optional.of(generateSignedECPreKey(1, aciIdentityKeyPair)),

View File

@ -50,6 +50,13 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
""") """)
boolean skipDeviceTransfer, boolean skipDeviceTransfer,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
If true, indicates that this is a request for "atomic" registration. If any properties
needed for atomic account creation are not present, the request will fail. If false,
atomic account creation can still occur, but only if all required fields are present.
""")
boolean requireAtomic,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
The ACI-associated identity key for the account, encoded as a base64 string. If The ACI-associated identity key for the account, encoded as a base64 string. If
provided, an account will be created "atomically," and all other properties needed for provided, an account will be created "atomically," and all other properties needed for
@ -78,6 +85,7 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
@JsonProperty("recoveryPassword") byte[] recoveryPassword, @JsonProperty("recoveryPassword") byte[] recoveryPassword,
@JsonProperty("accountAttributes") AccountAttributes accountAttributes, @JsonProperty("accountAttributes") AccountAttributes accountAttributes,
@JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer, @JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer,
@JsonProperty("requireAtomic") boolean requireAtomic,
@JsonProperty("aciIdentityKey") Optional<IdentityKey> aciIdentityKey, @JsonProperty("aciIdentityKey") Optional<IdentityKey> aciIdentityKey,
@JsonProperty("pniIdentityKey") Optional<IdentityKey> pniIdentityKey, @JsonProperty("pniIdentityKey") Optional<IdentityKey> pniIdentityKey,
@JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey, @JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey,
@ -90,7 +98,7 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
// This may seem a little verbose, but at the time of writing, Jackson struggles with `@JsonUnwrapped` members in // This may seem a little verbose, but at the time of writing, Jackson struggles with `@JsonUnwrapped` members in
// records, and this is a workaround. Please see // records, and this is a workaround. Please see
// https://github.com/FasterXML/jackson-databind/issues/3726#issuecomment-1525396869 for additional context. // https://github.com/FasterXML/jackson-databind/issues/3726#issuecomment-1525396869 for additional context.
this(sessionId, recoveryPassword, accountAttributes, skipDeviceTransfer, aciIdentityKey, pniIdentityKey, this(sessionId, recoveryPassword, accountAttributes, skipDeviceTransfer, requireAtomic, aciIdentityKey, pniIdentityKey,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnToken, gcmToken)); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnToken, gcmToken));
} }
@ -122,7 +130,7 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
&& deviceActivationRequest().aciPqLastResortPreKey().isEmpty() && deviceActivationRequest().aciPqLastResortPreKey().isEmpty()
&& deviceActivationRequest().pniPqLastResortPreKey().isEmpty(); && deviceActivationRequest().pniPqLastResortPreKey().isEmpty();
return supportsAtomicAccountCreation() || hasNoAtomicAccountCreationParameters; return supportsAtomicAccountCreation() || (!requireAtomic() && hasNoAtomicAccountCreationParameters);
} }
public boolean supportsAtomicAccountCreation() { public boolean supportsAtomicAccountCreation() {

View File

@ -47,6 +47,7 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
@ -117,10 +118,10 @@ class RegistrationControllerTest {
@Test @Test
public void testRegistrationRequest() throws Exception { public void testRegistrationRequest() throws Exception {
assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
} }
@Test @Test
@ -447,6 +448,7 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
fetchesMessagesAccountAttributes, fetchesMessagesAccountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
@ -461,6 +463,7 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
fetchesMessagesAccountAttributes, fetchesMessagesAccountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
@ -475,6 +478,7 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
pushAccountAttributes, pushAccountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
@ -533,6 +537,7 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
@ -547,6 +552,7 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
Optional.empty(), Optional.empty(),
@ -561,6 +567,7 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
@ -575,6 +582,7 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
@ -589,6 +597,7 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
Optional.empty(), Optional.empty(),
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
@ -603,6 +612,7 @@ class RegistrationControllerTest {
new byte[0], new byte[0],
accountAttributes, accountAttributes,
true, true,
false,
aciIdentityKey, aciIdentityKey,
Optional.empty(), Optional.empty(),
aciSignedPreKey, aciSignedPreKey,
@ -686,6 +696,43 @@ class RegistrationControllerTest {
() -> verify(device, never()).setGcmId(any())); () -> verify(device, never()).setGcmId(any()));
} }
@ParameterizedTest
@ValueSource(booleans = {false, true})
void nonAtomicAccountCreationWithNoAtomicFields(boolean requireAtomic) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
RegistrationRequest reg = new RegistrationRequest("session-id",
new byte[0],
new AccountAttributes(true, 1, "test", null, true, new Device.DeviceCapabilities()),
true,
requireAtomic,
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty());
try (final Response response = request.post(Entity.json(reg))) {
int expected = requireAtomic ? 422 : 200;
assertEquals(expected, response.getStatus());
}
}
private static Stream<Arguments> atomicAccountCreationSuccess() { private static Stream<Arguments> atomicAccountCreationSuccess() {
final Optional<IdentityKey> aciIdentityKey; final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey; final Optional<IdentityKey> pniIdentityKey;
@ -715,75 +762,105 @@ class RegistrationControllerTest {
final String apnsVoipToken = "apns-voip-token"; final String apnsVoipToken = "apns-voip-token";
final String gcmToken = "gcm-token"; final String gcmToken = "gcm-token";
return Stream.of( return Stream.of(false, true)
// Fetches messages; no push tokens // try with and without strict atomic checking
Arguments.of(new RegistrationRequest("session-id", .flatMap(requireAtomic ->
new byte[0], Stream.of(
fetchesMessagesAccountAttributes, // Fetches messages; no push tokens
true, Arguments.of(new RegistrationRequest("session-id",
aciIdentityKey, new byte[0],
pniIdentityKey, fetchesMessagesAccountAttributes,
aciSignedPreKey, true,
pniSignedPreKey, requireAtomic,
aciPqLastResortPreKey, aciIdentityKey,
pniPqLastResortPreKey, pniIdentityKey,
Optional.empty(), aciSignedPreKey,
Optional.empty()), pniSignedPreKey,
aciIdentityKey.get(), aciPqLastResortPreKey,
pniIdentityKey.get(), pniPqLastResortPreKey,
aciSignedPreKey.get(), Optional.empty(),
pniSignedPreKey.get(), Optional.empty()),
aciPqLastResortPreKey.get(), aciIdentityKey.get(),
pniPqLastResortPreKey.get(), pniIdentityKey.get(),
Optional.empty(), aciSignedPreKey.get(),
Optional.empty(), pniSignedPreKey.get(),
Optional.empty()), aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.empty()),
// Has APNs tokens // Has APNs tokens
Arguments.of(new RegistrationRequest("session-id", Arguments.of(new RegistrationRequest("session-id",
new byte[0], new byte[0],
pushAccountAttributes, pushAccountAttributes,
true, true,
aciIdentityKey, requireAtomic,
pniIdentityKey, aciIdentityKey,
aciSignedPreKey, pniIdentityKey,
pniSignedPreKey, aciSignedPreKey,
aciPqLastResortPreKey, pniSignedPreKey,
pniPqLastResortPreKey, aciPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), pniPqLastResortPreKey,
Optional.empty()), Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
aciIdentityKey.get(), Optional.empty()),
pniIdentityKey.get(), aciIdentityKey.get(),
aciSignedPreKey.get(), pniIdentityKey.get(),
pniSignedPreKey.get(), aciSignedPreKey.get(),
aciPqLastResortPreKey.get(), pniSignedPreKey.get(),
pniPqLastResortPreKey.get(), aciPqLastResortPreKey.get(),
Optional.of(apnsToken), pniPqLastResortPreKey.get(),
Optional.of(apnsVoipToken), Optional.of(apnsToken),
Optional.empty()), Optional.of(apnsVoipToken),
Optional.empty()),
// Fetches messages; no push tokens // requires the request to be atomic
Arguments.of(new RegistrationRequest("session-id", Arguments.of(new RegistrationRequest("session-id",
new byte[0], new byte[0],
pushAccountAttributes, pushAccountAttributes,
true, true,
aciIdentityKey, requireAtomic,
pniIdentityKey, aciIdentityKey,
aciSignedPreKey, pniIdentityKey,
pniSignedPreKey, aciSignedPreKey,
aciPqLastResortPreKey, pniSignedPreKey,
pniPqLastResortPreKey, aciPqLastResortPreKey,
Optional.empty(), pniPqLastResortPreKey,
Optional.of(new GcmRegistrationId(gcmToken))), Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
aciIdentityKey.get(), Optional.empty()),
pniIdentityKey.get(), aciIdentityKey.get(),
aciSignedPreKey.get(), pniIdentityKey.get(),
pniSignedPreKey.get(), aciSignedPreKey.get(),
aciPqLastResortPreKey.get(), pniSignedPreKey.get(),
pniPqLastResortPreKey.get(), aciPqLastResortPreKey.get(),
Optional.empty(), pniPqLastResortPreKey.get(),
Optional.empty(), Optional.of(apnsToken),
Optional.of(gcmToken))); Optional.of(apnsVoipToken),
Optional.empty()),
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken))),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.of(gcmToken))));
} }
/** /**