Add a "refresh websocket on number change" provider

This commit is contained in:
Jon Chambers 2021-09-10 15:32:28 -04:00 committed by Jon Chambers
parent 49ccbba2e3
commit 6a5d475198
8 changed files with 185 additions and 23 deletions

View File

@ -12,11 +12,9 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.ws.rs.core.SecurityContext;
import org.glassfish.jersey.server.ContainerRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -44,17 +42,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
private static final String ACCOUNT_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".accountEnabled";
private static final String DEVICES_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".devicesEnabled";
private Optional<Account> findAccount(final ContainerRequest containerRequest) {
return Optional.ofNullable(containerRequest.getSecurityContext())
.map(SecurityContext::getUserPrincipal)
.map(principal -> {
if (principal instanceof AccountAndAuthenticatedDeviceHolder) {
return ((AccountAndAuthenticatedDeviceHolder) principal).getAccount();
}
return null;
});
}
@VisibleForTesting
Map<Long, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream()
@ -63,11 +50,11 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
}
@Override
public void handleRequestStart(final ContainerRequest request) {
public void handleRequestFiltered(final ContainerRequest request) {
// The authenticated principal, if any, will be available after filters have run.
// Now that the account is known, capture a snapshot of `isEnabled` for the account and its devices,
// before carrying out the requests business logic.
findAccount(request)
ContainerRequestUtil.getAuthenticatedAccount(request)
.ifPresent(
account -> {
request.setProperty(ACCOUNT_ENABLED, account.isEnabled());
@ -87,7 +74,7 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
@SuppressWarnings("unchecked") final Map<Long, Boolean> initialDevicesEnabled =
(Map<Long, Boolean>) request.getProperty(DEVICES_ENABLED);
return findAccount(request).map(account -> {
return ContainerRequestUtil.getAuthenticatedAccount(request).map(account -> {
final Set<Long> deviceIdsToDisplace;
if (account.isEnabled() != accountInitiallyEnabled) {

View File

@ -0,0 +1,21 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.ws.rs.core.SecurityContext;
import java.util.Optional;
class ContainerRequestUtil {
static Optional<Account> getAuthenticatedAccount(final ContainerRequest request) {
return Optional.ofNullable(request.getSecurityContext())
.map(SecurityContext::getUserPrincipal)
.map(principal -> principal instanceof AccountAndAuthenticatedDeviceHolder
? ((AccountAndAuthenticatedDeviceHolder) principal).getAccount() : null);
}
}

View File

@ -0,0 +1,45 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Pair;
public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider {
private static final String INITIAL_NUMBER_KEY =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber";
@Override
public void handleRequestFiltered(final ContainerRequest request) {
ContainerRequestUtil.getAuthenticatedAccount(request)
.ifPresent(account -> request.setProperty(INITIAL_NUMBER_KEY, account.getNumber()));
}
@Override
public List<Pair<UUID, Long>> handleRequestFinished(final ContainerRequest request) {
final String initialNumber = (String) request.getProperty(INITIAL_NUMBER_KEY);
if (initialNumber != null) {
final Optional<Account> maybeAuthenticatedAccount = ContainerRequestUtil.getAuthenticatedAccount(request);
return maybeAuthenticatedAccount
.filter(account -> !initialNumber.equals(account.getNumber()))
.map(account -> account.getDevices().stream()
.map(device -> new Pair<>(account.getUuid(), device.getId()))
.collect(Collectors.toList()))
.orElse(Collections.emptyList());
} else {
return Collections.emptyList();
}
}
}

View File

@ -20,7 +20,8 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven
public WebsocketRefreshApplicationEventListener(final ClientPresenceManager clientPresenceManager) {
this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager,
new AuthEnablementRefreshRequirementProvider());
new AuthEnablementRefreshRequirementProvider(),
new PhoneNumberChangeRefreshRequirementProvider());
}
@Override

View File

@ -43,7 +43,7 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene
public void onEvent(final RequestEvent event) {
if (event.getType() == Type.REQUEST_FILTERED) {
for (final WebsocketRefreshRequirementProvider provider : providers) {
provider.handleRequestStart(event.getContainerRequest());
provider.handleRequestFiltered(event.getContainerRequest());
}
} else if (event.getType() == Type.FINISHED) {
final AtomicInteger displacedDevices = new AtomicInteger(0);

View File

@ -21,7 +21,7 @@ public interface WebsocketRefreshRequirementProvider {
*
* @param request the request to observe
*/
void handleRequestStart(ContainerRequest request);
void handleRequestFiltered(ContainerRequest request);
/**
* Processes a request after all normal request handling has been completed.

View File

@ -89,10 +89,10 @@ class AuthEnablementRefreshRequirementProviderTest {
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
private Account account = new Account();
private final Account account = new Account();
private Device authenticatedDevice = DevicesHelper.createDevice(1L);
private Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
private final Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
new TestPrincipal("test", account, authenticatedDevice));
private final ResourceExtension resources = ResourceExtension.builder()
@ -109,14 +109,15 @@ class AuthEnablementRefreshRequirementProviderTest {
private ClientPresenceManager clientPresenceManager;
private WebsocketRefreshRequestEventListener listener;
private AuthEnablementRefreshRequirementProvider provider;
@BeforeEach
void setup() {
clientPresenceManager = mock(ClientPresenceManager.class);
provider = new AuthEnablementRefreshRequirementProvider();
listener = new WebsocketRefreshRequestEventListener(clientPresenceManager, provider);
final WebsocketRefreshRequestEventListener listener =
new WebsocketRefreshRequestEventListener(clientPresenceManager, provider);
when(applicationEventListener.onRequest(any())).thenReturn(listener);
final UUID uuid = UUID.randomUUID();

View File

@ -0,0 +1,107 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.glassfish.jersey.server.ContainerRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
import javax.annotation.Nullable;
import javax.ws.rs.core.SecurityContext;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
class PhoneNumberChangeRefreshRequirementProviderTest {
private PhoneNumberChangeRefreshRequirementProvider provider;
private Account account;
private ContainerRequest request;
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final String NUMBER = "+18005551234";
private static final String CHANGED_NUMBER = "+18005554321";
@BeforeEach
void setUp() {
provider = new PhoneNumberChangeRefreshRequirementProvider();
account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(account.getNumber()).thenReturn(NUMBER);
when(account.getDevices()).thenReturn(Set.of(device));
when(device.getId()).thenReturn(Device.MASTER_ID);
request = mock(ContainerRequest.class);
final Map<String, Object> requestProperties = new HashMap<>();
doAnswer(invocation -> {
requestProperties.put(invocation.getArgument(0, String.class), invocation.getArgument(1));
return null;
}).when(request).setProperty(anyString(), any());
when(request.getProperty(anyString())).thenAnswer(
invocation -> requestProperties.get(invocation.getArgument(0, String.class)));
}
@Test
void handleRequestNoChange() {
setAuthenticatedAccount(request, account);
provider.handleRequestFiltered(request);
assertEquals(Collections.emptyList(), provider.handleRequestFinished(request));
}
@Test
void handleRequestNumberChange() {
setAuthenticatedAccount(request, account);
provider.handleRequestFiltered(request);
when(account.getNumber()).thenReturn(CHANGED_NUMBER);
assertEquals(List.of(new Pair<>(ACCOUNT_UUID, Device.MASTER_ID)), provider.handleRequestFinished(request));
}
@Test
void handleRequestNoAuthenticatedAccount() {
final ContainerRequest request = mock(ContainerRequest.class);
setAuthenticatedAccount(request, null);
provider.handleRequestFiltered(request);
assertEquals(Collections.emptyList(), provider.handleRequestFinished(request));
}
private void setAuthenticatedAccount(final ContainerRequest mockRequest, @Nullable final Account account) {
final SecurityContext securityContext = mock(SecurityContext.class);
when(mockRequest.getSecurityContext()).thenReturn(securityContext);
if (account != null) {
final AuthenticatedAccount authenticatedAccount = mock(AuthenticatedAccount.class);
when(securityContext.getUserPrincipal()).thenReturn(authenticatedAccount);
when(authenticatedAccount.getAccount()).thenReturn(account);
} else {
when(securityContext.getUserPrincipal()).thenReturn(null);
}
}
}