From 9ec4f0b2f584dff193b3a7723d39796ef2ec8688 Mon Sep 17 00:00:00 2001
From: Jon Chambers <63609320+jon-signal@users.noreply.github.com>
Date: Fri, 24 May 2024 09:11:19 -0400
Subject: [PATCH] Gracefully handle proxy protocol messages at the beginning of
TCP connections
---
service/pom.xml | 5 +
.../grpc/net/HAProxyMessageHandler.java | 32 ++++++
.../grpc/net/NoiseWebSocketTunnelServer.java | 4 +
.../net/ProxyProtocolDetectionHandler.java | 73 ++++++++++++
.../net/EstablishRemoteConnectionHandler.java | 18 ++-
.../grpc/net/HAProxyMessageHandlerTest.java | 62 ++++++++++
.../grpc/net/HAProxyMessageSender.java | 28 +++++
.../grpc/net/NoiseWebSocketTunnelClient.java | 6 +-
...eWebSocketTunnelServerIntegrationTest.java | 35 ++++--
.../ProxyProtocolDetectionHandlerTest.java | 108 ++++++++++++++++++
.../RejectUnsupportedMessagesHandlerTest.java | 1 -
11 files changed, 361 insertions(+), 11 deletions(-)
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageSender.java
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java
diff --git a/service/pom.xml b/service/pom.xml
index 0b78c84c0..656c26dc0 100644
--- a/service/pom.xml
+++ b/service/pom.xml
@@ -398,6 +398,11 @@
argparse4j
+
+ io.netty
+ netty-codec-haproxy
+
+
org.glassfish.jersey.test-frameworkjersey-test-framework-core
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java
new file mode 100644
index 000000000..2bd807967
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java
@@ -0,0 +1,32 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.codec.haproxy.HAProxyMessage;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An HAProxy message handler handles decoded HAProxyMessage instances, removing itself from the pipeline once it has
+ * either handled a proxy protocol message or determined that no such message is coming.
+ */
+public class HAProxyMessageHandler extends ChannelInboundHandlerAdapter {
+
+ private static final Logger log = LoggerFactory.getLogger(HAProxyMessageHandler.class);
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
+ if (message instanceof HAProxyMessage haProxyMessage) {
+ // Some network/deployment configurations will send us a proxy protocol message, but we don't use it. We still
+ // need to clear it from the pipeline to avoid confusing the TLS machinery, though.
+ log.debug("Discarding HAProxy message: {}", haProxyMessage);
+ haProxyMessage.release();
+ } else {
+ super.channelRead(context, message);
+ }
+
+ // Regardless of the type of the first message, we'll only ever receive zero or one HAProxyMessages. After the first
+ // message, all others will just be "normal" messages, and our work here is done.
+ context.pipeline().remove(this);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java
index deabb65be..ef0be249c 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java
@@ -89,6 +89,10 @@ public class NoiseWebSocketTunnelServer implements Managed {
.childHandler(new ChannelInitializer() {
@Override
protected void initChannel(SocketChannel socketChannel) {
+ socketChannel.pipeline()
+ .addLast(new ProxyProtocolDetectionHandler())
+ .addLast(new HAProxyMessageHandler());
+
if (sslContext != null) {
socketChannel.pipeline().addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor));
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java
new file mode 100644
index 000000000..3b8951658
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java
@@ -0,0 +1,73 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
+
+/**
+ * A proxy protocol detection handler watches for HAProxy PROXY protocol messages at the beginning of a TCP connection.
+ * If a connection begins with a proxy message, this handler will add a {@link HAProxyMessageDecoder} to the pipeline.
+ * In all cases, once this handler has determined that a connection does or does not begin with a proxy protocol
+ * message, it will remove itself from the pipeline and pass any intercepted down the pipeline.
+ *
+ * @see The PROXY protocol
+ */
+public class ProxyProtocolDetectionHandler extends ChannelInboundHandlerAdapter {
+
+ private CompositeByteBuf accumulator;
+
+ @VisibleForTesting
+ static final int PROXY_MESSAGE_DETECTION_BYTES = 12;
+
+ @Override
+ public void handlerAdded(final ChannelHandlerContext context) {
+ // We need at least 12 bytes to decide if a byte buffer contains a proxy protocol message. Assuming we only get
+ // non-empty buffers, that means we'll need at most 12 sub-buffers to have a complete message. In virtually every
+ // practical case, though, we'll be able to tell from the first packet.
+ accumulator = new CompositeByteBuf(context.alloc(), false, PROXY_MESSAGE_DETECTION_BYTES);
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
+ if (message instanceof ByteBuf byteBuf) {
+ accumulator.addComponent(true, byteBuf);
+
+ switch (HAProxyMessageDecoder.detectProtocol(accumulator).state()) {
+ case NEEDS_MORE_DATA -> {
+ }
+
+ case INVALID -> {
+ // We have enough information to determine that this connection is NOT starting with a proxy protocol message,
+ // and we can just pass the accumulated bytes through
+ context.fireChannelRead(accumulator);
+
+ accumulator = null;
+ context.pipeline().remove(this);
+ }
+
+ case DETECTED -> {
+ // We have enough information to know that we're dealing with a proxy protocol message; add appropriate
+ // handlers and pass the accumulated bytes through
+ context.pipeline().addAfter(context.name(), null, new HAProxyMessageDecoder());
+ context.fireChannelRead(accumulator);
+
+ accumulator = null;
+ context.pipeline().remove(this);
+ }
+ }
+ } else {
+ super.channelRead(context, message);
+ }
+ }
+
+ @Override
+ public void handlerRemoved(final ChannelHandlerContext context) {
+ if (accumulator != null) {
+ accumulator.release();
+ accumulator = null;
+ }
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java
index a3f62843b..b6d1e05c7 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java
@@ -8,6 +8,8 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.codec.haproxy.HAProxyMessage;
+import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
@@ -21,6 +23,7 @@ import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
+import java.util.function.Supplier;
import javax.annotation.Nullable;
import javax.net.ssl.SSLException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
@@ -39,6 +42,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final HttpHeaders headers;
private final SocketAddress remoteServerAddress;
private final WebSocketCloseListener webSocketCloseListener;
+ @Nullable private final Supplier proxyMessageSupplier;
private final List