Add opt-in timeouts to provisioning websocket
This commit is contained in:
parent
6460327372
commit
68f27be7cd
|
@ -569,6 +569,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
.scheduledExecutorService(name(getClass(), "cloudflareTurnRetry-%d")).threads(1).build();
|
.scheduledExecutorService(name(getClass(), "cloudflareTurnRetry-%d")).threads(1).build();
|
||||||
ScheduledExecutorService messagePollExecutor = environment.lifecycle()
|
ScheduledExecutorService messagePollExecutor = environment.lifecycle()
|
||||||
.scheduledExecutorService(name(getClass(), "messagePollExecutor-%d")).threads(1).build();
|
.scheduledExecutorService(name(getClass(), "messagePollExecutor-%d")).threads(1).build();
|
||||||
|
ScheduledExecutorService provisioningWebsocketTimeoutExecutor = environment.lifecycle()
|
||||||
|
.scheduledExecutorService(name(getClass(), "provisioningWebsocketTimeout-%d")).threads(1).build();
|
||||||
|
|
||||||
final ManagedNioEventLoopGroup dnsResolutionEventLoopGroup = new ManagedNioEventLoopGroup();
|
final ManagedNioEventLoopGroup dnsResolutionEventLoopGroup = new ManagedNioEventLoopGroup();
|
||||||
final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.next())
|
final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.next())
|
||||||
|
@ -1171,7 +1173,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000));
|
webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000));
|
||||||
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager,
|
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager,
|
||||||
disconnectionRequestManager));
|
disconnectionRequestManager));
|
||||||
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager));
|
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager, provisioningWebsocketTimeoutExecutor, Duration.ofSeconds(90)));
|
||||||
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
|
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
|
||||||
provisioningEnvironment.jersey().register(new KeepAliveController(webSocketConnectionEventManager));
|
provisioningEnvironment.jersey().register(new KeepAliveController(webSocketConnectionEventManager));
|
||||||
provisioningEnvironment.jersey().register(new TimestampResponseFilter());
|
provisioningEnvironment.jersey().register(new TimestampResponseFilter());
|
||||||
|
|
|
@ -130,7 +130,7 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||||
import org.whispersystems.textsecuregcm.util.Util;
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
|
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
|
||||||
import org.whispersystems.websocket.Stories;
|
import org.whispersystems.websocket.WebsocketHeaders;
|
||||||
import org.whispersystems.websocket.auth.ReadOnly;
|
import org.whispersystems.websocket.auth.ReadOnly;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
@ -749,10 +749,10 @@ public class MessageController {
|
||||||
@GET
|
@GET
|
||||||
@Produces(MediaType.APPLICATION_JSON)
|
@Produces(MediaType.APPLICATION_JSON)
|
||||||
public CompletableFuture<OutgoingMessageEntityList> getPendingMessages(@ReadOnly @Auth AuthenticatedDevice auth,
|
public CompletableFuture<OutgoingMessageEntityList> getPendingMessages(@ReadOnly @Auth AuthenticatedDevice auth,
|
||||||
@HeaderParam(Stories.X_SIGNAL_RECEIVE_STORIES) String receiveStoriesHeader,
|
@HeaderParam(WebsocketHeaders.X_SIGNAL_RECEIVE_STORIES) String receiveStoriesHeader,
|
||||||
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
|
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
|
||||||
|
|
||||||
boolean shouldReceiveStories = Stories.parseReceiveStoriesHeader(receiveStoriesHeader);
|
boolean shouldReceiveStories = WebsocketHeaders.parseReceiveStoriesHeader(receiveStoriesHeader);
|
||||||
|
|
||||||
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent);
|
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent);
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,15 @@ package org.whispersystems.textsecuregcm.websocket;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import java.security.SecureRandom;
|
import java.security.SecureRandom;
|
||||||
|
import java.time.Duration;
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
|
import java.util.concurrent.ScheduledFuture;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||||
|
import org.whispersystems.textsecuregcm.controllers.ProvisioningController;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||||
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
|
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
|
||||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||||
|
@ -31,7 +36,7 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener;
|
||||||
* a random, temporary "provisioning address," which it transmits via the newly-opened WebSocket. From there, the new
|
* a random, temporary "provisioning address," which it transmits via the newly-opened WebSocket. From there, the new
|
||||||
* device generally displays the provisioning address (and a public key) as a QR code. After that, the primary device
|
* device generally displays the provisioning address (and a public key) as a QR code. After that, the primary device
|
||||||
* will scan the QR code and send an encrypted provisioning message to the new device via
|
* will scan the QR code and send an encrypted provisioning message to the new device via
|
||||||
* {@link org.whispersystems.textsecuregcm.controllers.ProvisioningController#sendProvisioningMessage(AuthenticatedDevice, String, ProvisioningMessage, String)}.
|
* {@link ProvisioningController#sendProvisioningMessage(AuthenticatedDevice, String, ProvisioningMessage, String)}.
|
||||||
* Once the server receives the message from the primary device, it sends the message to the new device via the open
|
* Once the server receives the message from the primary device, it sends the message to the new device via the open
|
||||||
* WebSocket, then closes the WebSocket connection.
|
* WebSocket, then closes the WebSocket connection.
|
||||||
*/
|
*/
|
||||||
|
@ -39,9 +44,15 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
|
||||||
|
|
||||||
private final ProvisioningManager provisioningManager;
|
private final ProvisioningManager provisioningManager;
|
||||||
private final OpenWebSocketCounter openWebSocketCounter;
|
private final OpenWebSocketCounter openWebSocketCounter;
|
||||||
|
private final ScheduledExecutorService timeoutExecutor;
|
||||||
|
private final Duration timeout;
|
||||||
|
|
||||||
public ProvisioningConnectListener(final ProvisioningManager provisioningManager) {
|
public ProvisioningConnectListener(final ProvisioningManager provisioningManager,
|
||||||
|
final ScheduledExecutorService timeoutExecutor,
|
||||||
|
final Duration timeout) {
|
||||||
this.provisioningManager = provisioningManager;
|
this.provisioningManager = provisioningManager;
|
||||||
|
this.timeoutExecutor = timeoutExecutor;
|
||||||
|
this.timeout = timeout;
|
||||||
this.openWebSocketCounter = new OpenWebSocketCounter(MetricsUtil.name(getClass(), "openWebsockets"),
|
this.openWebSocketCounter = new OpenWebSocketCounter(MetricsUtil.name(getClass(), "openWebsockets"),
|
||||||
MetricsUtil.name(getClass(), "sessionDuration"));
|
MetricsUtil.name(getClass(), "sessionDuration"));
|
||||||
}
|
}
|
||||||
|
@ -50,8 +61,17 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
|
||||||
public void onWebSocketConnect(WebSocketSessionContext context) {
|
public void onWebSocketConnect(WebSocketSessionContext context) {
|
||||||
openWebSocketCounter.countOpenWebSocket(context);
|
openWebSocketCounter.countOpenWebSocket(context);
|
||||||
|
|
||||||
|
final Optional<ScheduledFuture<?>> maybeTimeoutFuture = context.getClient().supportsProvisioningSocketTimeouts()
|
||||||
|
? Optional.of(timeoutExecutor.schedule(() ->
|
||||||
|
context.getClient().close(1000, "Timeout"), timeout.toSeconds(), TimeUnit.SECONDS))
|
||||||
|
: Optional.empty();
|
||||||
|
|
||||||
final String provisioningAddress = generateProvisioningAddress();
|
final String provisioningAddress = generateProvisioningAddress();
|
||||||
context.addWebsocketClosedListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress));
|
|
||||||
|
context.addWebsocketClosedListener((context1, statusCode, reason) -> {
|
||||||
|
provisioningManager.removeListener(provisioningAddress);
|
||||||
|
maybeTimeoutFuture.ifPresent(future -> future.cancel(false));
|
||||||
|
});
|
||||||
|
|
||||||
provisioningManager.addListener(provisioningAddress, message -> {
|
provisioningManager.addListener(provisioningAddress, message -> {
|
||||||
assert message.getType() == PubSubProtos.PubSubMessage.Type.DELIVER;
|
assert message.getType() == PubSubProtos.PubSubMessage.Type.DELIVER;
|
||||||
|
|
|
@ -0,0 +1,180 @@
|
||||||
|
package org.whispersystems.textsecuregcm;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyLong;
|
||||||
|
import static org.mockito.Mockito.doReturn;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.reset;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
|
||||||
|
|
||||||
|
import io.dropwizard.core.Application;
|
||||||
|
import io.dropwizard.core.Configuration;
|
||||||
|
import io.dropwizard.core.setup.Environment;
|
||||||
|
import io.dropwizard.testing.junit5.DropwizardAppExtension;
|
||||||
|
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||||
|
import jakarta.servlet.DispatcherType;
|
||||||
|
import jakarta.servlet.ServletRegistration;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.EnumSet;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
|
import java.util.concurrent.ScheduledFuture;
|
||||||
|
import org.eclipse.jetty.websocket.api.Session;
|
||||||
|
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
|
||||||
|
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||||
|
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||||
|
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||||
|
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
|
||||||
|
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
|
||||||
|
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
|
||||||
|
import org.whispersystems.websocket.WebsocketHeaders;
|
||||||
|
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
|
||||||
|
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
|
||||||
|
import org.whispersystems.websocket.messages.InvalidMessageException;
|
||||||
|
import org.whispersystems.websocket.messages.WebSocketMessage;
|
||||||
|
import org.whispersystems.websocket.setup.WebSocketEnvironment;
|
||||||
|
|
||||||
|
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||||
|
public class ProvisioningTimeoutIntegrationTest {
|
||||||
|
|
||||||
|
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
|
||||||
|
new DropwizardAppExtension<>(TestApplication.class);
|
||||||
|
|
||||||
|
|
||||||
|
private WebSocketClient client;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() throws Exception {
|
||||||
|
client = new WebSocketClient();
|
||||||
|
client.start();
|
||||||
|
final TestApplication testApplication = DROPWIZARD_APP_EXTENSION.getApplication();
|
||||||
|
reset(testApplication.scheduler);
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() throws Exception {
|
||||||
|
client.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class TestProvisioningListener extends TestWebsocketListener {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
|
||||||
|
try {
|
||||||
|
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
|
||||||
|
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.REQUEST_MESSAGE
|
||||||
|
&& webSocketMessage.getRequestMessage().getPath().equals("/v1/address")) {
|
||||||
|
// ignore, this is the provisioning address the server sends on connect
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} catch (InvalidMessageException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
super.onWebSocketBinary(payload, offset, length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class TestApplication extends Application<Configuration> {
|
||||||
|
|
||||||
|
ScheduledExecutorService scheduler = mock(ScheduledExecutorService.class);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(final Configuration configuration, final Environment environment) throws Exception {
|
||||||
|
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
|
||||||
|
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
|
||||||
|
new WebSocketEnvironment<>(environment, webSocketConfiguration);
|
||||||
|
|
||||||
|
environment.servlets()
|
||||||
|
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
|
||||||
|
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
|
||||||
|
webSocketEnvironment.setConnectListener(
|
||||||
|
new ProvisioningConnectListener(mock(ProvisioningManager.class), scheduler, Duration.ofSeconds(5)));
|
||||||
|
|
||||||
|
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
|
||||||
|
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
|
||||||
|
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
|
||||||
|
|
||||||
|
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
|
||||||
|
final ServletRegistration.Dynamic websocketServlet = environment.servlets()
|
||||||
|
.addServlet("WebSocket", webSocketServlet);
|
||||||
|
websocketServlet.addMapping("/websocket");
|
||||||
|
websocketServlet.setAsyncSupported(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void websocketTimeoutWithHeader() throws IOException {
|
||||||
|
final TestProvisioningListener testWebsocketListener = new TestProvisioningListener();
|
||||||
|
|
||||||
|
final TestApplication testApplication = DROPWIZARD_APP_EXTENSION.getApplication();
|
||||||
|
when(testApplication.scheduler.schedule(any(Runnable.class), anyLong(), any()))
|
||||||
|
.thenReturn(mock(ScheduledFuture.class));
|
||||||
|
|
||||||
|
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
|
||||||
|
upgradeRequest.setHeader(WebsocketHeaders.X_SIGNAL_WEBSOCKET_TIMEOUT_HEADER, "");
|
||||||
|
try (Session ignored = client.connect(testWebsocketListener,
|
||||||
|
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
|
||||||
|
upgradeRequest).join()) {
|
||||||
|
|
||||||
|
assertThat(testWebsocketListener.closeFuture()).isNotDone();
|
||||||
|
|
||||||
|
final ArgumentCaptor<Runnable> closeFunctionCaptor = ArgumentCaptor.forClass(Runnable.class);
|
||||||
|
verify(testApplication.scheduler).schedule(closeFunctionCaptor.capture(), anyLong(), any());
|
||||||
|
closeFunctionCaptor.getValue().run();
|
||||||
|
|
||||||
|
assertThat(testWebsocketListener.closeFuture())
|
||||||
|
.succeedsWithin(Duration.ofSeconds(1))
|
||||||
|
.isEqualTo(1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void websocketTimeoutNoHeader() throws IOException {
|
||||||
|
final TestProvisioningListener testWebsocketListener = new TestProvisioningListener();
|
||||||
|
|
||||||
|
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
|
||||||
|
try (Session ignored = client.connect(testWebsocketListener,
|
||||||
|
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
|
||||||
|
upgradeRequest).join()) {
|
||||||
|
assertThat(testWebsocketListener.closeFuture()).isNotDone();
|
||||||
|
|
||||||
|
final TestApplication testApplication = DROPWIZARD_APP_EXTENSION.getApplication();
|
||||||
|
verify(testApplication.scheduler, never()).schedule(any(Runnable.class), anyLong(), any());
|
||||||
|
assertThat(testWebsocketListener.closeFuture()).isNotDone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void websocketTimeoutCancelled() throws IOException {
|
||||||
|
final TestProvisioningListener testWebsocketListener = new TestProvisioningListener();
|
||||||
|
|
||||||
|
final TestApplication testApplication = DROPWIZARD_APP_EXTENSION.getApplication();
|
||||||
|
@SuppressWarnings("unchecked") final ScheduledFuture<Void> scheduled = mock(ScheduledFuture.class);
|
||||||
|
doReturn(scheduled).when(testApplication.scheduler).schedule(any(Runnable.class), anyLong(), any());
|
||||||
|
|
||||||
|
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
|
||||||
|
upgradeRequest.setHeader(WebsocketHeaders.X_SIGNAL_WEBSOCKET_TIMEOUT_HEADER, "");
|
||||||
|
final Session session = client.connect(testWebsocketListener,
|
||||||
|
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
|
||||||
|
upgradeRequest).join();
|
||||||
|
|
||||||
|
// Close the websocket, make sure the timeout is cancelled.
|
||||||
|
session.close();
|
||||||
|
assertThat(testWebsocketListener.closeFuture()).succeedsWithin(Duration.ofSeconds(1));
|
||||||
|
verify(scheduled, times(1)).cancel(anyBoolean());
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,7 +19,6 @@ import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.ArgumentMatchers.argThat;
|
import static org.mockito.ArgumentMatchers.argThat;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.Mockito.anyBoolean;
|
import static org.mockito.Mockito.anyBoolean;
|
||||||
import static org.mockito.Mockito.anyString;
|
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.never;
|
import static org.mockito.Mockito.never;
|
||||||
import static org.mockito.Mockito.reset;
|
import static org.mockito.Mockito.reset;
|
||||||
|
@ -132,7 +131,7 @@ import org.whispersystems.textsecuregcm.util.Pair;
|
||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
import org.whispersystems.textsecuregcm.util.TestClock;
|
import org.whispersystems.textsecuregcm.util.TestClock;
|
||||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||||
import org.whispersystems.websocket.Stories;
|
import org.whispersystems.websocket.WebsocketHeaders;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
import reactor.core.scheduler.Scheduler;
|
import reactor.core.scheduler.Scheduler;
|
||||||
import reactor.core.scheduler.Schedulers;
|
import reactor.core.scheduler.Schedulers;
|
||||||
|
@ -675,7 +674,7 @@ class MessageControllerTest {
|
||||||
resources.getJerseyTest().target("/v1/messages/")
|
resources.getJerseyTest().target("/v1/messages/")
|
||||||
.request()
|
.request()
|
||||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||||
.header(Stories.X_SIGNAL_RECEIVE_STORIES, receiveStories ? "true" : "false")
|
.header(WebsocketHeaders.X_SIGNAL_RECEIVE_STORIES, receiveStories ? "true" : "false")
|
||||||
.header(HttpHeaders.USER_AGENT, userAgent)
|
.header(HttpHeaders.USER_AGENT, userAgent)
|
||||||
.accept(MediaType.APPLICATION_JSON_TYPE)
|
.accept(MediaType.APPLICATION_JSON_TYPE)
|
||||||
.get(OutgoingMessageEntityList.class);
|
.get(OutgoingMessageEntityList.class);
|
||||||
|
|
|
@ -24,8 +24,9 @@ public class TestWebsocketListener implements WebSocketListener {
|
||||||
|
|
||||||
private final AtomicLong requestId = new AtomicLong();
|
private final AtomicLong requestId = new AtomicLong();
|
||||||
private final CompletableFuture<Session> started = new CompletableFuture<>();
|
private final CompletableFuture<Session> started = new CompletableFuture<>();
|
||||||
|
private final CompletableFuture<Integer> closed = new CompletableFuture<>();
|
||||||
private final ConcurrentHashMap<Long, CompletableFuture<WebSocketResponseMessage>> responseFutures = new ConcurrentHashMap<>();
|
private final ConcurrentHashMap<Long, CompletableFuture<WebSocketResponseMessage>> responseFutures = new ConcurrentHashMap<>();
|
||||||
private final WebSocketMessageFactory messageFactory;
|
protected final WebSocketMessageFactory messageFactory;
|
||||||
|
|
||||||
public TestWebsocketListener() {
|
public TestWebsocketListener() {
|
||||||
this.messageFactory = new ProtobufWebSocketMessageFactory();
|
this.messageFactory = new ProtobufWebSocketMessageFactory();
|
||||||
|
@ -38,6 +39,15 @@ public class TestWebsocketListener implements WebSocketListener {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onWebSocketClose(int statusCode, String reason) {
|
||||||
|
closed.complete(statusCode);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CompletableFuture<Integer> closeFuture() {
|
||||||
|
return closed;
|
||||||
|
}
|
||||||
|
|
||||||
public CompletableFuture<WebSocketResponseMessage> doGet(final String requestPath) {
|
public CompletableFuture<WebSocketResponseMessage> doGet(final String requestPath) {
|
||||||
return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty());
|
return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,24 @@
|
||||||
package org.whispersystems.textsecuregcm.websocket;
|
package org.whispersystems.textsecuregcm.websocket;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyLong;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyString;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.doReturn;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
import com.google.protobuf.InvalidProtocolBufferException;
|
import com.google.protobuf.InvalidProtocolBufferException;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
|
import java.util.concurrent.ScheduledFuture;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.mockito.ArgumentCaptor;
|
import org.mockito.ArgumentCaptor;
|
||||||
|
@ -9,23 +27,20 @@ import org.whispersystems.textsecuregcm.push.ProvisioningManager;
|
||||||
import org.whispersystems.websocket.WebSocketClient;
|
import org.whispersystems.websocket.WebSocketClient;
|
||||||
import org.whispersystems.websocket.session.WebSocketSessionContext;
|
import org.whispersystems.websocket.session.WebSocketSessionContext;
|
||||||
|
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.verify;
|
|
||||||
|
|
||||||
class ProvisioningConnectListenerTest {
|
class ProvisioningConnectListenerTest {
|
||||||
|
|
||||||
private ProvisioningManager provisioningManager;
|
private ProvisioningManager provisioningManager;
|
||||||
private ProvisioningConnectListener provisioningConnectListener;
|
private ProvisioningConnectListener provisioningConnectListener;
|
||||||
|
private ScheduledExecutorService scheduledExecutorService;
|
||||||
|
|
||||||
|
private static Duration TIMEOUT = Duration.ofSeconds(5);
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
void setUp() {
|
void setUp() {
|
||||||
provisioningManager = mock(ProvisioningManager.class);
|
provisioningManager = mock(ProvisioningManager.class);
|
||||||
provisioningConnectListener = new ProvisioningConnectListener(provisioningManager);
|
scheduledExecutorService = mock(ScheduledExecutorService.class);
|
||||||
|
provisioningConnectListener =
|
||||||
|
new ProvisioningConnectListener(provisioningManager, scheduledExecutorService, TIMEOUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -60,4 +75,49 @@ class ProvisioningConnectListenerTest {
|
||||||
assertEquals(addListenerProvisioningAddressCaptor.getValue(), removeListenerProvisioningAddressCaptor.getValue());
|
assertEquals(addListenerProvisioningAddressCaptor.getValue(), removeListenerProvisioningAddressCaptor.getValue());
|
||||||
assertEquals(addListenerProvisioningAddressCaptor.getValue(), sentProvisioningAddress);
|
assertEquals(addListenerProvisioningAddressCaptor.getValue(), sentProvisioningAddress);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void schedulesTimeout() {
|
||||||
|
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
|
||||||
|
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
|
||||||
|
|
||||||
|
when(webSocketClient.supportsProvisioningSocketTimeouts()).thenReturn(true);
|
||||||
|
final ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
|
||||||
|
doReturn(scheduledFuture).when(scheduledExecutorService).schedule(any(Runnable.class), anyLong(), any());
|
||||||
|
|
||||||
|
final ArgumentCaptor<Runnable> scheduleCaptor = ArgumentCaptor.forClass(Runnable.class);
|
||||||
|
provisioningConnectListener.onWebSocketConnect(context);
|
||||||
|
verify(scheduledExecutorService).schedule(scheduleCaptor.capture(), eq(TIMEOUT.getSeconds()), eq(TimeUnit.SECONDS));
|
||||||
|
|
||||||
|
verify(webSocketClient, never()).close(anyInt(), any());
|
||||||
|
scheduleCaptor.getValue().run();
|
||||||
|
verify(webSocketClient, times(1)).close(eq(1000), anyString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void cancelsTimeout() {
|
||||||
|
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
|
||||||
|
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
|
||||||
|
|
||||||
|
when(webSocketClient.supportsProvisioningSocketTimeouts()).thenReturn(true);
|
||||||
|
final ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
|
||||||
|
doReturn(scheduledFuture).when(scheduledExecutorService).schedule(any(Runnable.class), anyLong(), any());
|
||||||
|
|
||||||
|
provisioningConnectListener.onWebSocketConnect(context);
|
||||||
|
verify(scheduledExecutorService).schedule(any(Runnable.class), eq(TIMEOUT.getSeconds()), eq(TimeUnit.SECONDS));
|
||||||
|
|
||||||
|
context.notifyClosed(1000, "Test");
|
||||||
|
|
||||||
|
verify(scheduledFuture).cancel(false);
|
||||||
|
verify(webSocketClient, never()).close(anyInt(), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void skipsTimeoutIfUnsupported() {
|
||||||
|
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
|
||||||
|
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
|
||||||
|
provisioningConnectListener.onWebSocketConnect(context);
|
||||||
|
verify(scheduledExecutorService, never())
|
||||||
|
.schedule(any(Runnable.class), eq(TIMEOUT.getSeconds()), eq(TimeUnit.SECONDS));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
package org.whispersystems.websocket;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Class containing constants and shared logic for handling stories.
|
|
||||||
* <p>
|
|
||||||
* In particular, it defines the way we interpret the X-Signal-Receive-Stories header
|
|
||||||
* which is used by both WebSockets and by the REST API.
|
|
||||||
*/
|
|
||||||
public class Stories {
|
|
||||||
public final static String X_SIGNAL_RECEIVE_STORIES = "X-Signal-Receive-Stories";
|
|
||||||
|
|
||||||
public static boolean parseReceiveStoriesHeader(String s) {
|
|
||||||
return "true".equals(s);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -98,8 +98,12 @@ public class WebSocketClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean shouldDeliverStories() {
|
public boolean shouldDeliverStories() {
|
||||||
String value = session.getUpgradeRequest().getHeader(Stories.X_SIGNAL_RECEIVE_STORIES);
|
String value = session.getUpgradeRequest().getHeader(WebsocketHeaders.X_SIGNAL_RECEIVE_STORIES);
|
||||||
return Stories.parseReceiveStoriesHeader(value);
|
return WebsocketHeaders.parseReceiveStoriesHeader(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean supportsProvisioningSocketTimeouts() {
|
||||||
|
return session.getUpgradeRequest().getHeader(WebsocketHeaders.X_SIGNAL_WEBSOCKET_TIMEOUT_HEADER) != null;
|
||||||
}
|
}
|
||||||
|
|
||||||
private long generateRequestId() {
|
private long generateRequestId() {
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
package org.whispersystems.websocket;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class containing constants and shared logic for headers used in websocket upgrade requests.
|
||||||
|
*/
|
||||||
|
public class WebsocketHeaders {
|
||||||
|
public final static String X_SIGNAL_RECEIVE_STORIES = "X-Signal-Receive-Stories";
|
||||||
|
public static final String X_SIGNAL_WEBSOCKET_TIMEOUT_HEADER = "X-Signal-Websocket-Timeout";
|
||||||
|
|
||||||
|
public static boolean parseReceiveStoriesHeader(String s) {
|
||||||
|
return "true".equals(s);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue