diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 1c7e40a32..c9fbbb51a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -595,7 +595,7 @@ public class WhisperServerService extends Application( ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))); - environment.jersey().register(new WebsocketRefreshApplicationEventListener(clientPresenceManager)); + environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); environment.jersey().register(new TimestampResponseFilter()); environment.jersey().register(new VoiceVerificationController(config.getVoiceVerificationConfiguration().getUrl(), config.getVoiceVerificationConfiguration().getLocales())); @@ -607,7 +607,7 @@ public class WhisperServerService extends Application provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), 60000); - provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(clientPresenceManager)); + provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager)); provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET)); provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java index f44c2d2c0..0bc85a1f0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java @@ -14,9 +14,11 @@ import java.util.Set; import java.util.UUID; import java.util.stream.Collectors; import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.monitoring.RequestEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.Pair; @@ -33,34 +35,48 @@ import org.whispersystems.textsecuregcm.util.Pair; */ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefreshRequirementProvider { + private final AccountsManager accountsManager; + private static final Logger logger = LoggerFactory.getLogger(AuthEnablementRefreshRequirementProvider.class); + private static final String ACCOUNT_UUID = AuthEnablementRefreshRequirementProvider.class.getName() + ".accountUuid"; private static final String DEVICES_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".devicesEnabled"; + public AuthEnablementRefreshRequirementProvider(final AccountsManager accountsManager) { + this.accountsManager = accountsManager; + } + @VisibleForTesting - Map buildDevicesEnabledMap(final Account account) { + static Map buildDevicesEnabledMap(final Account account) { return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled)); } @Override - public void handleRequestFiltered(final ContainerRequest request) { - // The authenticated principal, if any, will be available after filters have run. - // Now that the account is known, capture a snapshot of `isEnabled` for the account's devices before carrying out - // the request’s business logic. - ContainerRequestUtil.getAuthenticatedAccount(request) - .ifPresent(account -> request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account))); + public void handleRequestFiltered(final RequestEvent requestEvent) { + if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod().getAnnotation(ChangesDeviceEnabledState.class) != null) { + // The authenticated principal, if any, will be available after filters have run. + // Now that the account is known, capture a snapshot of `isEnabled` for the account's devices before carrying out + // the request’s business logic. + ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()).ifPresent(account -> + setAccount(requestEvent.getContainerRequest(), account)); + } + } + + public static void setAccount(final ContainerRequest containerRequest, final Account account) { + containerRequest.setProperty(ACCOUNT_UUID, account.getUuid()); + containerRequest.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account)); } @Override - public List> handleRequestFinished(final ContainerRequest request) { + public List> handleRequestFinished(final RequestEvent requestEvent) { // Now that the request is finished, check whether `isEnabled` changed for any of the devices. If the value did // change or if a devices was added or removed, all devices must disconnect and reauthenticate. - if (request.getProperty(DEVICES_ENABLED) != null) { + if (requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED) != null) { @SuppressWarnings("unchecked") final Map initialDevicesEnabled = - (Map) request.getProperty(DEVICES_ENABLED); + (Map) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED); - return ContainerRequestUtil.getAuthenticatedAccount(request).map(account -> { + return accountsManager.get((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)).map(account -> { final Set deviceIdsToDisplace; final Map currentDevicesEnabled = buildDevicesEnabledMap(account); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/ChangesDeviceEnabledState.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/ChangesDeviceEnabledState.java new file mode 100644 index 000000000..dc4911cdb --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/ChangesDeviceEnabledState.java @@ -0,0 +1,20 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Indicates that an endpoint may change the "enabled" state of one or more devices associated with an account, and that + * any websockets associated with the account may need to be refreshed after a call to that endpoint. + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface ChangesDeviceEnabledState { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java index 85fb6530a..b51978687 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java @@ -11,6 +11,7 @@ import java.util.Optional; import java.util.UUID; import java.util.stream.Collectors; import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.monitoring.RequestEvent; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.util.Pair; @@ -20,17 +21,18 @@ public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRef PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber"; @Override - public void handleRequestFiltered(final ContainerRequest request) { - ContainerRequestUtil.getAuthenticatedAccount(request) - .ifPresent(account -> request.setProperty(INITIAL_NUMBER_KEY, account.getNumber())); + public void handleRequestFiltered(final RequestEvent requestEvent) { + ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()) + .ifPresent(account -> requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.getNumber())); } @Override - public List> handleRequestFinished(final ContainerRequest request) { - final String initialNumber = (String) request.getProperty(INITIAL_NUMBER_KEY); + public List> handleRequestFinished(final RequestEvent requestEvent) { + final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY); if (initialNumber != null) { - final Optional maybeAuthenticatedAccount = ContainerRequestUtil.getAuthenticatedAccount(request); + final Optional maybeAuthenticatedAccount = + ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()); return maybeAuthenticatedAccount .filter(account -> !initialNumber.equals(account.getNumber())) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java index a6d2a952b..ad7ffeb9c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java @@ -10,6 +10,7 @@ import org.glassfish.jersey.server.monitoring.ApplicationEventListener; import org.glassfish.jersey.server.monitoring.RequestEvent; import org.glassfish.jersey.server.monitoring.RequestEventListener; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.storage.AccountsManager; /** * Delegates request events to a listener that watches for intra-request changes that require websocket refreshes @@ -18,9 +19,11 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven private final WebsocketRefreshRequestEventListener websocketRefreshRequestEventListener; - public WebsocketRefreshApplicationEventListener(final ClientPresenceManager clientPresenceManager) { + public WebsocketRefreshApplicationEventListener(final AccountsManager accountsManager, + final ClientPresenceManager clientPresenceManager) { + this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager, - new AuthEnablementRefreshRequirementProvider(), + new AuthEnablementRefreshRequirementProvider(accountsManager), new PhoneNumberChangeRefreshRequirementProvider()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java index 1029a9971..e61708d9a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java @@ -16,6 +16,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import javax.ws.rs.container.ResourceInfo; +import javax.ws.rs.core.Context; + import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; public class WebsocketRefreshRequestEventListener implements RequestEventListener { @@ -39,17 +42,20 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene this.providers = providers; } + @Context + private ResourceInfo resourceInfo; + @Override public void onEvent(final RequestEvent event) { if (event.getType() == Type.REQUEST_FILTERED) { for (final WebsocketRefreshRequirementProvider provider : providers) { - provider.handleRequestFiltered(event.getContainerRequest()); + provider.handleRequestFiltered(event); } } else if (event.getType() == Type.FINISHED) { final AtomicInteger displacedDevices = new AtomicInteger(0); Arrays.stream(providers) - .flatMap(provider -> provider.handleRequestFinished(event.getContainerRequest()).stream()) + .flatMap(provider -> provider.handleRequestFinished(event).stream()) .distinct() .forEach(pair -> { try { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java index e34c800a1..862c31238 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.auth; import java.util.List; import java.util.UUID; import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.monitoring.RequestEvent; import org.whispersystems.textsecuregcm.util.Pair; /** @@ -19,16 +20,16 @@ public interface WebsocketRefreshRequirementProvider { /** * Processes a request after filters have run and the request has been mapped to a destination controller. * - * @param request the request to observe + * @param requestEvent the request event to observe */ - void handleRequestFiltered(ContainerRequest request); + void handleRequestFiltered(RequestEvent requestEvent); /** * Processes a request after all normal request handling has been completed. * - * @param request the request to observe + * @param requestEvent the request event to observe * @return a list of pairs of account UUID/device ID pairs identifying websockets that need to be refreshed as a * result of the observed request */ - List> handleRequestFinished(ContainerRequest request); + List> handleRequestFinished(RequestEvent requestEvent); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index b3491e9f6..18faa99cd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -43,6 +43,7 @@ import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; +import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; @@ -434,6 +435,7 @@ public class AccountController { @PUT @Path("/gcm/") @Consumes(MediaType.APPLICATION_JSON) + @ChangesDeviceEnabledState public void setGcmRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, @Valid GcmRegistrationId registrationId) { Account account = disabledPermittedAuth.getAccount(); @@ -455,6 +457,7 @@ public class AccountController { @Timed @DELETE @Path("/gcm/") + @ChangesDeviceEnabledState public void deleteGcmRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) { Account account = disabledPermittedAuth.getAccount(); Device device = disabledPermittedAuth.getAuthenticatedDevice(); @@ -470,6 +473,7 @@ public class AccountController { @PUT @Path("/apn/") @Consumes(MediaType.APPLICATION_JSON) + @ChangesDeviceEnabledState public void setApnRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, @Valid ApnRegistrationId registrationId) { Account account = disabledPermittedAuth.getAccount(); @@ -486,6 +490,7 @@ public class AccountController { @Timed @DELETE @Path("/apn/") + @ChangesDeviceEnabledState public void deleteApnRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) { Account account = disabledPermittedAuth.getAccount(); Device device = disabledPermittedAuth.getAuthenticatedDevice(); @@ -538,6 +543,7 @@ public class AccountController { @PUT @Path("/attributes/") @Consumes(MediaType.APPLICATION_JSON) + @ChangesDeviceEnabledState public void setAccountAttributes(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, @HeaderParam("X-Signal-Agent") String userAgent, @Valid AccountAttributes attributes) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 24d4789c2..8ea1041f5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -22,11 +22,15 @@ import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.glassfish.jersey.server.ContainerRequest; +import org.whispersystems.textsecuregcm.auth.AuthEnablementRefreshRequirementProvider; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; +import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.DeviceInfo; @@ -89,6 +93,7 @@ public class DeviceController { @Timed @DELETE @Path("/{device_id}") + @ChangesDeviceEnabledState public void removeDevice(@Auth AuthenticatedAccount auth, @PathParam("device_id") long deviceId) { Account account = auth.getAccount(); if (auth.getAuthenticatedDevice().getId() != Device.MASTER_ID) { @@ -143,10 +148,12 @@ public class DeviceController { @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) @Path("/{verification_code}") + @ChangesDeviceEnabledState public DeviceResponse verifyDeviceToken(@PathParam("verification_code") String verificationCode, @HeaderParam("Authorization") BasicAuthorizationHeader authorizationHeader, @HeaderParam("User-Agent") String userAgent, - @Valid AccountAttributes accountAttributes) + @Valid AccountAttributes accountAttributes, + @Context ContainerRequest containerRequest) throws RateLimitExceededException, DeviceLimitExceededException { @@ -167,6 +174,11 @@ public class DeviceController { throw new WebApplicationException(Response.status(403).build()); } + // Normally, the the "do we need to refresh somebody's websockets" listener can do this on its own. In this case, + // we're not using the conventional authentication system, and so we need to give it a hint so it knows who the + // active user is and what their device states look like. + AuthEnablementRefreshRequirementProvider.setAccount(containerRequest, account.get()); + int maxDeviceLimit = MAX_DEVICES; if (maxDeviceConfiguration.containsKey(account.get().getNumber())) { @@ -191,11 +203,11 @@ public class DeviceController { device.setCreated(System.currentTimeMillis()); device.setCapabilities(accountAttributes.getCapabilities()); - accounts.update(account.get(), a -> { - device.setId(a.getNextDeviceId()); - messages.clear(a.getUuid(), device.getId()); - a.addDevice(device); - }); + accounts.update(account.get(), a -> { + device.setId(a.getNextDeviceId()); + messages.clear(a.getUuid(), device.getId()); + a.addDevice(device); + }); pendingDevices.remove(number); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java index 5ce87a0ae..4c0847ed4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java @@ -11,6 +11,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -72,6 +73,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.websocket.WebSocketResourceProvider; @@ -104,14 +106,18 @@ class AuthEnablementRefreshRequirementProviderTest { .addResource(new TestResource()) .build(); + private AccountsManager accountsManager; private ClientPresenceManager clientPresenceManager; private AuthEnablementRefreshRequirementProvider provider; @BeforeEach void setup() { + accountsManager = mock(AccountsManager.class); clientPresenceManager = mock(ClientPresenceManager.class); - provider = new AuthEnablementRefreshRequirementProvider(); + + provider = new AuthEnablementRefreshRequirementProvider(accountsManager); + final WebsocketRefreshRequestEventListener listener = new WebsocketRefreshRequestEventListener(clientPresenceManager, provider); @@ -122,6 +128,8 @@ class AuthEnablementRefreshRequirementProviderTest { account.addDevice(authenticatedDevice); LongStream.range(2, 4).forEach(deviceId -> account.addDevice(DevicesHelper.createDevice(deviceId))); + when(accountsManager.get(uuid)).thenReturn(Optional.of(account)); + account.getDevices() .forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true)); } @@ -301,6 +309,8 @@ class AuthEnablementRefreshRequirementProviderTest { .get(); assertEquals(200, response.getStatus()); + + verify(accountsManager, never()).get(any(UUID.class)); } @Nested @@ -402,6 +412,7 @@ class AuthEnablementRefreshRequirementProviderTest { @PUT @Path("/account/enabled/{enabled}") + @ChangesDeviceEnabledState public String setAccountEnabled(@Auth TestPrincipal principal, @PathParam("enabled") final boolean enabled) { final Device device = principal.getAccount().getMasterDevice().orElseThrow(); @@ -415,6 +426,7 @@ class AuthEnablementRefreshRequirementProviderTest { @POST @Path("/account/devices/enabled") + @ChangesDeviceEnabledState public String setEnabled(@Auth TestPrincipal principal, Map deviceIdsEnabled) { final StringBuilder response = new StringBuilder(); @@ -431,6 +443,7 @@ class AuthEnablementRefreshRequirementProviderTest { @PUT @Path("/account/devices") + @ChangesDeviceEnabledState public String addDevices(@Auth TestPrincipal auth, List deviceNames) { deviceNames.forEach(name -> { @@ -445,6 +458,7 @@ class AuthEnablementRefreshRequirementProviderTest { @DELETE @Path("/account/devices/{deviceIds}") + @ChangesDeviceEnabledState public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) { Arrays.stream(deviceIds.split(",")) @@ -456,6 +470,7 @@ class AuthEnablementRefreshRequirementProviderTest { @POST @Path("/account/disableMasterDeviceAndDeleteDevice/{deviceId}") + @ChangesDeviceEnabledState public String disableMasterDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") long deviceId) { DevicesHelper.setEnabled(auth.getAccount().getMasterDevice().orElseThrow(), false); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java index 300c45dd0..57c0d8dab 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.auth; import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.monitoring.RequestEvent; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.storage.Account; @@ -33,6 +34,7 @@ class PhoneNumberChangeRefreshRequirementProviderTest { private PhoneNumberChangeRefreshRequirementProvider provider; private Account account; + private RequestEvent requestEvent; private ContainerRequest request; private static final UUID ACCOUNT_UUID = UUID.randomUUID(); @@ -62,23 +64,26 @@ class PhoneNumberChangeRefreshRequirementProviderTest { when(request.getProperty(anyString())).thenAnswer( invocation -> requestProperties.get(invocation.getArgument(0, String.class))); + + requestEvent = mock(RequestEvent.class); + when(requestEvent.getContainerRequest()).thenReturn(request); } @Test void handleRequestNoChange() { setAuthenticatedAccount(request, account); - provider.handleRequestFiltered(request); - assertEquals(Collections.emptyList(), provider.handleRequestFinished(request)); + provider.handleRequestFiltered(requestEvent); + assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent)); } @Test void handleRequestNumberChange() { setAuthenticatedAccount(request, account); - provider.handleRequestFiltered(request); + provider.handleRequestFiltered(requestEvent); when(account.getNumber()).thenReturn(CHANGED_NUMBER); - assertEquals(List.of(new Pair<>(ACCOUNT_UUID, Device.MASTER_ID)), provider.handleRequestFinished(request)); + assertEquals(List.of(new Pair<>(ACCOUNT_UUID, Device.MASTER_ID)), provider.handleRequestFinished(requestEvent)); } @Test @@ -86,11 +91,13 @@ class PhoneNumberChangeRefreshRequirementProviderTest { final ContainerRequest request = mock(ContainerRequest.class); setAuthenticatedAccount(request, null); - provider.handleRequestFiltered(request); - assertEquals(Collections.emptyList(), provider.handleRequestFinished(request)); + when(requestEvent.getContainerRequest()).thenReturn(request); + + provider.handleRequestFiltered(requestEvent); + assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent)); } - private void setAuthenticatedAccount(final ContainerRequest mockRequest, @Nullable final Account account) { + private static void setAuthenticatedAccount(final ContainerRequest mockRequest, @Nullable final Account account) { final SecurityContext securityContext = mock(SecurityContext.class); when(mockRequest.getSecurityContext()).thenReturn(securityContext); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index dfb380d03..1008d1e17 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.controllers; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; @@ -23,6 +24,7 @@ import io.dropwizard.testing.junit5.ResourceExtension; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.stream.Stream; import javax.ws.rs.Path; import javax.ws.rs.client.Entity; @@ -39,12 +41,14 @@ import org.junit.jupiter.params.provider.MethodSource; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; +import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; import org.whispersystems.textsecuregcm.controllers.DeviceController; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.DeviceResponse; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; +import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -85,6 +89,7 @@ class DeviceControllerTest { private static Account account = mock(Account.class ); private static Account maxedAccount = mock(Account.class); private static Device masterDevice = mock(Device.class); + private static ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class); private static Map deviceConfiguration = new HashMap<>(); @@ -93,6 +98,7 @@ class DeviceControllerTest { .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)) .addProvider(new DeviceLimitExceededExceptionMapper()) .addResource(new DumbVerificationDeviceController(pendingDevicesManager, accountsManager, @@ -143,12 +149,19 @@ class DeviceControllerTest { rateLimiter, account, maxedAccount, - masterDevice + masterDevice, + clientPresenceManager ); } @Test void validDeviceRegisterTest() { + when(accountsManager.get(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + + final Device existingDevice = mock(Device.class); + when(existingDevice.getId()).thenReturn(Device.MASTER_ID); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(existingDevice)); + VerificationCode deviceCode = resources.getJerseyTest() .target("/v1/devices/provisioning/code") .request() @@ -170,6 +183,7 @@ class DeviceControllerTest { verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER); verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); + verify(clientPresenceManager).displacePresence(AuthHelper.VALID_UUID, Device.MASTER_ID); } @Test