diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index 2b7127a02..c45884bf9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -387,11 +387,25 @@ public class ProfileController { private void checkFingerprintAndAdd(BatchIdentityCheckRequest.Element element, Collection responseElements, MessageDigest md) { - accountsManager.getByAccountIdentifier(element.aci()).ifPresent(account -> { - if (account.getIdentityKey() == null) return; + + final Optional maybeAccount; + final boolean usePhoneNumberIdentity; + if (element.aci() != null) { + maybeAccount = accountsManager.getByAccountIdentifier(element.aci()); + usePhoneNumberIdentity = false; + } else { + maybeAccount = accountsManager.getByPhoneNumberIdentifier(element.pni()); + usePhoneNumberIdentity = true; + } + + maybeAccount.ifPresent(account -> { + if (account.getIdentityKey() == null || account.getPhoneNumberIdentityKey() == null) { + return; + } byte[] identityKeyBytes; try { - identityKeyBytes = Base64.getDecoder().decode(account.getIdentityKey()); + identityKeyBytes = Base64.getDecoder().decode(usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() + : account.getIdentityKey()); } catch (IllegalArgumentException ignored) { return; } @@ -400,7 +414,7 @@ public class ProfileController { byte[] fingerprint = Util.truncate(digest, 4); if (!Arrays.equals(fingerprint, element.fingerprint())) { - responseElements.add(new BatchIdentityCheckResponse.Element(element.aci(), identityKeyBytes)); + responseElements.add(new BatchIdentityCheckResponse.Element(element.aci(), element.pni(), identityKeyBytes)); } }); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckRequest.java index 6d4ad32e9..12ce35717 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckRequest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities; import java.util.List; import java.util.UUID; +import javax.annotation.Nullable; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.validation.constraints.Size; @@ -15,11 +16,23 @@ import org.whispersystems.textsecuregcm.util.ExactlySize; public record BatchIdentityCheckRequest(@Valid @NotNull @Size(max = 1000) List elements) { /** + * Exactly one of {@code aci} and {@code pni} must be non-null + * * @param aci account id - * @param fingerprint most significant 4 bytes of SHA-256 of the 33-byte identity key field (32-byte curve25519 - * public key prefixed with 0x05) + * @param pni phone number id + * @param fingerprint most significant 4 bytes of SHA-256 of the 33-byte identity key field (32-byte curve25519 public + * key prefixed with 0x05) */ - public record Element(@NotNull UUID aci, @NotNull @ExactlySize(4) byte[] fingerprint) { + public record Element(@Nullable UUID aci, @Nullable UUID pni, @NotNull @ExactlySize(4) byte[] fingerprint) { + public Element { + if (aci == null && pni == null) { + throw new IllegalArgumentException("aci and pni cannot both be null"); + } + + if (aci != null && pni != null) { + throw new IllegalArgumentException("aci and pni cannot both be non-null"); + } + } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckResponse.java index 00301f46f..27f926f3d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckResponse.java @@ -7,10 +7,26 @@ package org.whispersystems.textsecuregcm.entities; import java.util.List; import java.util.UUID; +import javax.annotation.Nullable; import javax.validation.Valid; import javax.validation.constraints.NotNull; import org.whispersystems.textsecuregcm.util.ExactlySize; public record BatchIdentityCheckResponse(@Valid List elements) { - public record Element(@NotNull UUID aci, @NotNull @ExactlySize(33) byte[] identityKey) {} + + /** + * Exactly one of {@code aci} and {@code pni} must be non-null + */ + public record Element(@Nullable UUID aci, @Nullable UUID pni, @NotNull @ExactlySize(33) byte[] identityKey) { + + public Element { + if (aci == null && pni == null) { + throw new IllegalArgumentException("aci and pni cannot both be null"); + } + + if (aci != null && pni != null) { + throw new IllegalArgumentException("aci and pni cannot both be non-null"); + } + } + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java index cfcab4f3e..0d283a3e4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java @@ -120,14 +120,19 @@ class ProfileControllerTest { private static final RateLimiter usernameRateLimiter = mock(RateLimiter.class); private static final S3Client s3client = mock(S3Client.class); - private static final PostPolicyGenerator postPolicyGenerator = new PostPolicyGenerator("us-west-1", "profile-bucket", "accessKey"); + private static final PostPolicyGenerator postPolicyGenerator = new PostPolicyGenerator("us-west-1", "profile-bucket", + "accessKey"); private static final PolicySigner policySigner = new PolicySigner("accessSecret", "us-west-1"); private static final ServerZkProfileOperations zkProfileOperations = mock(ServerZkProfileOperations.class); private static final byte[] UNIDENTIFIED_ACCESS_KEY = "test-uak".getBytes(StandardCharsets.UTF_8); - + private static final String ACCOUNT_IDENTITY_KEY = "barz"; + private static final String ACCOUNT_PHONE_NUMBER_IDENTITY_KEY = "bazz"; + private static final String ACCOUNT_TWO_IDENTITY_KEY = "bar"; + private static final String ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY = "baz"; @SuppressWarnings("unchecked") - private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); + private static final DynamicConfigurationManager dynamicConfigurationManager = mock( + DynamicConfigurationManager.class); private DynamicPaymentsConfiguration dynamicPaymentsConfiguration; private Account profileAccount; @@ -183,8 +188,8 @@ class ProfileControllerTest { profileAccount = mock(Account.class); - when(profileAccount.getIdentityKey()).thenReturn("bar"); - when(profileAccount.getPhoneNumberIdentityKey()).thenReturn("baz"); + when(profileAccount.getIdentityKey()).thenReturn(ACCOUNT_TWO_IDENTITY_KEY); + when(profileAccount.getPhoneNumberIdentityKey()).thenReturn(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY); when(profileAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID_TWO); when(profileAccount.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI_TWO); when(profileAccount.isEnabled()).thenReturn(true); @@ -199,7 +204,8 @@ class ProfileControllerTest { Account capabilitiesAccount = mock(Account.class); - when(capabilitiesAccount.getIdentityKey()).thenReturn("barz"); + when(capabilitiesAccount.getIdentityKey()).thenReturn(ACCOUNT_IDENTITY_KEY); + when(capabilitiesAccount.getPhoneNumberIdentityKey()).thenReturn(ACCOUNT_PHONE_NUMBER_IDENTITY_KEY); when(capabilitiesAccount.isEnabled()).thenReturn(true); when(capabilitiesAccount.isGroupsV2Supported()).thenReturn(true); when(capabilitiesAccount.isGv1MigrationSupported()).thenReturn(true); @@ -242,7 +248,7 @@ class ProfileControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(BaseProfileResponse.class); - assertThat(profile.getIdentityKey()).isEqualTo("bar"); + assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_IDENTITY_KEY); assertThat(profile.getBadges()).hasSize(1).element(0).has(new Condition<>( badge -> "Test Badge".equals(badge.getName()), "has badge with expected name")); @@ -272,7 +278,7 @@ class ProfileControllerTest { .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) .get(BaseProfileResponse.class); - assertThat(profile.getIdentityKey()).isEqualTo("bar"); + assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_IDENTITY_KEY); assertThat(profile.getBadges()).hasSize(1).element(0).has(new Condition<>( badge -> "Test Badge".equals(badge.getName()), "has badge with expected name")); @@ -310,7 +316,7 @@ class ProfileControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(BaseProfileResponse.class); - assertThat(profile.getIdentityKey()).isEqualTo("baz"); + assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY); assertThat(profile.getBadges()).isEmpty(); assertThat(profile.getUuid()).isEqualTo(AuthHelper.VALID_PNI_TWO); assertThat(profile.getCapabilities()).isNotNull(); @@ -742,7 +748,7 @@ class ProfileControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(VersionedProfileResponse.class); - assertThat(profile.getBaseProfileResponse().getIdentityKey()).isEqualTo("bar"); + assertThat(profile.getBaseProfileResponse().getIdentityKey()).isEqualTo(ACCOUNT_TWO_IDENTITY_KEY); assertThat(profile.getName()).isEqualTo("validname"); assertThat(profile.getAbout()).isEqualTo("about"); assertThat(profile.getAboutEmoji()).isEqualTo("emoji"); @@ -1227,9 +1233,12 @@ class ProfileControllerTest { void testBatchIdentityCheck() { try (Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() .post(Entity.json(new BatchIdentityCheckRequest(List.of( - new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, convertStringToFingerprint("barz")), - new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID_TWO, convertStringToFingerprint("bar")), - new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, convertStringToFingerprint("baz")) + new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, + convertStringToFingerprint(ACCOUNT_IDENTITY_KEY)), + new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, + convertStringToFingerprint(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY)), + new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, + convertStringToFingerprint(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY)) ))))) { assertThat(response).isNotNull(); assertThat(response.getStatus()).isEqualTo(200); @@ -1238,37 +1247,39 @@ class ProfileControllerTest { assertThat(identityCheckResponse.elements()).isNotNull().isEmpty(); } - Condition isEitherUuid1orUuid2 = new Condition<>(element -> { - if (AuthHelper.VALID_UUID.equals(element.aci())) { - return Arrays.equals(Base64.getDecoder().decode("barz"), element.identityKey()); - } else if (AuthHelper.VALID_UUID_TWO.equals(element.aci())) { - return Arrays.equals(Base64.getDecoder().decode("bar"), element.identityKey()); - } else { - return false; - } - }, "is either UUID 1 or UUID 2 with the correct identity key"); + Condition isAnExpectedUuid = new Condition<>(element -> { + if (AuthHelper.VALID_UUID.equals(element.aci())) { + return Arrays.equals(Base64.getDecoder().decode(ACCOUNT_IDENTITY_KEY), element.identityKey()); + } else if (AuthHelper.VALID_PNI_TWO.equals(element.pni())) { + return Arrays.equals(Base64.getDecoder().decode(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY), element.identityKey()); + } else { + return false; + } + }, "is an expected UUID with the correct identity key"); try (Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() .post(Entity.json(new BatchIdentityCheckRequest(List.of( - new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, convertStringToFingerprint("else1234")), - new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID_TWO, convertStringToFingerprint("another1")), - new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, convertStringToFingerprint("456")) + new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, convertStringToFingerprint("else1234")), + new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, + convertStringToFingerprint("another1")), + new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, convertStringToFingerprint("456")) ))))) { assertThat(response).isNotNull(); assertThat(response.getStatus()).isEqualTo(200); BatchIdentityCheckResponse identityCheckResponse = response.readEntity(BatchIdentityCheckResponse.class); assertThat(identityCheckResponse).isNotNull(); assertThat(identityCheckResponse.elements()).isNotNull().hasSize(2); - assertThat(identityCheckResponse.elements()).element(0).isNotNull().is(isEitherUuid1orUuid2); - assertThat(identityCheckResponse.elements()).element(1).isNotNull().is(isEitherUuid1orUuid2); + assertThat(identityCheckResponse.elements()).element(0).isNotNull().is(isAnExpectedUuid); + assertThat(identityCheckResponse.elements()).element(1).isNotNull().is(isAnExpectedUuid); } List largeElementList = new ArrayList<>(List.of( - new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, convertStringToFingerprint("else1234")), - new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID_TWO, convertStringToFingerprint("another1")), - new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, convertStringToFingerprint("456")))); + new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, convertStringToFingerprint("else1234")), + new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, convertStringToFingerprint("another1")), + new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, convertStringToFingerprint("456")))); for (int i = 0; i < 900; i++) { - largeElementList.add(new BatchIdentityCheckRequest.Element(UUID.randomUUID(), convertStringToFingerprint("abcd"))); + largeElementList.add( + new BatchIdentityCheckRequest.Element(UUID.randomUUID(), null, convertStringToFingerprint("abcd"))); } try (Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() .post(Entity.json(new BatchIdentityCheckRequest(largeElementList)))) { @@ -1277,12 +1288,98 @@ class ProfileControllerTest { BatchIdentityCheckResponse identityCheckResponse = response.readEntity(BatchIdentityCheckResponse.class); assertThat(identityCheckResponse).isNotNull(); assertThat(identityCheckResponse.elements()).isNotNull().hasSize(2); - assertThat(identityCheckResponse.elements()).element(0).isNotNull().is(isEitherUuid1orUuid2); - assertThat(identityCheckResponse.elements()).element(1).isNotNull().is(isEitherUuid1orUuid2); + assertThat(identityCheckResponse.elements()).element(0).isNotNull().is(isAnExpectedUuid); + assertThat(identityCheckResponse.elements()).element(1).isNotNull().is(isAnExpectedUuid); } } - private byte[] convertStringToFingerprint(String base64) { + @Test + void testBatchIdentityCheckDeserialization() { + + Condition isAnExpectedUuid = new Condition<>(element -> { + if (AuthHelper.VALID_UUID.equals(element.aci())) { + return Arrays.equals(Base64.getDecoder().decode(ACCOUNT_IDENTITY_KEY), element.identityKey()); + } else if (AuthHelper.VALID_PNI_TWO.equals(element.pni())) { + return Arrays.equals(Base64.getDecoder().decode(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY), element.identityKey()); + } else { + return false; + } + }, "is an expected UUID with the correct identity key"); + + // null properties are ok to omit + String json = String.format(""" + { + "elements": [ + { "aci": "%s", "fingerprint": "%s" }, + { "pni": "%s", "fingerprint": "%s" }, + { "aci": "%s", "fingerprint": "%s" } + ] + } + """, AuthHelper.VALID_UUID, Base64.getEncoder().encodeToString(convertStringToFingerprint("else1234")), + AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertStringToFingerprint("another1")), + AuthHelper.INVALID_UUID, Base64.getEncoder().encodeToString(convertStringToFingerprint("456"))); + try (Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() + .post(Entity.entity(json, "application/json"))) { + assertThat(response).isNotNull(); + assertThat(response.getStatus()).isEqualTo(200); + BatchIdentityCheckResponse identityCheckResponse = response.readEntity(BatchIdentityCheckResponse.class); + assertThat(identityCheckResponse).isNotNull(); + assertThat(identityCheckResponse.elements()).isNotNull().hasSize(2); + assertThat(identityCheckResponse.elements()).element(0).isNotNull().is(isAnExpectedUuid); + assertThat(identityCheckResponse.elements()).element(1).isNotNull().is(isAnExpectedUuid); + } + } + + @ParameterizedTest + @MethodSource + void testBatchIdentityCheckDeserializationBadRequest(final String json) { + try (Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() + .post(Entity.entity(json, "application/json"))) { + assertThat(response).isNotNull(); + assertThat(response.getStatus()).isEqualTo(400); + } + } + + static Stream testBatchIdentityCheckDeserializationBadRequest() { + return Stream.of( + Arguments.of( // aci and pni cannot both be null + """ + { + "elements": [ + { "aci": null, "pni": null, "fingerprint": "%s" } + ] + } + """), + Arguments.of( // an empty string is also invalid + """ + { + "elements": [ + { "aci": "", "pni": null, "fingerprint": "%s" } + ] + } + """ + ), + Arguments.of( // as is a blank string + """ + { + "elements": [ + { "aci": null, "pni": " ", "fingerprint": "%s" } + ] + } + """), + Arguments.of( // aci and pni cannot both be non-null + String.format(""" + { + "elements": [ + { "aci": "%s", "pni": "%s", "fingerprint": "%s" } + ] + } + """, AuthHelper.VALID_UUID, AuthHelper.VALID_PNI, + Base64.getEncoder().encodeToString(convertStringToFingerprint("else1234")))) + ); + } + + private static byte[] convertStringToFingerprint(String base64) { MessageDigest sha256; try { sha256 = MessageDigest.getInstance("SHA-256");