Normalize identity types throughout `KeysController`

This commit is contained in:
Jon Chambers 2023-11-30 10:36:39 -05:00 committed by Jon Chambers
parent e2037dea6c
commit 4cca7aa4bd
4 changed files with 32 additions and 39 deletions

View File

@ -24,11 +24,11 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.ForbiddenException; import javax.ws.rs.ForbiddenException;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam; import javax.ws.rs.HeaderParam;
@ -97,13 +97,13 @@ public class KeysController {
@ApiResponse(responseCode = "200", description = "Body contains the number of available one-time prekeys for the device.", useReturnTypeSchema = true) @ApiResponse(responseCode = "200", description = "Body contains the number of available one-time prekeys for the device.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public CompletableFuture<PreKeyCount> getStatus(@Auth final AuthenticatedAccount auth, public CompletableFuture<PreKeyCount> getStatus(@Auth final AuthenticatedAccount auth,
@QueryParam("identity") final Optional<String> identityType) { @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
final CompletableFuture<Integer> ecCountFuture = final CompletableFuture<Integer> ecCountFuture =
keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); keys.getEcCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
final CompletableFuture<Integer> pqCountFuture = final CompletableFuture<Integer> pqCountFuture =
keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); keys.getPqCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new); return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new);
} }
@ -129,30 +129,25 @@ public class KeysController {
allowableValues={"aci", "pni"}, allowableValues={"aci", "pni"},
defaultValue="aci", defaultValue="aci",
description="whether this operation applies to the account (aci) or phone-number (pni) identity") description="whether this operation applies to the account (aci) or phone-number (pni) identity")
@QueryParam("identity") final Optional<String> identityType, @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) { @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
Account account = disabledPermittedAuth.getAccount(); Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice(); Device device = disabledPermittedAuth.getAuthenticatedDevice();
boolean updateAccount = false; boolean updateAccount = false;
final boolean usePhoneNumberIdentity = usePhoneNumberIdentity(identityType); if (preKeys.signedPreKey() != null && !preKeys.signedPreKey().equals(device.getSignedPreKey(identityType))) {
if (preKeys.signedPreKey() != null &&
!preKeys.signedPreKey().equals(usePhoneNumberIdentity ? device.getSignedPreKey(IdentityType.PNI)
: device.getSignedPreKey(IdentityType.ACI))) {
updateAccount = true; updateAccount = true;
} }
final IdentityKey oldIdentityKey = final IdentityKey oldIdentityKey = account.getIdentityKey(identityType);
usePhoneNumberIdentity ? account.getIdentityKey(IdentityType.PNI) : account.getIdentityKey(IdentityType.ACI);
if (!Objects.equals(preKeys.identityKey(), oldIdentityKey)) { if (!Objects.equals(preKeys.identityKey(), oldIdentityKey)) {
updateAccount = true; updateAccount = true;
final boolean hasIdentityKey = oldIdentityKey != null; final boolean hasIdentityKey = oldIdentityKey != null;
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)) final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))
.and(HAS_IDENTITY_KEY_TAG_NAME, String.valueOf(hasIdentityKey)) .and(HAS_IDENTITY_KEY_TAG_NAME, String.valueOf(hasIdentityKey))
.and(IDENTITY_TYPE_TAG_NAME, usePhoneNumberIdentity ? "pni" : "aci"); .and(IDENTITY_TYPE_TAG_NAME, identityType.name());
if (!device.isPrimary()) { if (!device.isPrimary()) {
Metrics.counter(IDENTITY_KEY_CHANGE_FORBIDDEN_COUNTER_NAME, tags).increment(); Metrics.counter(IDENTITY_KEY_CHANGE_FORBIDDEN_COUNTER_NAME, tags).increment();
@ -163,7 +158,7 @@ public class KeysController {
if (hasIdentityKey) { if (hasIdentityKey) {
logger.warn("Existing {} identity key changed; account age is {} days", logger.warn("Existing {} identity key changed; account age is {} days",
identityType.orElse("aci"), identityType,
Duration.between(Instant.ofEpochMilli(device.getCreated()), Instant.now()).toDays()); Duration.between(Instant.ofEpochMilli(device.getCreated()), Instant.now()).toDays());
} }
} }
@ -172,23 +167,21 @@ public class KeysController {
account = accounts.update(account, a -> { account = accounts.update(account, a -> {
if (preKeys.signedPreKey() != null) { if (preKeys.signedPreKey() != null) {
a.getDevice(device.getId()).ifPresent(d -> { a.getDevice(device.getId()).ifPresent(d -> {
if (usePhoneNumberIdentity) { switch (identityType) {
d.setPhoneNumberIdentitySignedPreKey(preKeys.signedPreKey()); case ACI -> d.setSignedPreKey(preKeys.signedPreKey());
} else { case PNI -> d.setPhoneNumberIdentitySignedPreKey(preKeys.signedPreKey());
d.setSignedPreKey(preKeys.signedPreKey());
} }
}); });
} }
if (usePhoneNumberIdentity) { switch (identityType) {
a.setPhoneNumberIdentityKey(preKeys.identityKey()); case ACI -> a.setIdentityKey(preKeys.identityKey());
} else { case PNI -> a.setPhoneNumberIdentityKey(preKeys.identityKey());
a.setIdentityKey(preKeys.identityKey());
} }
}); });
} }
return keys.store(getIdentifier(account, identityType), device.getId(), return keys.store(account.getIdentifier(identityType), device.getId(),
preKeys.preKeys(), preKeys.pqPreKeys(), preKeys.signedPreKey(), preKeys.pqLastResortPreKey()) preKeys.preKeys(), preKeys.pqPreKeys(), preKeys.signedPreKey(), preKeys.pqLastResortPreKey())
.thenApply(Util.ASYNC_EMPTY_RESPONSE); .thenApply(Util.ASYNC_EMPTY_RESPONSE);
} }
@ -302,33 +295,22 @@ public class KeysController {
@ApiResponse(responseCode = "422", description = "Invalid request format.") @ApiResponse(responseCode = "422", description = "Invalid request format.")
public CompletableFuture<Response> setSignedKey(@Auth final AuthenticatedAccount auth, public CompletableFuture<Response> setSignedKey(@Auth final AuthenticatedAccount auth,
@Valid final ECSignedPreKey signedPreKey, @Valid final ECSignedPreKey signedPreKey,
@QueryParam("identity") final Optional<String> identityType) { @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
Device device = auth.getAuthenticatedDevice(); Device device = auth.getAuthenticatedDevice();
accounts.updateDevice(auth.getAccount(), device.getId(), d -> { accounts.updateDevice(auth.getAccount(), device.getId(), d -> {
if (usePhoneNumberIdentity(identityType)) { switch (identityType) {
d.setPhoneNumberIdentitySignedPreKey(signedPreKey); case ACI -> d.setSignedPreKey(signedPreKey);
} else { case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey);
d.setSignedPreKey(signedPreKey);
} }
}); });
return keys.storeEcSignedPreKeys(getIdentifier(auth.getAccount(), identityType), return keys.storeEcSignedPreKeys(auth.getAccount().getIdentifier(identityType),
Map.of(device.getId(), signedPreKey)) Map.of(device.getId(), signedPreKey))
.thenApply(Util.ASYNC_EMPTY_RESPONSE); .thenApply(Util.ASYNC_EMPTY_RESPONSE);
} }
private static boolean usePhoneNumberIdentity(final Optional<String> identityType) {
return "pni".equals(identityType.map(String::toLowerCase).orElse("aci"));
}
private static UUID getIdentifier(final Account account, final Optional<String> identityType) {
return usePhoneNumberIdentity(identityType) ?
account.getPhoneNumberIdentifier() :
account.getUuid();
}
private List<Device> parseDeviceId(String deviceId, Account account) { private List<Device> parseDeviceId(String deviceId, Account account) {
if (deviceId.equals("*")) { if (deviceId.equals("*")) {
return account.getDevices().stream().filter(Device::isEnabled).toList(); return account.getDevices().stream().filter(Device::isEnabled).toList();

View File

@ -217,6 +217,8 @@ class KeysControllerTest {
when(existsAccount.getUuid()).thenReturn(EXISTS_UUID); when(existsAccount.getUuid()).thenReturn(EXISTS_UUID);
when(existsAccount.getPhoneNumberIdentifier()).thenReturn(EXISTS_PNI); when(existsAccount.getPhoneNumberIdentifier()).thenReturn(EXISTS_PNI);
when(existsAccount.getIdentifier(IdentityType.ACI)).thenReturn(EXISTS_UUID);
when(existsAccount.getIdentifier(IdentityType.PNI)).thenReturn(EXISTS_PNI);
when(existsAccount.getDevice(sampleDeviceId)).thenReturn(Optional.of(sampleDevice)); when(existsAccount.getDevice(sampleDeviceId)).thenReturn(Optional.of(sampleDevice));
when(existsAccount.getDevice(sampleDevice2Id)).thenReturn(Optional.of(sampleDevice2)); when(existsAccount.getDevice(sampleDevice2Id)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(sampleDevice3Id)).thenReturn(Optional.of(sampleDevice3)); when(existsAccount.getDevice(sampleDevice3Id)).thenReturn(Optional.of(sampleDevice3));

View File

@ -119,6 +119,7 @@ public class AccountsHelper {
switch (stubbing.getInvocation().getMethod().getName()) { switch (stubbing.getInvocation().getMethod().getName()) {
case "getUuid" -> when(updatedAccount.getUuid()).thenAnswer(stubbing); case "getUuid" -> when(updatedAccount.getUuid()).thenAnswer(stubbing);
case "getPhoneNumberIdentifier" -> when(updatedAccount.getPhoneNumberIdentifier()).thenAnswer(stubbing); case "getPhoneNumberIdentifier" -> when(updatedAccount.getPhoneNumberIdentifier()).thenAnswer(stubbing);
case "getIdentifier" -> when(updatedAccount.getIdentifier(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);
case "isIdentifiedBy" -> when(updatedAccount.isIdentifiedBy(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing); case "isIdentifiedBy" -> when(updatedAccount.isIdentifiedBy(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);
case "getNumber" -> when(updatedAccount.getNumber()).thenAnswer(stubbing); case "getNumber" -> when(updatedAccount.getNumber()).thenAnswer(stubbing);
case "getUsername" -> when(updatedAccount.getUsernameHash()).thenAnswer(stubbing); case "getUsername" -> when(updatedAccount.getUsernameHash()).thenAnswer(stubbing);

View File

@ -161,16 +161,24 @@ public class AuthHelper {
when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER); when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER);
when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID); when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID);
when(VALID_ACCOUNT.getPhoneNumberIdentifier()).thenReturn(VALID_PNI); when(VALID_ACCOUNT.getPhoneNumberIdentifier()).thenReturn(VALID_PNI);
when(VALID_ACCOUNT.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID);
when(VALID_ACCOUNT.getIdentifier(IdentityType.PNI)).thenReturn(VALID_PNI);
when(VALID_ACCOUNT_TWO.getNumber()).thenReturn(VALID_NUMBER_TWO); when(VALID_ACCOUNT_TWO.getNumber()).thenReturn(VALID_NUMBER_TWO);
when(VALID_ACCOUNT_TWO.getUuid()).thenReturn(VALID_UUID_TWO); when(VALID_ACCOUNT_TWO.getUuid()).thenReturn(VALID_UUID_TWO);
when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO); when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO);
when(VALID_ACCOUNT_TWO.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID_TWO);
when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO);
when(DISABLED_ACCOUNT.getNumber()).thenReturn(DISABLED_NUMBER); when(DISABLED_ACCOUNT.getNumber()).thenReturn(DISABLED_NUMBER);
when(DISABLED_ACCOUNT.getUuid()).thenReturn(DISABLED_UUID); when(DISABLED_ACCOUNT.getUuid()).thenReturn(DISABLED_UUID);
when(DISABLED_ACCOUNT.getIdentifier(IdentityType.ACI)).thenReturn(DISABLED_UUID);
when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER); when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER);
when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID); when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID);
when(UNDISCOVERABLE_ACCOUNT.getIdentifier(IdentityType.ACI)).thenReturn(UNDISCOVERABLE_UUID);
when(VALID_ACCOUNT_3.getNumber()).thenReturn(VALID_NUMBER_3); when(VALID_ACCOUNT_3.getNumber()).thenReturn(VALID_NUMBER_3);
when(VALID_ACCOUNT_3.getUuid()).thenReturn(VALID_UUID_3); when(VALID_ACCOUNT_3.getUuid()).thenReturn(VALID_UUID_3);
when(VALID_ACCOUNT_3.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_3); when(VALID_ACCOUNT_3.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_3);
when(VALID_ACCOUNT_3.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID_3);
when(VALID_ACCOUNT_3.getIdentifier(IdentityType.PNI)).thenReturn(VALID_PNI_3);
when(VALID_ACCOUNT.isEnabled()).thenReturn(true); when(VALID_ACCOUNT.isEnabled()).thenReturn(true);
when(VALID_ACCOUNT_TWO.isEnabled()).thenReturn(true); when(VALID_ACCOUNT_TWO.isEnabled()).thenReturn(true);