enforce provisioning websocket timeouts

This commit is contained in:
Ravi Khadiwala 2025-03-26 15:50:44 -05:00 committed by ravi-signal
parent 8c2f3c839f
commit c2e3ab832c
5 changed files with 5 additions and 39 deletions

View File

@ -61,16 +61,14 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
public void onWebSocketConnect(WebSocketSessionContext context) {
openWebSocketCounter.countOpenWebSocket(context);
final Optional<ScheduledFuture<?>> maybeTimeoutFuture = context.getClient().supportsProvisioningSocketTimeouts()
? Optional.of(timeoutExecutor.schedule(() ->
context.getClient().close(1000, "Timeout"), timeout.toSeconds(), TimeUnit.SECONDS))
: Optional.empty();
final ScheduledFuture<?> timeoutFuture = timeoutExecutor.schedule(() ->
context.getClient().close(1000, "Timeout"), timeout.toSeconds(), TimeUnit.SECONDS);
final String provisioningAddress = generateProvisioningAddress();
context.addWebsocketClosedListener((context1, statusCode, reason) -> {
provisioningManager.removeListener(provisioningAddress);
maybeTimeoutFuture.ifPresent(future -> future.cancel(false));
timeoutFuture.cancel(false);
});
provisioningManager.addListener(provisioningAddress, message -> {

View File

@ -133,7 +133,6 @@ public class ProvisioningTimeoutIntegrationTest {
.thenReturn(mock(ScheduledFuture.class));
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader(WebsocketHeaders.X_SIGNAL_WEBSOCKET_TIMEOUT_HEADER, "");
try (Session ignored = client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest).join()) {
@ -151,22 +150,6 @@ public class ProvisioningTimeoutIntegrationTest {
}
}
@Test
public void websocketTimeoutNoHeader() throws IOException {
final TestProvisioningListener testWebsocketListener = new TestProvisioningListener();
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
try (Session ignored = client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest).join()) {
assertThat(testWebsocketListener.closeFuture()).isNotDone();
final TestApplication testApplication = DROPWIZARD_APP_EXTENSION.getApplication();
verify(testApplication.scheduler, never()).schedule(any(Runnable.class), anyLong(), any());
assertThat(testWebsocketListener.closeFuture()).isNotDone();
}
}
@Test
public void websocketTimeoutCancelled() throws IOException {
final TestProvisioningListener testWebsocketListener = new TestProvisioningListener();
@ -176,7 +159,6 @@ public class ProvisioningTimeoutIntegrationTest {
doReturn(scheduled).when(testApplication.scheduler).schedule(any(Runnable.class), anyLong(), any());
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader(WebsocketHeaders.X_SIGNAL_WEBSOCKET_TIMEOUT_HEADER, "");
final Session session = client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest).join();

View File

@ -47,6 +47,8 @@ class ProvisioningConnectListenerTest {
void onWebSocketConnect() {
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
final ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture).when(scheduledExecutorService).schedule(any(Runnable.class), anyLong(), any());
provisioningConnectListener.onWebSocketConnect(context);
context.notifyClosed(1000, "Test");
@ -81,7 +83,6 @@ class ProvisioningConnectListenerTest {
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
when(webSocketClient.supportsProvisioningSocketTimeouts()).thenReturn(true);
final ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture).when(scheduledExecutorService).schedule(any(Runnable.class), anyLong(), any());
@ -99,7 +100,6 @@ class ProvisioningConnectListenerTest {
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
when(webSocketClient.supportsProvisioningSocketTimeouts()).thenReturn(true);
final ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
doReturn(scheduledFuture).when(scheduledExecutorService).schedule(any(Runnable.class), anyLong(), any());
@ -111,13 +111,4 @@ class ProvisioningConnectListenerTest {
verify(scheduledFuture).cancel(false);
verify(webSocketClient, never()).close(anyInt(), any());
}
@Test
void skipsTimeoutIfUnsupported() {
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
provisioningConnectListener.onWebSocketConnect(context);
verify(scheduledExecutorService, never())
.schedule(any(Runnable.class), eq(TIMEOUT.getSeconds()), eq(TimeUnit.SECONDS));
}
}

View File

@ -102,10 +102,6 @@ public class WebSocketClient {
return WebsocketHeaders.parseReceiveStoriesHeader(value);
}
public boolean supportsProvisioningSocketTimeouts() {
return session.getUpgradeRequest().getHeader(WebsocketHeaders.X_SIGNAL_WEBSOCKET_TIMEOUT_HEADER) != null;
}
private long generateRequestId() {
return Math.abs(SECURE_RANDOM.nextLong());
}

View File

@ -5,7 +5,6 @@ package org.whispersystems.websocket;
*/
public class WebsocketHeaders {
public final static String X_SIGNAL_RECEIVE_STORIES = "X-Signal-Receive-Stories";
public static final String X_SIGNAL_WEBSOCKET_TIMEOUT_HEADER = "X-Signal-Websocket-Timeout";
public static boolean parseReceiveStoriesHeader(String s) {
return "true".equals(s);