Gracefully handle proxy protocol messages at the beginning of TCP connections
This commit is contained in:
parent
1678045ce4
commit
9ec4f0b2f5
|
@ -398,6 +398,11 @@
|
|||
<artifactId>argparse4j</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>io.netty</groupId>
|
||||
<artifactId>netty-codec-haproxy</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.glassfish.jersey.test-framework</groupId>
|
||||
<artifactId>jersey-test-framework-core</artifactId>
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* An HAProxy message handler handles decoded HAProxyMessage instances, removing itself from the pipeline once it has
|
||||
* either handled a proxy protocol message or determined that no such message is coming.
|
||||
*/
|
||||
public class HAProxyMessageHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(HAProxyMessageHandler.class);
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
if (message instanceof HAProxyMessage haProxyMessage) {
|
||||
// Some network/deployment configurations will send us a proxy protocol message, but we don't use it. We still
|
||||
// need to clear it from the pipeline to avoid confusing the TLS machinery, though.
|
||||
log.debug("Discarding HAProxy message: {}", haProxyMessage);
|
||||
haProxyMessage.release();
|
||||
} else {
|
||||
super.channelRead(context, message);
|
||||
}
|
||||
|
||||
// Regardless of the type of the first message, we'll only ever receive zero or one HAProxyMessages. After the first
|
||||
// message, all others will just be "normal" messages, and our work here is done.
|
||||
context.pipeline().remove(this);
|
||||
}
|
||||
}
|
|
@ -89,6 +89,10 @@ public class NoiseWebSocketTunnelServer implements Managed {
|
|||
.childHandler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
protected void initChannel(SocketChannel socketChannel) {
|
||||
socketChannel.pipeline()
|
||||
.addLast(new ProxyProtocolDetectionHandler())
|
||||
.addLast(new HAProxyMessageHandler());
|
||||
|
||||
if (sslContext != null) {
|
||||
socketChannel.pipeline().addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.CompositeByteBuf;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
|
||||
|
||||
/**
|
||||
* A proxy protocol detection handler watches for HAProxy PROXY protocol messages at the beginning of a TCP connection.
|
||||
* If a connection begins with a proxy message, this handler will add a {@link HAProxyMessageDecoder} to the pipeline.
|
||||
* In all cases, once this handler has determined that a connection does or does not begin with a proxy protocol
|
||||
* message, it will remove itself from the pipeline and pass any intercepted down the pipeline.
|
||||
*
|
||||
* @see <a href="https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt">The PROXY protocol</a>
|
||||
*/
|
||||
public class ProxyProtocolDetectionHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private CompositeByteBuf accumulator;
|
||||
|
||||
@VisibleForTesting
|
||||
static final int PROXY_MESSAGE_DETECTION_BYTES = 12;
|
||||
|
||||
@Override
|
||||
public void handlerAdded(final ChannelHandlerContext context) {
|
||||
// We need at least 12 bytes to decide if a byte buffer contains a proxy protocol message. Assuming we only get
|
||||
// non-empty buffers, that means we'll need at most 12 sub-buffers to have a complete message. In virtually every
|
||||
// practical case, though, we'll be able to tell from the first packet.
|
||||
accumulator = new CompositeByteBuf(context.alloc(), false, PROXY_MESSAGE_DETECTION_BYTES);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
if (message instanceof ByteBuf byteBuf) {
|
||||
accumulator.addComponent(true, byteBuf);
|
||||
|
||||
switch (HAProxyMessageDecoder.detectProtocol(accumulator).state()) {
|
||||
case NEEDS_MORE_DATA -> {
|
||||
}
|
||||
|
||||
case INVALID -> {
|
||||
// We have enough information to determine that this connection is NOT starting with a proxy protocol message,
|
||||
// and we can just pass the accumulated bytes through
|
||||
context.fireChannelRead(accumulator);
|
||||
|
||||
accumulator = null;
|
||||
context.pipeline().remove(this);
|
||||
}
|
||||
|
||||
case DETECTED -> {
|
||||
// We have enough information to know that we're dealing with a proxy protocol message; add appropriate
|
||||
// handlers and pass the accumulated bytes through
|
||||
context.pipeline().addAfter(context.name(), null, new HAProxyMessageDecoder());
|
||||
context.fireChannelRead(accumulator);
|
||||
|
||||
accumulator = null;
|
||||
context.pipeline().remove(this);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
super.channelRead(context, message);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
if (accumulator != null) {
|
||||
accumulator.release();
|
||||
accumulator = null;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,6 +8,8 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
|
|||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
|
||||
import io.netty.handler.codec.http.HttpClientCodec;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||
|
@ -21,6 +23,7 @@ import java.security.cert.X509Certificate;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.function.Supplier;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.net.ssl.SSLException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
@ -39,6 +42,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
private final HttpHeaders headers;
|
||||
private final SocketAddress remoteServerAddress;
|
||||
private final WebSocketCloseListener webSocketCloseListener;
|
||||
@Nullable private final Supplier<HAProxyMessage> proxyMessageSupplier;
|
||||
|
||||
private final List<Object> pendingReads = new ArrayList<>();
|
||||
|
||||
|
@ -55,7 +59,8 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
final byte deviceId,
|
||||
final HttpHeaders headers,
|
||||
final SocketAddress remoteServerAddress,
|
||||
final WebSocketCloseListener webSocketCloseListener) {
|
||||
final WebSocketCloseListener webSocketCloseListener,
|
||||
@Nullable Supplier<HAProxyMessage> proxyMessageSupplier) {
|
||||
|
||||
this.useTls = useTls;
|
||||
this.trustedServerCertificate = trustedServerCertificate;
|
||||
|
@ -68,6 +73,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
this.headers = headers;
|
||||
this.remoteServerAddress = remoteServerAddress;
|
||||
this.webSocketCloseListener = webSocketCloseListener;
|
||||
this.proxyMessageSupplier = proxyMessageSupplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -78,6 +84,16 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
.handler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
protected void initChannel(final SocketChannel channel) throws SSLException {
|
||||
|
||||
if (proxyMessageSupplier != null) {
|
||||
// In a production setting, we'd want some mechanism to remove these handlers after the initial message
|
||||
// were sent. Since this is just for testing, though, we can tolerate the inefficiency of leaving a
|
||||
// pair of inert handlers in the pipeline.
|
||||
channel.pipeline()
|
||||
.addLast(HAProxyMessageEncoder.INSTANCE)
|
||||
.addLast(new HAProxyMessageSender(proxyMessageSupplier));
|
||||
}
|
||||
|
||||
if (useTls) {
|
||||
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
|
||||
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.haproxy.HAProxyCommand;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
|
||||
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class HAProxyMessageHandlerTest {
|
||||
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
embeddedChannel = new EmbeddedChannel(new HAProxyMessageHandler());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleHAProxyMessage() throws InterruptedException {
|
||||
final HAProxyMessage haProxyMessage = new HAProxyMessage(
|
||||
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
|
||||
"10.0.0.1", "10.0.0.2", 12345, 443);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(haProxyMessage);
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
writeFuture.await();
|
||||
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
assertEquals(0, haProxyMessage.refCnt());
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleNonHAProxyMessage() throws InterruptedException {
|
||||
final byte[] bytes = new byte[32];
|
||||
ThreadLocalRandom.current().nextBytes(bytes);
|
||||
|
||||
final ByteBuf message = Unpooled.wrappedBuffer(bytes);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message);
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
writeFuture.await();
|
||||
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
assertEquals(message, embeddedChannel.inboundMessages().poll());
|
||||
assertEquals(1, message.refCnt());
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
class HAProxyMessageSender extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final Supplier<HAProxyMessage> messageSupplier;
|
||||
|
||||
HAProxyMessageSender(final Supplier<HAProxyMessage> messageSupplier) {
|
||||
this.messageSupplier = messageSupplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerAdded(final ChannelHandlerContext context) {
|
||||
if (context.channel().isActive()) {
|
||||
context.writeAndFlush(messageSupplier.get());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(final ChannelHandlerContext context) {
|
||||
context.writeAndFlush(messageSupplier.get());
|
||||
context.fireChannelActive();
|
||||
}
|
||||
}
|
|
@ -11,6 +11,8 @@ import java.net.SocketAddress;
|
|||
import java.net.URI;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.UUID;
|
||||
import java.util.function.Supplier;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
|
@ -34,6 +36,7 @@ class NoiseWebSocketTunnelClient implements AutoCloseable {
|
|||
final HttpHeaders headers,
|
||||
final boolean useTls,
|
||||
@Nullable final X509Certificate trustedServerCertificate,
|
||||
@Nullable final Supplier<HAProxyMessage> proxyMessageSupplier,
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final WebSocketCloseListener webSocketCloseListener) {
|
||||
|
||||
|
@ -54,7 +57,8 @@ class NoiseWebSocketTunnelClient implements AutoCloseable {
|
|||
deviceId,
|
||||
headers,
|
||||
remoteServerAddress,
|
||||
webSocketCloseListener));
|
||||
webSocketCloseListener,
|
||||
proxyMessageSupplier));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
@ -16,6 +16,10 @@ import io.netty.channel.DefaultEventLoopGroup;
|
|||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.handler.codec.haproxy.HAProxyCommand;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
|
||||
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import java.io.ByteArrayInputStream;
|
||||
|
@ -54,6 +58,8 @@ import org.junit.jupiter.api.AfterEach;
|
|||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||
|
@ -234,9 +240,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAuthenticated() throws InterruptedException {
|
||||
try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient()) {
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = { true, false })
|
||||
void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
|
||||
try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), includeProxyMessage)) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
|
@ -251,8 +258,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAuthenticatedPlaintext() throws InterruptedException {
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = { true, false })
|
||||
void connectAuthenticatedPlaintext(final boolean includeProxyMessage) throws InterruptedException {
|
||||
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient(
|
||||
tlsNoiseWebSocketTunnelServer.getLocalAddress(),
|
||||
NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
|
||||
|
@ -264,6 +272,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
new DefaultHttpHeaders(),
|
||||
true,
|
||||
serverTlsCertificate,
|
||||
includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null,
|
||||
nioEventLoopGroup,
|
||||
WebSocketCloseListener.NOOP_LISTENER)
|
||||
.start()) {
|
||||
|
@ -289,7 +298,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
|
||||
// Try to verify the server's public key with something other than the key with which it was signed
|
||||
try (final NoiseWebSocketTunnelClient client =
|
||||
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) {
|
||||
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders(), false)) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
|
@ -369,6 +378,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
new DefaultHttpHeaders(),
|
||||
true,
|
||||
serverTlsCertificate,
|
||||
null,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
.start()) {
|
||||
|
@ -443,6 +453,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
new DefaultHttpHeaders(),
|
||||
true,
|
||||
serverTlsCertificate,
|
||||
null,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
.start()) {
|
||||
|
@ -602,12 +613,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener)
|
||||
throws InterruptedException {
|
||||
|
||||
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders());
|
||||
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), false);
|
||||
}
|
||||
|
||||
private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener,
|
||||
final ECPublicKey rootPublicKey,
|
||||
final HttpHeaders headers) throws InterruptedException {
|
||||
final HttpHeaders headers,
|
||||
final boolean includeProxyMessage) throws InterruptedException {
|
||||
|
||||
return new NoiseWebSocketTunnelClient(tlsNoiseWebSocketTunnelServer.getLocalAddress(),
|
||||
NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
|
||||
|
@ -619,6 +631,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
headers,
|
||||
true,
|
||||
serverTlsCertificate,
|
||||
includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
.start();
|
||||
|
@ -642,8 +655,14 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
headers,
|
||||
true,
|
||||
serverTlsCertificate,
|
||||
null,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
.start();
|
||||
}
|
||||
|
||||
private static HAProxyMessage buildProxyMessage() {
|
||||
return new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
|
||||
"10.0.0.1", "10.0.0.2", 12345, 443);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import java.util.HexFormat;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class ProxyProtocolDetectionHandlerTest {
|
||||
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
private static final byte[] PROXY_V2_MESSAGE_BYTES =
|
||||
HexFormat.of().parseHex("0d0a0d0a000d0a515549540a2111000c0a0000010a000002303901bb");
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
embeddedChannel = new EmbeddedChannel(new ProxyProtocolDetectionHandler());
|
||||
}
|
||||
|
||||
@Test
|
||||
void singlePacketProxyMessage() throws InterruptedException {
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES));
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
writeFuture.await();
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll());
|
||||
}
|
||||
|
||||
@Test
|
||||
void multiPacketProxyMessage() throws InterruptedException {
|
||||
final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound(
|
||||
Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, 0,
|
||||
ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1));
|
||||
|
||||
final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound(
|
||||
Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1,
|
||||
PROXY_V2_MESSAGE_BYTES.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)));
|
||||
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
firstWriteFuture.await();
|
||||
secondWriteFuture.await();
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll());
|
||||
}
|
||||
|
||||
@Test
|
||||
void singlePacketNonProxyMessage() throws InterruptedException {
|
||||
final byte[] nonProxyProtocolMessage = new byte[32];
|
||||
ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(nonProxyProtocolMessage));
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
writeFuture.await();
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
|
||||
final Object inboundMessage = embeddedChannel.inboundMessages().poll();
|
||||
|
||||
assertInstanceOf(ByteBuf.class, inboundMessage);
|
||||
assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage));
|
||||
}
|
||||
|
||||
@Test
|
||||
void multiPacketNonProxyMessage() throws InterruptedException {
|
||||
final byte[] nonProxyProtocolMessage = new byte[32];
|
||||
ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage);
|
||||
|
||||
final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound(
|
||||
Unpooled.wrappedBuffer(nonProxyProtocolMessage, 0,
|
||||
ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1));
|
||||
|
||||
final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound(
|
||||
Unpooled.wrappedBuffer(nonProxyProtocolMessage, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1,
|
||||
nonProxyProtocolMessage.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)));
|
||||
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
firstWriteFuture.await();
|
||||
secondWriteFuture.await();
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
|
||||
final Object inboundMessage = embeddedChannel.inboundMessages().poll();
|
||||
|
||||
assertInstanceOf(ByteBuf.class, inboundMessage);
|
||||
assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage));
|
||||
}
|
||||
}
|
|
@ -13,7 +13,6 @@ import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
|
|||
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.util.List;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
|
Loading…
Reference in New Issue