Normalize identity types throughout `KeysController`

This commit is contained in:
Jon Chambers 2023-11-30 10:36:39 -05:00 committed by Jon Chambers
parent e2037dea6c
commit 4cca7aa4bd
4 changed files with 32 additions and 39 deletions

View File

@ -24,11 +24,11 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
@ -97,13 +97,13 @@ public class KeysController {
@ApiResponse(responseCode = "200", description = "Body contains the number of available one-time prekeys for the device.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public CompletableFuture<PreKeyCount> getStatus(@Auth final AuthenticatedAccount auth,
@QueryParam("identity") final Optional<String> identityType) {
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
final CompletableFuture<Integer> ecCountFuture =
keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
keys.getEcCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
final CompletableFuture<Integer> pqCountFuture =
keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
keys.getPqCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new);
}
@ -129,30 +129,25 @@ public class KeysController {
allowableValues={"aci", "pni"},
defaultValue="aci",
description="whether this operation applies to the account (aci) or phone-number (pni) identity")
@QueryParam("identity") final Optional<String> identityType,
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
boolean updateAccount = false;
final boolean usePhoneNumberIdentity = usePhoneNumberIdentity(identityType);
if (preKeys.signedPreKey() != null &&
!preKeys.signedPreKey().equals(usePhoneNumberIdentity ? device.getSignedPreKey(IdentityType.PNI)
: device.getSignedPreKey(IdentityType.ACI))) {
if (preKeys.signedPreKey() != null && !preKeys.signedPreKey().equals(device.getSignedPreKey(identityType))) {
updateAccount = true;
}
final IdentityKey oldIdentityKey =
usePhoneNumberIdentity ? account.getIdentityKey(IdentityType.PNI) : account.getIdentityKey(IdentityType.ACI);
final IdentityKey oldIdentityKey = account.getIdentityKey(identityType);
if (!Objects.equals(preKeys.identityKey(), oldIdentityKey)) {
updateAccount = true;
final boolean hasIdentityKey = oldIdentityKey != null;
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))
.and(HAS_IDENTITY_KEY_TAG_NAME, String.valueOf(hasIdentityKey))
.and(IDENTITY_TYPE_TAG_NAME, usePhoneNumberIdentity ? "pni" : "aci");
.and(IDENTITY_TYPE_TAG_NAME, identityType.name());
if (!device.isPrimary()) {
Metrics.counter(IDENTITY_KEY_CHANGE_FORBIDDEN_COUNTER_NAME, tags).increment();
@ -163,7 +158,7 @@ public class KeysController {
if (hasIdentityKey) {
logger.warn("Existing {} identity key changed; account age is {} days",
identityType.orElse("aci"),
identityType,
Duration.between(Instant.ofEpochMilli(device.getCreated()), Instant.now()).toDays());
}
}
@ -172,23 +167,21 @@ public class KeysController {
account = accounts.update(account, a -> {
if (preKeys.signedPreKey() != null) {
a.getDevice(device.getId()).ifPresent(d -> {
if (usePhoneNumberIdentity) {
d.setPhoneNumberIdentitySignedPreKey(preKeys.signedPreKey());
} else {
d.setSignedPreKey(preKeys.signedPreKey());
switch (identityType) {
case ACI -> d.setSignedPreKey(preKeys.signedPreKey());
case PNI -> d.setPhoneNumberIdentitySignedPreKey(preKeys.signedPreKey());
}
});
}
if (usePhoneNumberIdentity) {
a.setPhoneNumberIdentityKey(preKeys.identityKey());
} else {
a.setIdentityKey(preKeys.identityKey());
switch (identityType) {
case ACI -> a.setIdentityKey(preKeys.identityKey());
case PNI -> a.setPhoneNumberIdentityKey(preKeys.identityKey());
}
});
}
return keys.store(getIdentifier(account, identityType), device.getId(),
return keys.store(account.getIdentifier(identityType), device.getId(),
preKeys.preKeys(), preKeys.pqPreKeys(), preKeys.signedPreKey(), preKeys.pqLastResortPreKey())
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
@ -302,33 +295,22 @@ public class KeysController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
public CompletableFuture<Response> setSignedKey(@Auth final AuthenticatedAccount auth,
@Valid final ECSignedPreKey signedPreKey,
@QueryParam("identity") final Optional<String> identityType) {
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
Device device = auth.getAuthenticatedDevice();
accounts.updateDevice(auth.getAccount(), device.getId(), d -> {
if (usePhoneNumberIdentity(identityType)) {
d.setPhoneNumberIdentitySignedPreKey(signedPreKey);
} else {
d.setSignedPreKey(signedPreKey);
switch (identityType) {
case ACI -> d.setSignedPreKey(signedPreKey);
case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey);
}
});
return keys.storeEcSignedPreKeys(getIdentifier(auth.getAccount(), identityType),
return keys.storeEcSignedPreKeys(auth.getAccount().getIdentifier(identityType),
Map.of(device.getId(), signedPreKey))
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
private static boolean usePhoneNumberIdentity(final Optional<String> identityType) {
return "pni".equals(identityType.map(String::toLowerCase).orElse("aci"));
}
private static UUID getIdentifier(final Account account, final Optional<String> identityType) {
return usePhoneNumberIdentity(identityType) ?
account.getPhoneNumberIdentifier() :
account.getUuid();
}
private List<Device> parseDeviceId(String deviceId, Account account) {
if (deviceId.equals("*")) {
return account.getDevices().stream().filter(Device::isEnabled).toList();

View File

@ -217,6 +217,8 @@ class KeysControllerTest {
when(existsAccount.getUuid()).thenReturn(EXISTS_UUID);
when(existsAccount.getPhoneNumberIdentifier()).thenReturn(EXISTS_PNI);
when(existsAccount.getIdentifier(IdentityType.ACI)).thenReturn(EXISTS_UUID);
when(existsAccount.getIdentifier(IdentityType.PNI)).thenReturn(EXISTS_PNI);
when(existsAccount.getDevice(sampleDeviceId)).thenReturn(Optional.of(sampleDevice));
when(existsAccount.getDevice(sampleDevice2Id)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(sampleDevice3Id)).thenReturn(Optional.of(sampleDevice3));

View File

@ -119,6 +119,7 @@ public class AccountsHelper {
switch (stubbing.getInvocation().getMethod().getName()) {
case "getUuid" -> when(updatedAccount.getUuid()).thenAnswer(stubbing);
case "getPhoneNumberIdentifier" -> when(updatedAccount.getPhoneNumberIdentifier()).thenAnswer(stubbing);
case "getIdentifier" -> when(updatedAccount.getIdentifier(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);
case "isIdentifiedBy" -> when(updatedAccount.isIdentifiedBy(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);
case "getNumber" -> when(updatedAccount.getNumber()).thenAnswer(stubbing);
case "getUsername" -> when(updatedAccount.getUsernameHash()).thenAnswer(stubbing);

View File

@ -161,16 +161,24 @@ public class AuthHelper {
when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER);
when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID);
when(VALID_ACCOUNT.getPhoneNumberIdentifier()).thenReturn(VALID_PNI);
when(VALID_ACCOUNT.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID);
when(VALID_ACCOUNT.getIdentifier(IdentityType.PNI)).thenReturn(VALID_PNI);
when(VALID_ACCOUNT_TWO.getNumber()).thenReturn(VALID_NUMBER_TWO);
when(VALID_ACCOUNT_TWO.getUuid()).thenReturn(VALID_UUID_TWO);
when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO);
when(VALID_ACCOUNT_TWO.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID_TWO);
when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO);
when(DISABLED_ACCOUNT.getNumber()).thenReturn(DISABLED_NUMBER);
when(DISABLED_ACCOUNT.getUuid()).thenReturn(DISABLED_UUID);
when(DISABLED_ACCOUNT.getIdentifier(IdentityType.ACI)).thenReturn(DISABLED_UUID);
when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER);
when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID);
when(UNDISCOVERABLE_ACCOUNT.getIdentifier(IdentityType.ACI)).thenReturn(UNDISCOVERABLE_UUID);
when(VALID_ACCOUNT_3.getNumber()).thenReturn(VALID_NUMBER_3);
when(VALID_ACCOUNT_3.getUuid()).thenReturn(VALID_UUID_3);
when(VALID_ACCOUNT_3.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_3);
when(VALID_ACCOUNT_3.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID_3);
when(VALID_ACCOUNT_3.getIdentifier(IdentityType.PNI)).thenReturn(VALID_PNI_3);
when(VALID_ACCOUNT.isEnabled()).thenReturn(true);
when(VALID_ACCOUNT_TWO.isEnabled()).thenReturn(true);