Fix flaky websocketTimeoutNoHeader test

This commit is contained in:
Ravi Khadiwala 2024-12-18 19:07:50 -06:00
parent 981a04f33b
commit a3e106fe04
1 changed files with 11 additions and 2 deletions

View File

@ -13,6 +13,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME; import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.core.Application; import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration; import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment; import io.dropwizard.core.setup.Environment;
@ -25,6 +26,7 @@ import java.net.URI;
import java.time.Duration; import java.time.Duration;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
@ -37,12 +39,13 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener; import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
import org.whispersystems.websocket.WebsocketHeaders;
import org.whispersystems.websocket.WebSocketResourceProviderFactory; import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.WebsocketHeaders;
import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.InvalidMessageException; import org.whispersystems.websocket.messages.InvalidMessageException;
import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.messages.WebSocketMessage;
@ -71,6 +74,7 @@ public class ProvisioningTimeoutIntegrationTest {
} }
public static class TestProvisioningListener extends TestWebsocketListener { public static class TestProvisioningListener extends TestWebsocketListener {
CompletableFuture<String> provisioningAddressFuture = new CompletableFuture<>();
@Override @Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) { public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
@ -78,11 +82,15 @@ public class ProvisioningTimeoutIntegrationTest {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.REQUEST_MESSAGE if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.REQUEST_MESSAGE
&& webSocketMessage.getRequestMessage().getPath().equals("/v1/address")) { && webSocketMessage.getRequestMessage().getPath().equals("/v1/address")) {
// ignore, this is the provisioning address the server sends on connect MessageProtos.ProvisioningAddress provisioningAddress =
MessageProtos.ProvisioningAddress.parseFrom(webSocketMessage.getRequestMessage().getBody().orElseThrow());
provisioningAddressFuture.complete(provisioningAddress.getAddress());
return; return;
} }
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
} }
super.onWebSocketBinary(payload, offset, length); super.onWebSocketBinary(payload, offset, length);
} }
@ -130,6 +138,7 @@ public class ProvisioningTimeoutIntegrationTest {
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())), URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest).join()) { upgradeRequest).join()) {
assertThat(testWebsocketListener.provisioningAddressFuture.join()).isNotNull();
assertThat(testWebsocketListener.closeFuture()).isNotDone(); assertThat(testWebsocketListener.closeFuture()).isNotDone();
final ArgumentCaptor<Runnable> closeFunctionCaptor = ArgumentCaptor.forClass(Runnable.class); final ArgumentCaptor<Runnable> closeFunctionCaptor = ArgumentCaptor.forClass(Runnable.class);