diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java index 4d308e7bd..db1bd2e26 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java @@ -12,11 +12,9 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.stream.Collectors; -import javax.ws.rs.core.SecurityContext; import org.glassfish.jersey.server.ContainerRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,17 +42,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres private static final String ACCOUNT_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".accountEnabled"; private static final String DEVICES_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".devicesEnabled"; - private Optional findAccount(final ContainerRequest containerRequest) { - return Optional.ofNullable(containerRequest.getSecurityContext()) - .map(SecurityContext::getUserPrincipal) - .map(principal -> { - if (principal instanceof AccountAndAuthenticatedDeviceHolder) { - return ((AccountAndAuthenticatedDeviceHolder) principal).getAccount(); - } - return null; - }); - } - @VisibleForTesting Map buildDevicesEnabledMap(final Account account) { return account.getDevices().stream() @@ -63,11 +50,11 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres } @Override - public void handleRequestStart(final ContainerRequest request) { + public void handleRequestFiltered(final ContainerRequest request) { // The authenticated principal, if any, will be available after filters have run. // Now that the account is known, capture a snapshot of `isEnabled` for the account and its devices, // before carrying out the request’s business logic. - findAccount(request) + ContainerRequestUtil.getAuthenticatedAccount(request) .ifPresent( account -> { request.setProperty(ACCOUNT_ENABLED, account.isEnabled()); @@ -87,7 +74,7 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres @SuppressWarnings("unchecked") final Map initialDevicesEnabled = (Map) request.getProperty(DEVICES_ENABLED); - return findAccount(request).map(account -> { + return ContainerRequestUtil.getAuthenticatedAccount(request).map(account -> { final Set deviceIdsToDisplace; if (account.isEnabled() != accountInitiallyEnabled) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java new file mode 100644 index 000000000..f551b9761 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java @@ -0,0 +1,21 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import org.glassfish.jersey.server.ContainerRequest; +import org.whispersystems.textsecuregcm.storage.Account; +import javax.ws.rs.core.SecurityContext; +import java.util.Optional; + +class ContainerRequestUtil { + + static Optional getAuthenticatedAccount(final ContainerRequest request) { + return Optional.ofNullable(request.getSecurityContext()) + .map(SecurityContext::getUserPrincipal) + .map(principal -> principal instanceof AccountAndAuthenticatedDeviceHolder + ? ((AccountAndAuthenticatedDeviceHolder) principal).getAccount() : null); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java new file mode 100644 index 000000000..85fb6530a --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java @@ -0,0 +1,45 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.stream.Collectors; +import org.glassfish.jersey.server.ContainerRequest; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.util.Pair; + +public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider { + + private static final String INITIAL_NUMBER_KEY = + PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber"; + + @Override + public void handleRequestFiltered(final ContainerRequest request) { + ContainerRequestUtil.getAuthenticatedAccount(request) + .ifPresent(account -> request.setProperty(INITIAL_NUMBER_KEY, account.getNumber())); + } + + @Override + public List> handleRequestFinished(final ContainerRequest request) { + final String initialNumber = (String) request.getProperty(INITIAL_NUMBER_KEY); + + if (initialNumber != null) { + final Optional maybeAuthenticatedAccount = ContainerRequestUtil.getAuthenticatedAccount(request); + + return maybeAuthenticatedAccount + .filter(account -> !initialNumber.equals(account.getNumber())) + .map(account -> account.getDevices().stream() + .map(device -> new Pair<>(account.getUuid(), device.getId())) + .collect(Collectors.toList())) + .orElse(Collections.emptyList()); + } else { + return Collections.emptyList(); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java index 68dac01d0..a6d2a952b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java @@ -20,7 +20,8 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven public WebsocketRefreshApplicationEventListener(final ClientPresenceManager clientPresenceManager) { this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager, - new AuthEnablementRefreshRequirementProvider()); + new AuthEnablementRefreshRequirementProvider(), + new PhoneNumberChangeRefreshRequirementProvider()); } @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java index 8d85fef92..1029a9971 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java @@ -43,7 +43,7 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene public void onEvent(final RequestEvent event) { if (event.getType() == Type.REQUEST_FILTERED) { for (final WebsocketRefreshRequirementProvider provider : providers) { - provider.handleRequestStart(event.getContainerRequest()); + provider.handleRequestFiltered(event.getContainerRequest()); } } else if (event.getType() == Type.FINISHED) { final AtomicInteger displacedDevices = new AtomicInteger(0); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java index f036907f4..e34c800a1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java @@ -21,7 +21,7 @@ public interface WebsocketRefreshRequirementProvider { * * @param request the request to observe */ - void handleRequestStart(ContainerRequest request); + void handleRequestFiltered(ContainerRequest request); /** * Processes a request after all normal request handling has been completed. diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java index 5c1d3bf7b..73411f824 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java @@ -89,10 +89,10 @@ class AuthEnablementRefreshRequirementProviderTest { private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class); - private Account account = new Account(); + private final Account account = new Account(); private Device authenticatedDevice = DevicesHelper.createDevice(1L); - private Supplier> principalSupplier = () -> Optional.of( + private final Supplier> principalSupplier = () -> Optional.of( new TestPrincipal("test", account, authenticatedDevice)); private final ResourceExtension resources = ResourceExtension.builder() @@ -109,14 +109,15 @@ class AuthEnablementRefreshRequirementProviderTest { private ClientPresenceManager clientPresenceManager; - private WebsocketRefreshRequestEventListener listener; private AuthEnablementRefreshRequirementProvider provider; @BeforeEach void setup() { clientPresenceManager = mock(ClientPresenceManager.class); provider = new AuthEnablementRefreshRequirementProvider(); - listener = new WebsocketRefreshRequestEventListener(clientPresenceManager, provider); + final WebsocketRefreshRequestEventListener listener = + new WebsocketRefreshRequestEventListener(clientPresenceManager, provider); + when(applicationEventListener.onRequest(any())).thenReturn(listener); final UUID uuid = UUID.randomUUID(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java new file mode 100644 index 000000000..300c45dd0 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import org.glassfish.jersey.server.ContainerRequest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.Pair; + +import javax.annotation.Nullable; +import javax.ws.rs.core.SecurityContext; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class PhoneNumberChangeRefreshRequirementProviderTest { + + private PhoneNumberChangeRefreshRequirementProvider provider; + + private Account account; + private ContainerRequest request; + + private static final UUID ACCOUNT_UUID = UUID.randomUUID(); + private static final String NUMBER = "+18005551234"; + private static final String CHANGED_NUMBER = "+18005554321"; + + @BeforeEach + void setUp() { + provider = new PhoneNumberChangeRefreshRequirementProvider(); + + account = mock(Account.class); + final Device device = mock(Device.class); + + when(account.getUuid()).thenReturn(ACCOUNT_UUID); + when(account.getNumber()).thenReturn(NUMBER); + when(account.getDevices()).thenReturn(Set.of(device)); + when(device.getId()).thenReturn(Device.MASTER_ID); + + request = mock(ContainerRequest.class); + + final Map requestProperties = new HashMap<>(); + + doAnswer(invocation -> { + requestProperties.put(invocation.getArgument(0, String.class), invocation.getArgument(1)); + return null; + }).when(request).setProperty(anyString(), any()); + + when(request.getProperty(anyString())).thenAnswer( + invocation -> requestProperties.get(invocation.getArgument(0, String.class))); + } + + @Test + void handleRequestNoChange() { + setAuthenticatedAccount(request, account); + + provider.handleRequestFiltered(request); + assertEquals(Collections.emptyList(), provider.handleRequestFinished(request)); + } + + @Test + void handleRequestNumberChange() { + setAuthenticatedAccount(request, account); + + provider.handleRequestFiltered(request); + when(account.getNumber()).thenReturn(CHANGED_NUMBER); + assertEquals(List.of(new Pair<>(ACCOUNT_UUID, Device.MASTER_ID)), provider.handleRequestFinished(request)); + } + + @Test + void handleRequestNoAuthenticatedAccount() { + final ContainerRequest request = mock(ContainerRequest.class); + setAuthenticatedAccount(request, null); + + provider.handleRequestFiltered(request); + assertEquals(Collections.emptyList(), provider.handleRequestFinished(request)); + } + + private void setAuthenticatedAccount(final ContainerRequest mockRequest, @Nullable final Account account) { + final SecurityContext securityContext = mock(SecurityContext.class); + + when(mockRequest.getSecurityContext()).thenReturn(securityContext); + + if (account != null) { + final AuthenticatedAccount authenticatedAccount = mock(AuthenticatedAccount.class); + + when(securityContext.getUserPrincipal()).thenReturn(authenticatedAccount); + when(authenticatedAccount.getAccount()).thenReturn(account); + } else { + when(securityContext.getUserPrincipal()).thenReturn(null); + } + } +}