From 9ec4f0b2f584dff193b3a7723d39796ef2ec8688 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Fri, 24 May 2024 09:11:19 -0400 Subject: [PATCH] Gracefully handle proxy protocol messages at the beginning of TCP connections --- service/pom.xml | 5 + .../grpc/net/HAProxyMessageHandler.java | 32 ++++++ .../grpc/net/NoiseWebSocketTunnelServer.java | 4 + .../net/ProxyProtocolDetectionHandler.java | 73 ++++++++++++ .../net/EstablishRemoteConnectionHandler.java | 18 ++- .../grpc/net/HAProxyMessageHandlerTest.java | 62 ++++++++++ .../grpc/net/HAProxyMessageSender.java | 28 +++++ .../grpc/net/NoiseWebSocketTunnelClient.java | 6 +- ...eWebSocketTunnelServerIntegrationTest.java | 35 ++++-- .../ProxyProtocolDetectionHandlerTest.java | 108 ++++++++++++++++++ .../RejectUnsupportedMessagesHandlerTest.java | 1 - 11 files changed, 361 insertions(+), 11 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageSender.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java diff --git a/service/pom.xml b/service/pom.xml index 0b78c84c0..656c26dc0 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -398,6 +398,11 @@ argparse4j + + io.netty + netty-codec-haproxy + + org.glassfish.jersey.test-framework jersey-test-framework-core diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java new file mode 100644 index 000000000..2bd807967 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java @@ -0,0 +1,32 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An HAProxy message handler handles decoded HAProxyMessage instances, removing itself from the pipeline once it has + * either handled a proxy protocol message or determined that no such message is coming. + */ +public class HAProxyMessageHandler extends ChannelInboundHandlerAdapter { + + private static final Logger log = LoggerFactory.getLogger(HAProxyMessageHandler.class); + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + if (message instanceof HAProxyMessage haProxyMessage) { + // Some network/deployment configurations will send us a proxy protocol message, but we don't use it. We still + // need to clear it from the pipeline to avoid confusing the TLS machinery, though. + log.debug("Discarding HAProxy message: {}", haProxyMessage); + haProxyMessage.release(); + } else { + super.channelRead(context, message); + } + + // Regardless of the type of the first message, we'll only ever receive zero or one HAProxyMessages. After the first + // message, all others will just be "normal" messages, and our work here is done. + context.pipeline().remove(this); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java index deabb65be..ef0be249c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java @@ -89,6 +89,10 @@ public class NoiseWebSocketTunnelServer implements Managed { .childHandler(new ChannelInitializer() { @Override protected void initChannel(SocketChannel socketChannel) { + socketChannel.pipeline() + .addLast(new ProxyProtocolDetectionHandler()) + .addLast(new HAProxyMessageHandler()); + if (sslContext != null) { socketChannel.pipeline().addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor)); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java new file mode 100644 index 000000000..3b8951658 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java @@ -0,0 +1,73 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.google.common.annotations.VisibleForTesting; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.haproxy.HAProxyMessageDecoder; + +/** + * A proxy protocol detection handler watches for HAProxy PROXY protocol messages at the beginning of a TCP connection. + * If a connection begins with a proxy message, this handler will add a {@link HAProxyMessageDecoder} to the pipeline. + * In all cases, once this handler has determined that a connection does or does not begin with a proxy protocol + * message, it will remove itself from the pipeline and pass any intercepted down the pipeline. + * + * @see The PROXY protocol + */ +public class ProxyProtocolDetectionHandler extends ChannelInboundHandlerAdapter { + + private CompositeByteBuf accumulator; + + @VisibleForTesting + static final int PROXY_MESSAGE_DETECTION_BYTES = 12; + + @Override + public void handlerAdded(final ChannelHandlerContext context) { + // We need at least 12 bytes to decide if a byte buffer contains a proxy protocol message. Assuming we only get + // non-empty buffers, that means we'll need at most 12 sub-buffers to have a complete message. In virtually every + // practical case, though, we'll be able to tell from the first packet. + accumulator = new CompositeByteBuf(context.alloc(), false, PROXY_MESSAGE_DETECTION_BYTES); + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + if (message instanceof ByteBuf byteBuf) { + accumulator.addComponent(true, byteBuf); + + switch (HAProxyMessageDecoder.detectProtocol(accumulator).state()) { + case NEEDS_MORE_DATA -> { + } + + case INVALID -> { + // We have enough information to determine that this connection is NOT starting with a proxy protocol message, + // and we can just pass the accumulated bytes through + context.fireChannelRead(accumulator); + + accumulator = null; + context.pipeline().remove(this); + } + + case DETECTED -> { + // We have enough information to know that we're dealing with a proxy protocol message; add appropriate + // handlers and pass the accumulated bytes through + context.pipeline().addAfter(context.name(), null, new HAProxyMessageDecoder()); + context.fireChannelRead(accumulator); + + accumulator = null; + context.pipeline().remove(this); + } + } + } else { + super.channelRead(context, message); + } + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + if (accumulator != null) { + accumulator.release(); + accumulator = null; + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java index a3f62843b..b6d1e05c7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java @@ -8,6 +8,8 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.handler.codec.haproxy.HAProxyMessageEncoder; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObjectAggregator; @@ -21,6 +23,7 @@ import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.List; import java.util.UUID; +import java.util.function.Supplier; import javax.annotation.Nullable; import javax.net.ssl.SSLException; import org.signal.libsignal.protocol.ecc.ECKeyPair; @@ -39,6 +42,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { private final HttpHeaders headers; private final SocketAddress remoteServerAddress; private final WebSocketCloseListener webSocketCloseListener; + @Nullable private final Supplier proxyMessageSupplier; private final List pendingReads = new ArrayList<>(); @@ -55,7 +59,8 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { final byte deviceId, final HttpHeaders headers, final SocketAddress remoteServerAddress, - final WebSocketCloseListener webSocketCloseListener) { + final WebSocketCloseListener webSocketCloseListener, + @Nullable Supplier proxyMessageSupplier) { this.useTls = useTls; this.trustedServerCertificate = trustedServerCertificate; @@ -68,6 +73,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { this.headers = headers; this.remoteServerAddress = remoteServerAddress; this.webSocketCloseListener = webSocketCloseListener; + this.proxyMessageSupplier = proxyMessageSupplier; } @Override @@ -78,6 +84,16 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { .handler(new ChannelInitializer() { @Override protected void initChannel(final SocketChannel channel) throws SSLException { + + if (proxyMessageSupplier != null) { + // In a production setting, we'd want some mechanism to remove these handlers after the initial message + // were sent. Since this is just for testing, though, we can tolerate the inefficiency of leaving a + // pair of inert handlers in the pipeline. + channel.pipeline() + .addLast(HAProxyMessageEncoder.INSTANCE) + .addLast(new HAProxyMessageSender(proxyMessageSupplier)); + } + if (useTls) { final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java new file mode 100644 index 000000000..de26f3c82 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java @@ -0,0 +1,62 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.haproxy.HAProxyCommand; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; +import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class HAProxyMessageHandlerTest { + + private EmbeddedChannel embeddedChannel; + + @BeforeEach + void setUp() { + embeddedChannel = new EmbeddedChannel(new HAProxyMessageHandler()); + } + + @Test + void handleHAProxyMessage() throws InterruptedException { + final HAProxyMessage haProxyMessage = new HAProxyMessage( + HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, + "10.0.0.1", "10.0.0.2", 12345, 443); + + final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(haProxyMessage); + embeddedChannel.flushInbound(); + + writeFuture.await(); + + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + assertEquals(0, haProxyMessage.refCnt()); + + assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); + } + + @Test + void handleNonHAProxyMessage() throws InterruptedException { + final byte[] bytes = new byte[32]; + ThreadLocalRandom.current().nextBytes(bytes); + + final ByteBuf message = Unpooled.wrappedBuffer(bytes); + + final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message); + embeddedChannel.flushInbound(); + + writeFuture.await(); + + assertEquals(1, embeddedChannel.inboundMessages().size()); + assertEquals(message, embeddedChannel.inboundMessages().poll()); + assertEquals(1, message.refCnt()); + + assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageSender.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageSender.java new file mode 100644 index 000000000..4bd337009 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageSender.java @@ -0,0 +1,28 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import java.util.function.Supplier; + +class HAProxyMessageSender extends ChannelInboundHandlerAdapter { + + private final Supplier messageSupplier; + + HAProxyMessageSender(final Supplier messageSupplier) { + this.messageSupplier = messageSupplier; + } + + @Override + public void handlerAdded(final ChannelHandlerContext context) { + if (context.channel().isActive()) { + context.writeAndFlush(messageSupplier.get()); + } + } + + @Override + public void channelActive(final ChannelHandlerContext context) { + context.writeAndFlush(messageSupplier.get()); + context.fireChannelActive(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java index f14e7cd10..44c2974ef 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java @@ -11,6 +11,8 @@ import java.net.SocketAddress; import java.net.URI; import java.security.cert.X509Certificate; import java.util.UUID; +import java.util.function.Supplier; +import io.netty.handler.codec.haproxy.HAProxyMessage; import io.netty.handler.codec.http.HttpHeaders; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; @@ -34,6 +36,7 @@ class NoiseWebSocketTunnelClient implements AutoCloseable { final HttpHeaders headers, final boolean useTls, @Nullable final X509Certificate trustedServerCertificate, + @Nullable final Supplier proxyMessageSupplier, final NioEventLoopGroup eventLoopGroup, final WebSocketCloseListener webSocketCloseListener) { @@ -54,7 +57,8 @@ class NoiseWebSocketTunnelClient implements AutoCloseable { deviceId, headers, remoteServerAddress, - webSocketCloseListener)); + webSocketCloseListener, + proxyMessageSupplier)); } }); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java index ee8269426..844004c1e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java @@ -16,6 +16,10 @@ import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.handler.codec.haproxy.HAProxyCommand; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; +import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpHeaders; import java.io.ByteArrayInputStream; @@ -54,6 +58,8 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetRequestAttributesRequest; @@ -234,9 +240,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS); } - @Test - void connectAuthenticated() throws InterruptedException { - try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient()) { + @ParameterizedTest + @ValueSource(booleans = { true, false }) + void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException { + try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), includeProxyMessage)) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); try { @@ -251,8 +258,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } } - @Test - void connectAuthenticatedPlaintext() throws InterruptedException { + @ParameterizedTest + @ValueSource(booleans = { true, false }) + void connectAuthenticatedPlaintext(final boolean includeProxyMessage) throws InterruptedException { try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient( tlsNoiseWebSocketTunnelServer.getLocalAddress(), NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI, @@ -264,6 +272,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes new DefaultHttpHeaders(), true, serverTlsCertificate, + includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null, nioEventLoopGroup, WebSocketCloseListener.NOOP_LISTENER) .start()) { @@ -289,7 +298,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes // Try to verify the server's public key with something other than the key with which it was signed try (final NoiseWebSocketTunnelClient client = - buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) { + buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders(), false)) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); @@ -369,6 +378,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes new DefaultHttpHeaders(), true, serverTlsCertificate, + null, nioEventLoopGroup, webSocketCloseListener) .start()) { @@ -443,6 +453,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes new DefaultHttpHeaders(), true, serverTlsCertificate, + null, nioEventLoopGroup, webSocketCloseListener) .start()) { @@ -602,12 +613,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener) throws InterruptedException { - return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders()); + return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), false); } private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener, final ECPublicKey rootPublicKey, - final HttpHeaders headers) throws InterruptedException { + final HttpHeaders headers, + final boolean includeProxyMessage) throws InterruptedException { return new NoiseWebSocketTunnelClient(tlsNoiseWebSocketTunnelServer.getLocalAddress(), NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI, @@ -619,6 +631,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes headers, true, serverTlsCertificate, + includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null, nioEventLoopGroup, webSocketCloseListener) .start(); @@ -642,8 +655,14 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes headers, true, serverTlsCertificate, + null, nioEventLoopGroup, webSocketCloseListener) .start(); } + + private static HAProxyMessage buildProxyMessage() { + return new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, + "10.0.0.1", "10.0.0.2", 12345, 443); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java new file mode 100644 index 000000000..d7d3751ac --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java @@ -0,0 +1,108 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import java.util.HexFormat; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class ProxyProtocolDetectionHandlerTest { + + private EmbeddedChannel embeddedChannel; + + private static final byte[] PROXY_V2_MESSAGE_BYTES = + HexFormat.of().parseHex("0d0a0d0a000d0a515549540a2111000c0a0000010a000002303901bb"); + + @BeforeEach + void setUp() { + embeddedChannel = new EmbeddedChannel(new ProxyProtocolDetectionHandler()); + } + + @Test + void singlePacketProxyMessage() throws InterruptedException { + final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES)); + embeddedChannel.flushInbound(); + + writeFuture.await(); + + assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); + assertEquals(1, embeddedChannel.inboundMessages().size()); + assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll()); + } + + @Test + void multiPacketProxyMessage() throws InterruptedException { + final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound( + Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, 0, + ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)); + + final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound( + Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1, + PROXY_V2_MESSAGE_BYTES.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1))); + + embeddedChannel.flushInbound(); + + firstWriteFuture.await(); + secondWriteFuture.await(); + + assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); + assertEquals(1, embeddedChannel.inboundMessages().size()); + assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll()); + } + + @Test + void singlePacketNonProxyMessage() throws InterruptedException { + final byte[] nonProxyProtocolMessage = new byte[32]; + ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage); + + final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(nonProxyProtocolMessage)); + embeddedChannel.flushInbound(); + + writeFuture.await(); + + assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); + assertEquals(1, embeddedChannel.inboundMessages().size()); + + final Object inboundMessage = embeddedChannel.inboundMessages().poll(); + + assertInstanceOf(ByteBuf.class, inboundMessage); + assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage)); + } + + @Test + void multiPacketNonProxyMessage() throws InterruptedException { + final byte[] nonProxyProtocolMessage = new byte[32]; + ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage); + + final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound( + Unpooled.wrappedBuffer(nonProxyProtocolMessage, 0, + ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)); + + final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound( + Unpooled.wrappedBuffer(nonProxyProtocolMessage, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1, + nonProxyProtocolMessage.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1))); + + embeddedChannel.flushInbound(); + + firstWriteFuture.await(); + secondWriteFuture.await(); + + assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); + assertEquals(1, embeddedChannel.inboundMessages().size()); + + final Object inboundMessage = embeddedChannel.inboundMessages().poll(); + + assertInstanceOf(ByteBuf.class, inboundMessage); + assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java index c16a6468d..da312c719 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java @@ -13,7 +13,6 @@ import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; -import io.netty.util.ReferenceCountUtil; import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test;