diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/OptionalAccess.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/OptionalAccess.java index aa6997677..366bffe2b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/OptionalAccess.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/OptionalAccess.java @@ -5,28 +5,30 @@ package org.whispersystems.textsecuregcm.auth; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.Device; - +import java.security.MessageDigest; +import java.util.Optional; import javax.ws.rs.NotAuthorizedException; import javax.ws.rs.NotFoundException; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.Response; -import java.security.MessageDigest; -import java.util.Optional; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class OptionalAccess { public static String ALL_DEVICES_SELECTOR = "*"; - public static void verify(Optional requestAccount, - Optional accessKey, - Optional targetAccount, - String deviceSelector) - { + public static void verify(Optional requestAccount, + Optional accessKey, + Optional targetAccount, + ServiceIdentifier targetIdentifier, + String deviceSelector) { + try { - verify(requestAccount, accessKey, targetAccount); + verify(requestAccount, accessKey, targetAccount, targetIdentifier); if (!ALL_DEVICES_SELECTOR.equals(deviceSelector)) { byte deviceId = Byte.parseByte(deviceSelector); @@ -48,9 +50,11 @@ public class OptionalAccess { } } - public static void verify(Optional requestAccount, - Optional accessKey, - Optional targetAccount) { + public static void verify(Optional requestAccount, + Optional accessKey, + Optional targetAccount, + ServiceIdentifier targetIdentifier) { + if (requestAccount.isPresent()) { // Authenticated requests are never unauthorized; if the target exists, return OK, otherwise throw not-found. if (targetAccount.isPresent()) { @@ -74,6 +78,15 @@ public class OptionalAccess { return; } + if (!targetAccount.get().isIdentifiedBy(targetIdentifier)) { + throw new IllegalArgumentException("Target account is not identified by the given identifier"); + } + + // Unidentified access is only for ACI identities + if (IdentityType.PNI.equals(targetIdentifier.identityType())) { + throw new NotAuthorizedException(Response.Status.UNAUTHORIZED); + } + // At this point, any successful authentication requires a real access key on the target account if (targetAccount.get().getUnidentifiedAccessKey().isEmpty()) { throw new NotAuthorizedException(Response.Status.UNAUTHORIZED); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 62ba345f5..c567b8978 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -349,7 +349,7 @@ public class KeysController { throw new NotAuthorizedException(e); } } else { - OptionalAccess.verify(account, accessKey, maybeTarget, deviceId); + OptionalAccess.verify(account, accessKey, maybeTarget, targetIdentifier, deviceId); } final Account target = maybeTarget.orElseThrow(NotFoundException::new); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 3160d7584..51738dc7a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -374,7 +374,8 @@ public class MessageController { throw new NotFoundException(); } } else { - OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination); + OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination, + destinationIdentifier); } boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().getDevices().size() > 1; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index 8d435d5b9..f79232e72 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -19,7 +19,6 @@ import java.time.Clock; import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.Arrays; -import java.util.Base64; import java.util.Collection; import java.util.Collections; import java.util.HexFormat; @@ -32,7 +31,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.function.Function; import java.util.stream.Collectors; -import javax.annotation.Nullable; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.ws.rs.BadRequestException; @@ -48,7 +46,6 @@ import javax.ws.rs.PathParam; import javax.ws.rs.ProcessingException; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; -import javax.ws.rs.WebApplicationException; import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.core.Context; import javax.ws.rs.core.HttpHeaders; @@ -503,7 +500,7 @@ public class ProfileController { final Optional maybeTargetAccount = accountsManager.getByServiceIdentifier(accountIdentifier); - OptionalAccess.verify(maybeRequester, maybeAccessKey, maybeTargetAccount); + OptionalAccess.verify(maybeRequester, maybeAccessKey, maybeTargetAccount, accountIdentifier); assert maybeTargetAccount.isPresent(); return maybeTargetAccount.get(); @@ -520,19 +517,4 @@ public class ProfileController { now.format(PostPolicyGenerator.AWS_DATE_TIME), policy.second(), signature); } - @Nullable - private static byte[] decodeFromBase64(@Nullable final String input) { - if (input == null) { - return null; - } - return Base64.getDecoder().decode(input); - } - - @Nullable - private static String encodeToBase64(@Nullable final byte[] input) { - if (input == null) { - return null; - } - return Base64.getEncoder().encodeToString(input); - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/OptionalAccessTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/OptionalAccessTest.java index c5c9c6852..c05a6a7a5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/OptionalAccessTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/OptionalAccessTest.java @@ -16,11 +16,16 @@ import java.util.Base64; import java.util.List; import java.util.Optional; import java.util.OptionalInt; +import java.util.UUID; import javax.ws.rs.WebApplicationException; import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; @@ -32,15 +37,17 @@ class OptionalAccessTest { void verify(final Optional requestAccount, final Optional accessKey, final Optional targetAccount, + final ServiceIdentifier targetIdentifier, final String deviceSelector, final OptionalInt expectedStatusCode) { expectedStatusCode.ifPresentOrElse(statusCode -> { final WebApplicationException webApplicationException = assertThrows(WebApplicationException.class, - () -> OptionalAccess.verify(requestAccount, accessKey, targetAccount, deviceSelector)); + () -> OptionalAccess.verify(requestAccount, accessKey, targetAccount, targetIdentifier, deviceSelector)); assertEquals(statusCode, webApplicationException.getResponse().getStatus()); - }, () -> assertDoesNotThrow(() -> OptionalAccess.verify(requestAccount, accessKey, targetAccount, deviceSelector))); + }, () -> assertDoesNotThrow(() -> + OptionalAccess.verify(requestAccount, accessKey, targetAccount, targetIdentifier, deviceSelector))); } private static List verify() { @@ -53,28 +60,39 @@ class OptionalAccessTest { new Anonymous(Base64.getEncoder().encodeToString((unidentifiedAccessKey + "-incorrect").getBytes())); final Account targetAccount = mock(Account.class); + final ServiceIdentifier targetAccountAciIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + final ServiceIdentifier targetAccountPniIdentifier = new PniServiceIdentifier(UUID.randomUUID()); when(targetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class))); when(targetAccount.getUnidentifiedAccessKey()) .thenReturn(Optional.of(unidentifiedAccessKey.getBytes(StandardCharsets.UTF_8))); + when(targetAccount.isIdentifiedBy(targetAccountAciIdentifier)).thenReturn(true); + when(targetAccount.isIdentifiedBy(targetAccountPniIdentifier)).thenReturn(true); final Account allowAllTargetAccount = mock(Account.class); + final ServiceIdentifier allowAllTargetAccountPniIdentifier = new PniServiceIdentifier(UUID.randomUUID()); when(allowAllTargetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class))); when(allowAllTargetAccount.isUnrestrictedUnidentifiedAccess()).thenReturn(true); + when(allowAllTargetAccount.isIdentifiedBy(allowAllTargetAccountPniIdentifier)).thenReturn(true); final Account noUakTargetAccount = mock(Account.class); + final ServiceIdentifier noUakTargetAccountAciIdentifier = new AciServiceIdentifier(UUID.randomUUID()); when(noUakTargetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class))); when(noUakTargetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.empty()); + when(noUakTargetAccount.isIdentifiedBy(noUakTargetAccountAciIdentifier)).thenReturn(true); final Account inactiveTargetAccount = mock(Account.class); + final ServiceIdentifier inactiveTargetAccountAciIdentifier = new AciServiceIdentifier(UUID.randomUUID()); when(inactiveTargetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class))); when(inactiveTargetAccount.getUnidentifiedAccessKey()) .thenReturn(Optional.of(unidentifiedAccessKey.getBytes(StandardCharsets.UTF_8))); + when(inactiveTargetAccount.isIdentifiedBy(inactiveTargetAccountAciIdentifier)).thenReturn(true); return List.of( // Unidentified caller; correct UAK Arguments.of(Optional.empty(), Optional.of(correctUakHeader), Optional.of(targetAccount), + targetAccountAciIdentifier, OptionalAccess.ALL_DEVICES_SELECTOR, OptionalInt.empty()), @@ -82,6 +100,7 @@ class OptionalAccessTest { Arguments.of(Optional.of(mock(Account.class)), Optional.empty(), Optional.of(targetAccount), + targetAccountAciIdentifier, OptionalAccess.ALL_DEVICES_SELECTOR, OptionalInt.empty()), @@ -89,6 +108,7 @@ class OptionalAccessTest { Arguments.of(Optional.empty(), Optional.empty(), Optional.empty(), + new AciServiceIdentifier(UUID.randomUUID()), OptionalAccess.ALL_DEVICES_SELECTOR, OptionalInt.of(401)), @@ -96,6 +116,7 @@ class OptionalAccessTest { Arguments.of(Optional.of(mock(Account.class)), Optional.empty(), Optional.empty(), + new AciServiceIdentifier(UUID.randomUUID()), OptionalAccess.ALL_DEVICES_SELECTOR, OptionalInt.of(404)), @@ -103,6 +124,7 @@ class OptionalAccessTest { Arguments.of(Optional.empty(), Optional.of(correctUakHeader), Optional.of(targetAccount), + targetAccountAciIdentifier, String.valueOf(Device.PRIMARY_ID + 1), OptionalInt.of(401)), @@ -110,6 +132,7 @@ class OptionalAccessTest { Arguments.of(Optional.empty(), Optional.of(incorrectUakHeader), Optional.of(targetAccount), + targetAccountAciIdentifier, OptionalAccess.ALL_DEVICES_SELECTOR, OptionalInt.of(401)), @@ -117,13 +140,15 @@ class OptionalAccessTest { Arguments.of(Optional.empty(), Optional.of(correctUakHeader), Optional.of(noUakTargetAccount), + noUakTargetAccountAciIdentifier, OptionalAccess.ALL_DEVICES_SELECTOR, OptionalInt.of(401)), - // Unidentified caller; target account found, allows unrestricted unidentified access + // Unidentified caller; target account found, allows unrestricted unidentified access, so PNI target doesn't matter Arguments.of(Optional.empty(), Optional.of(incorrectUakHeader), Optional.of(allowAllTargetAccount), + allowAllTargetAccountPniIdentifier, OptionalAccess.ALL_DEVICES_SELECTOR, OptionalInt.empty()), @@ -131,6 +156,7 @@ class OptionalAccessTest { Arguments.of(Optional.empty(), Optional.of(correctUakHeader), Optional.of(inactiveTargetAccount), + inactiveTargetAccountAciIdentifier, OptionalAccess.ALL_DEVICES_SELECTOR, OptionalInt.empty()), @@ -138,8 +164,35 @@ class OptionalAccessTest { Arguments.of(Optional.empty(), Optional.of(correctUakHeader), Optional.of(targetAccount), + targetAccountAciIdentifier, "not a valid identifier", - OptionalInt.of(422)) + OptionalInt.of(422)), + + // Unidentified caller; target account found, but PNI identifier + Arguments.of(Optional.empty(), + Optional.of(correctUakHeader), + Optional.of(targetAccount), + targetAccountPniIdentifier, + OptionalAccess.ALL_DEVICES_SELECTOR, + OptionalInt.of(401)) ); } + + @Test + void testTargetIdentifierIllegalArgument() { + final String unidentifiedAccessKey = RandomStringUtils.randomAlphanumeric(16); + + final Anonymous correctUakHeader = + new Anonymous(Base64.getEncoder().encodeToString(unidentifiedAccessKey.getBytes())); + + final Account targetAccount = mock(Account.class); + when(targetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class))); + when(targetAccount.getUnidentifiedAccessKey()) + .thenReturn(Optional.of(unidentifiedAccessKey.getBytes(StandardCharsets.UTF_8))); + + assertThrows(IllegalArgumentException.class, + () -> OptionalAccess.verify(Optional.empty(), Optional.of(correctUakHeader), Optional.of(targetAccount), + new AciServiceIdentifier(UUID.randomUUID()))); + } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index c125e0840..9501a74f3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -38,7 +38,6 @@ import java.util.OptionalInt; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; - import javax.ws.rs.client.Entity; import javax.ws.rs.client.Invocation; import javax.ws.rs.core.MediaType; @@ -51,7 +50,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.signal.libsignal.protocol.IdentityKey; @@ -227,7 +225,9 @@ class KeysControllerTest { when(sampleDevice4.getId()).thenReturn(sampleDevice4Id); when(existsAccount.getUuid()).thenReturn(EXISTS_UUID); + when(existsAccount.isIdentifiedBy(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(true); when(existsAccount.getPhoneNumberIdentifier()).thenReturn(EXISTS_PNI); + when(existsAccount.isIdentifiedBy(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(true); when(existsAccount.getIdentifier(IdentityType.ACI)).thenReturn(EXISTS_UUID); when(existsAccount.getIdentifier(IdentityType.PNI)).thenReturn(EXISTS_PNI); when(existsAccount.getDevice(sampleDeviceId)).thenReturn(Optional.of(sampleDevice)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java index a36421b32..e43db07c3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java @@ -202,6 +202,8 @@ class ProfileControllerTest { when(profileAccount.getCurrentProfileVersion()).thenReturn(Optional.empty()); when(profileAccount.getUsernameHash()).thenReturn(Optional.of(USERNAME_HASH)); when(profileAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + when(profileAccount.isIdentifiedBy(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO)))).thenReturn(true); + when(profileAccount.isIdentifiedBy(eq(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO)))).thenReturn(true); capabilitiesAccount = mock(Account.class); @@ -1166,6 +1168,7 @@ class ProfileControllerTest { when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(account.getCurrentProfileVersion()).thenReturn(Optional.of(version)); when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + when(account.isIdentifiedBy(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(true); final Instant expiration = Instant.now().plus(org.whispersystems.textsecuregcm.util.ProfileHelper.EXPIRING_PROFILE_KEY_CREDENTIAL_EXPIRATION) .truncatedTo(ChronoUnit.DAYS); @@ -1231,6 +1234,7 @@ class ProfileControllerTest { final Account account = mock(Account.class); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + when(account.isIdentifiedBy(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(true); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(account)); when(profilesManager.get(AuthHelper.VALID_UUID, version)).thenReturn(Optional.of(versionedProfile));