diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 7dedebd50..75124a48c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -20,6 +20,7 @@ import java.util.UUID; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.ws.rs.Consumes; +import javax.ws.rs.ForbiddenException; import javax.ws.rs.GET; import javax.ws.rs.HeaderParam; import javax.ws.rs.PUT; @@ -92,16 +93,19 @@ public class KeysController { Device device = disabledPermittedAuth.getAuthenticatedDevice(); boolean updateAccount = false; - if (!preKeys.getSignedPreKey().equals(device.getSignedPreKey())) { - updateAccount = true; - } - - if (!preKeys.getIdentityKey().equals(account.getIdentityKey())) { - updateAccount = true; - } - final boolean usePhoneNumberIdentity = usePhoneNumberIdentity(identityType); + if (!preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) { + updateAccount = true; + } + + if (!preKeys.getIdentityKey().equals(usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() : account.getIdentityKey())) { + updateAccount = true; + if (!device.isMaster()) { + throw new ForbiddenException(); + } + } + if (updateAccount) { account = accounts.update(account, a -> { a.getDevice(device.getId()).ifPresent(d -> { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index bf305a2b6..b02d87b5c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -570,10 +570,10 @@ class KeysControllerTest { Response response = resources.getJerseyTest() - .target("/v2/keys") - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) - .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); + .target("/v2/keys") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) + .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(204); @@ -589,4 +589,24 @@ class KeysControllerTest { verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey)); verify(accounts).update(eq(AuthHelper.DISABLED_ACCOUNT), any()); } + + @Test + void putIdentityKeyNonPrimary() { + final PreKey preKey = new PreKey(31337, "foobar"); + final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig"); + final String identityKey = "barbar"; + + List preKeys = List.of(preKey); + + PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, preKeys); + + Response response = + resources.getJerseyTest() + .target("/v2/keys") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_3, 2L, AuthHelper.VALID_PASSWORD_3_LINKED)) + .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); + + assertThat(response.getStatus()).isEqualTo(403); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java index 104b9946d..655f813c5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java @@ -44,6 +44,12 @@ public class AuthHelper { public static final UUID VALID_UUID_TWO = UUID.randomUUID(); public static final String VALID_PASSWORD_TWO = "baz"; + public static final String VALID_NUMBER_3 = "+14445556666"; + public static final UUID VALID_UUID_3 = UUID.randomUUID(); + public static final UUID VALID_PNI_3 = UUID.randomUUID(); + public static final String VALID_PASSWORD_3_PRIMARY = "3primary"; + public static final String VALID_PASSWORD_3_LINKED = "3linked"; + public static final UUID INVALID_UUID = UUID.randomUUID(); public static final String INVALID_PASSWORD = "bar"; @@ -62,25 +68,34 @@ public class AuthHelper { public static Account VALID_ACCOUNT_TWO = mock(Account.class ); public static Account DISABLED_ACCOUNT = mock(Account.class ); public static Account UNDISCOVERABLE_ACCOUNT = mock(Account.class ); + public static Account VALID_ACCOUNT_3 = mock(Account.class ); - public static Device VALID_DEVICE = mock(Device.class); - public static Device VALID_DEVICE_TWO = mock(Device.class); - public static Device DISABLED_DEVICE = mock(Device.class); - public static Device UNDISCOVERABLE_DEVICE = mock(Device.class); + public static Device VALID_DEVICE = mock(Device.class); + public static Device VALID_DEVICE_TWO = mock(Device.class); + public static Device DISABLED_DEVICE = mock(Device.class); + public static Device UNDISCOVERABLE_DEVICE = mock(Device.class); + public static Device VALID_DEVICE_3_PRIMARY = mock(Device.class); + public static Device VALID_DEVICE_3_LINKED = mock(Device.class); - private static AuthenticationCredentials VALID_CREDENTIALS = mock(AuthenticationCredentials.class); - private static AuthenticationCredentials VALID_CREDENTIALS_TWO = mock(AuthenticationCredentials.class); - private static AuthenticationCredentials DISABLED_CREDENTIALS = mock(AuthenticationCredentials.class); - private static AuthenticationCredentials UNDISCOVERABLE_CREDENTIALS = mock(AuthenticationCredentials.class); + private static AuthenticationCredentials VALID_CREDENTIALS = mock(AuthenticationCredentials.class); + private static AuthenticationCredentials VALID_CREDENTIALS_TWO = mock(AuthenticationCredentials.class); + private static AuthenticationCredentials VALID_CREDENTIALS_3_PRIMARY = mock(AuthenticationCredentials.class); + private static AuthenticationCredentials VALID_CREDENTIALS_3_LINKED = mock(AuthenticationCredentials.class); + private static AuthenticationCredentials DISABLED_CREDENTIALS = mock(AuthenticationCredentials.class); + private static AuthenticationCredentials UNDISCOVERABLE_CREDENTIALS = mock(AuthenticationCredentials.class); public static PolymorphicAuthDynamicFeature getAuthFilter() { when(VALID_CREDENTIALS.verify("foo")).thenReturn(true); when(VALID_CREDENTIALS_TWO.verify("baz")).thenReturn(true); + when(VALID_CREDENTIALS_3_PRIMARY.verify(VALID_PASSWORD_3_PRIMARY)).thenReturn(true); + when(VALID_CREDENTIALS_3_LINKED.verify(VALID_PASSWORD_3_LINKED)).thenReturn(true); when(DISABLED_CREDENTIALS.verify(DISABLED_PASSWORD)).thenReturn(true); when(UNDISCOVERABLE_CREDENTIALS.verify(UNDISCOVERABLE_PASSWORD)).thenReturn(true); when(VALID_DEVICE.getAuthenticationCredentials()).thenReturn(VALID_CREDENTIALS); when(VALID_DEVICE_TWO.getAuthenticationCredentials()).thenReturn(VALID_CREDENTIALS_TWO); + when(VALID_DEVICE_3_PRIMARY.getAuthenticationCredentials()).thenReturn(VALID_CREDENTIALS_3_PRIMARY); + when(VALID_DEVICE_3_LINKED.getAuthenticationCredentials()).thenReturn(VALID_CREDENTIALS_3_LINKED); when(DISABLED_DEVICE.getAuthenticationCredentials()).thenReturn(DISABLED_CREDENTIALS); when(UNDISCOVERABLE_DEVICE.getAuthenticationCredentials()).thenReturn(UNDISCOVERABLE_CREDENTIALS); @@ -88,16 +103,22 @@ public class AuthHelper { when(VALID_DEVICE_TWO.isMaster()).thenReturn(true); when(DISABLED_DEVICE.isMaster()).thenReturn(true); when(UNDISCOVERABLE_DEVICE.isMaster()).thenReturn(true); + when(VALID_DEVICE_3_PRIMARY.isMaster()).thenReturn(true); + when(VALID_DEVICE_3_LINKED.isMaster()).thenReturn(false); when(VALID_DEVICE.getId()).thenReturn(1L); when(VALID_DEVICE_TWO.getId()).thenReturn(1L); when(DISABLED_DEVICE.getId()).thenReturn(1L); when(UNDISCOVERABLE_DEVICE.getId()).thenReturn(1L); + when(VALID_DEVICE_3_PRIMARY.getId()).thenReturn(1L); + when(VALID_DEVICE_3_LINKED.getId()).thenReturn(2L); when(VALID_DEVICE.isEnabled()).thenReturn(true); when(VALID_DEVICE_TWO.isEnabled()).thenReturn(true); when(DISABLED_DEVICE.isEnabled()).thenReturn(false); when(UNDISCOVERABLE_DEVICE.isMaster()).thenReturn(true); + when(VALID_DEVICE_3_PRIMARY.isEnabled()).thenReturn(true); + when(VALID_DEVICE_3_LINKED.isEnabled()).thenReturn(true); when(VALID_ACCOUNT.getDevice(1L)).thenReturn(Optional.of(VALID_DEVICE)); when(VALID_ACCOUNT.getMasterDevice()).thenReturn(Optional.of(VALID_DEVICE)); @@ -107,6 +128,9 @@ public class AuthHelper { when(DISABLED_ACCOUNT.getMasterDevice()).thenReturn(Optional.of(DISABLED_DEVICE)); when(UNDISCOVERABLE_ACCOUNT.getDevice(eq(1L))).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE)); when(UNDISCOVERABLE_ACCOUNT.getMasterDevice()).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE)); + when(VALID_ACCOUNT_3.getDevice(1L)).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY)); + when(VALID_ACCOUNT_3.getMasterDevice()).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY)); + when(VALID_ACCOUNT_3.getDevice(2L)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED)); when(VALID_ACCOUNT_TWO.getEnabledDeviceCount()).thenReturn(6); @@ -119,16 +143,21 @@ public class AuthHelper { when(DISABLED_ACCOUNT.getUuid()).thenReturn(DISABLED_UUID); when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER); when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID); + when(VALID_ACCOUNT_3.getNumber()).thenReturn(VALID_NUMBER_3); + when(VALID_ACCOUNT_3.getUuid()).thenReturn(VALID_UUID_3); + when(VALID_ACCOUNT_3.getPhoneNumberIdentifier()).thenReturn(VALID_PNI_3); when(VALID_ACCOUNT.isEnabled()).thenReturn(true); when(VALID_ACCOUNT_TWO.isEnabled()).thenReturn(true); when(DISABLED_ACCOUNT.isEnabled()).thenReturn(false); when(UNDISCOVERABLE_ACCOUNT.isEnabled()).thenReturn(true); + when(VALID_ACCOUNT_3.isEnabled()).thenReturn(true); when(VALID_ACCOUNT.isDiscoverableByPhoneNumber()).thenReturn(true); when(VALID_ACCOUNT_TWO.isDiscoverableByPhoneNumber()).thenReturn(true); when(DISABLED_ACCOUNT.isDiscoverableByPhoneNumber()).thenReturn(true); when(UNDISCOVERABLE_ACCOUNT.isDiscoverableByPhoneNumber()).thenReturn(false); + when(VALID_ACCOUNT_3.isDiscoverableByPhoneNumber()).thenReturn(true); when(VALID_ACCOUNT.getIdentityKey()).thenReturn(VALID_IDENTITY); @@ -147,6 +176,10 @@ public class AuthHelper { when(ACCOUNTS_MANAGER.getByE164(UNDISCOVERABLE_NUMBER)).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); when(ACCOUNTS_MANAGER.getByAccountIdentifier(UNDISCOVERABLE_UUID)).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); + when(ACCOUNTS_MANAGER.getByE164(VALID_NUMBER_3)).thenReturn(Optional.of(VALID_ACCOUNT_3)); + when(ACCOUNTS_MANAGER.getByAccountIdentifier(VALID_UUID_3)).thenReturn(Optional.of(VALID_ACCOUNT_3)); + when(ACCOUNTS_MANAGER.getByPhoneNumberIdentifier(VALID_PNI_3)).thenReturn(Optional.of(VALID_ACCOUNT_3)); + AccountsHelper.setupMockUpdateForAuthHelper(ACCOUNTS_MANAGER); for (TestAccount testAccount : TEST_ACCOUNTS) { @@ -162,6 +195,10 @@ public class AuthHelper { DisabledPermittedAuthenticatedAccount.class, disabledPermittedAccountAuthFilter)); } + public static String getAuthHeader(UUID uuid, long deviceId, String password) { + return getAuthHeader(uuid.toString() + "." + deviceId, password); + } + public static String getAuthHeader(UUID uuid, String password) { return getAuthHeader(uuid.toString(), password); }