diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 6c65e139f..47f194370 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -24,11 +24,11 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.ws.rs.Consumes; +import javax.ws.rs.DefaultValue; import javax.ws.rs.ForbiddenException; import javax.ws.rs.GET; import javax.ws.rs.HeaderParam; @@ -97,13 +97,13 @@ public class KeysController { @ApiResponse(responseCode = "200", description = "Body contains the number of available one-time prekeys for the device.", useReturnTypeSchema = true) @ApiResponse(responseCode = "401", description = "Account authentication check failed.") public CompletableFuture getStatus(@Auth final AuthenticatedAccount auth, - @QueryParam("identity") final Optional identityType) { + @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) { final CompletableFuture ecCountFuture = - keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); + keys.getEcCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId()); final CompletableFuture pqCountFuture = - keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); + keys.getPqCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId()); return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new); } @@ -129,30 +129,25 @@ public class KeysController { allowableValues={"aci", "pni"}, defaultValue="aci", description="whether this operation applies to the account (aci) or phone-number (pni) identity") - @QueryParam("identity") final Optional identityType, + @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) { Account account = disabledPermittedAuth.getAccount(); Device device = disabledPermittedAuth.getAuthenticatedDevice(); boolean updateAccount = false; - final boolean usePhoneNumberIdentity = usePhoneNumberIdentity(identityType); - - if (preKeys.signedPreKey() != null && - !preKeys.signedPreKey().equals(usePhoneNumberIdentity ? device.getSignedPreKey(IdentityType.PNI) - : device.getSignedPreKey(IdentityType.ACI))) { + if (preKeys.signedPreKey() != null && !preKeys.signedPreKey().equals(device.getSignedPreKey(identityType))) { updateAccount = true; } - final IdentityKey oldIdentityKey = - usePhoneNumberIdentity ? account.getIdentityKey(IdentityType.PNI) : account.getIdentityKey(IdentityType.ACI); + final IdentityKey oldIdentityKey = account.getIdentityKey(identityType); if (!Objects.equals(preKeys.identityKey(), oldIdentityKey)) { updateAccount = true; final boolean hasIdentityKey = oldIdentityKey != null; final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)) .and(HAS_IDENTITY_KEY_TAG_NAME, String.valueOf(hasIdentityKey)) - .and(IDENTITY_TYPE_TAG_NAME, usePhoneNumberIdentity ? "pni" : "aci"); + .and(IDENTITY_TYPE_TAG_NAME, identityType.name()); if (!device.isPrimary()) { Metrics.counter(IDENTITY_KEY_CHANGE_FORBIDDEN_COUNTER_NAME, tags).increment(); @@ -163,7 +158,7 @@ public class KeysController { if (hasIdentityKey) { logger.warn("Existing {} identity key changed; account age is {} days", - identityType.orElse("aci"), + identityType, Duration.between(Instant.ofEpochMilli(device.getCreated()), Instant.now()).toDays()); } } @@ -172,23 +167,21 @@ public class KeysController { account = accounts.update(account, a -> { if (preKeys.signedPreKey() != null) { a.getDevice(device.getId()).ifPresent(d -> { - if (usePhoneNumberIdentity) { - d.setPhoneNumberIdentitySignedPreKey(preKeys.signedPreKey()); - } else { - d.setSignedPreKey(preKeys.signedPreKey()); + switch (identityType) { + case ACI -> d.setSignedPreKey(preKeys.signedPreKey()); + case PNI -> d.setPhoneNumberIdentitySignedPreKey(preKeys.signedPreKey()); } }); } - if (usePhoneNumberIdentity) { - a.setPhoneNumberIdentityKey(preKeys.identityKey()); - } else { - a.setIdentityKey(preKeys.identityKey()); + switch (identityType) { + case ACI -> a.setIdentityKey(preKeys.identityKey()); + case PNI -> a.setPhoneNumberIdentityKey(preKeys.identityKey()); } }); } - return keys.store(getIdentifier(account, identityType), device.getId(), + return keys.store(account.getIdentifier(identityType), device.getId(), preKeys.preKeys(), preKeys.pqPreKeys(), preKeys.signedPreKey(), preKeys.pqLastResortPreKey()) .thenApply(Util.ASYNC_EMPTY_RESPONSE); } @@ -302,33 +295,22 @@ public class KeysController { @ApiResponse(responseCode = "422", description = "Invalid request format.") public CompletableFuture setSignedKey(@Auth final AuthenticatedAccount auth, @Valid final ECSignedPreKey signedPreKey, - @QueryParam("identity") final Optional identityType) { + @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) { Device device = auth.getAuthenticatedDevice(); accounts.updateDevice(auth.getAccount(), device.getId(), d -> { - if (usePhoneNumberIdentity(identityType)) { - d.setPhoneNumberIdentitySignedPreKey(signedPreKey); - } else { - d.setSignedPreKey(signedPreKey); + switch (identityType) { + case ACI -> d.setSignedPreKey(signedPreKey); + case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey); } }); - return keys.storeEcSignedPreKeys(getIdentifier(auth.getAccount(), identityType), + return keys.storeEcSignedPreKeys(auth.getAccount().getIdentifier(identityType), Map.of(device.getId(), signedPreKey)) .thenApply(Util.ASYNC_EMPTY_RESPONSE); } - private static boolean usePhoneNumberIdentity(final Optional identityType) { - return "pni".equals(identityType.map(String::toLowerCase).orElse("aci")); - } - - private static UUID getIdentifier(final Account account, final Optional identityType) { - return usePhoneNumberIdentity(identityType) ? - account.getPhoneNumberIdentifier() : - account.getUuid(); - } - private List parseDeviceId(String deviceId, Account account) { if (deviceId.equals("*")) { return account.getDevices().stream().filter(Device::isEnabled).toList(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index cb5c30c40..699ab3a34 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -217,6 +217,8 @@ class KeysControllerTest { when(existsAccount.getUuid()).thenReturn(EXISTS_UUID); when(existsAccount.getPhoneNumberIdentifier()).thenReturn(EXISTS_PNI); + when(existsAccount.getIdentifier(IdentityType.ACI)).thenReturn(EXISTS_UUID); + when(existsAccount.getIdentifier(IdentityType.PNI)).thenReturn(EXISTS_PNI); when(existsAccount.getDevice(sampleDeviceId)).thenReturn(Optional.of(sampleDevice)); when(existsAccount.getDevice(sampleDevice2Id)).thenReturn(Optional.of(sampleDevice2)); when(existsAccount.getDevice(sampleDevice3Id)).thenReturn(Optional.of(sampleDevice3)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java index 10454d612..be10b451d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -119,6 +119,7 @@ public class AccountsHelper { switch (stubbing.getInvocation().getMethod().getName()) { case "getUuid" -> when(updatedAccount.getUuid()).thenAnswer(stubbing); case "getPhoneNumberIdentifier" -> when(updatedAccount.getPhoneNumberIdentifier()).thenAnswer(stubbing); + case "getIdentifier" -> when(updatedAccount.getIdentifier(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing); case "isIdentifiedBy" -> when(updatedAccount.isIdentifiedBy(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing); case "getNumber" -> when(updatedAccount.getNumber()).thenAnswer(stubbing); case "getUsername" -> when(updatedAccount.getUsernameHash()).thenAnswer(stubbing); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java index 33920e098..20179dee0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java @@ -161,16 +161,24 @@ public class AuthHelper { when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER); when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID); when(VALID_ACCOUNT.getPhoneNumberIdentifier()).thenReturn(VALID_PNI); + when(VALID_ACCOUNT.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID); + when(VALID_ACCOUNT.getIdentifier(IdentityType.PNI)).thenReturn(VALID_PNI); when(VALID_ACCOUNT_TWO.getNumber()).thenReturn(VALID_NUMBER_TWO); when(VALID_ACCOUNT_TWO.getUuid()).thenReturn(VALID_UUID_TWO); when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO); + when(VALID_ACCOUNT_TWO.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID_TWO); + when(VALID_ACCOUNT_TWO.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_TWO); when(DISABLED_ACCOUNT.getNumber()).thenReturn(DISABLED_NUMBER); when(DISABLED_ACCOUNT.getUuid()).thenReturn(DISABLED_UUID); + when(DISABLED_ACCOUNT.getIdentifier(IdentityType.ACI)).thenReturn(DISABLED_UUID); when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER); when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID); + when(UNDISCOVERABLE_ACCOUNT.getIdentifier(IdentityType.ACI)).thenReturn(UNDISCOVERABLE_UUID); when(VALID_ACCOUNT_3.getNumber()).thenReturn(VALID_NUMBER_3); when(VALID_ACCOUNT_3.getUuid()).thenReturn(VALID_UUID_3); when(VALID_ACCOUNT_3.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_3); + when(VALID_ACCOUNT_3.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID_3); + when(VALID_ACCOUNT_3.getIdentifier(IdentityType.PNI)).thenReturn(VALID_PNI_3); when(VALID_ACCOUNT.isEnabled()).thenReturn(true); when(VALID_ACCOUNT_TWO.isEnabled()).thenReturn(true);