diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 73ed28fc9..b666c5db5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -154,7 +154,7 @@ import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup; import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer; import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup; -import org.whispersystems.textsecuregcm.grpc.net.NoiseWebSocketTunnelServer; +import org.whispersystems.textsecuregcm.grpc.net.websocket.NoiseWebSocketTunnelServer; import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer; import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java index f3015b7e8..beeb1469a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java @@ -1,7 +1,9 @@ package org.whispersystems.textsecuregcm.grpc.net; +import org.whispersystems.textsecuregcm.util.NoStackTraceException; + /** * Indicates that an attempt to authenticate a remote client failed for some reason. */ -class ClientAuthenticationException extends Exception { +public class ClientAuthenticationException extends NoStackTraceException { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java index 787b091e2..a62fe0684 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java @@ -1,62 +1,46 @@ package org.whispersystems.textsecuregcm.grpc.net; +import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; -import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import javax.crypto.BadPaddingException; +import io.netty.channel.ChannelInboundHandlerAdapter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.util.ExceptionUtils; /** - * An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a - * WebSocket handshake, the error handler will send appropriate WebSocket closure codes to the client in an attempt to - * identify the problem. If the client has not completed a WebSocket handshake, the handler simply closes the - * connection. + * An error handler serves as a general backstop for exceptions elsewhere in the pipeline. It translates exceptions + * thrown in inbound handlers into {@link OutboundCloseErrorMessage}s. */ -class ErrorHandler extends ChannelInboundHandlerAdapter { - - private boolean websocketHandshakeComplete = false; - +public class ErrorHandler extends ChannelInboundHandlerAdapter { private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class); - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception { - if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) { - setWebsocketHandshakeComplete(); - } - - context.fireUserEventTriggered(event); - } - - protected void setWebsocketHandshakeComplete() { - this.websocketHandshakeComplete = true; - } + private static OutboundCloseErrorMessage UNAUTHENTICATED_CLOSE = new OutboundCloseErrorMessage( + OutboundCloseErrorMessage.Code.AUTHENTICATION_ERROR, + "Not authenticated"); + private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage( + OutboundCloseErrorMessage.Code.NOISE_ERROR, + "Noise encryption error"); @Override public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) { - if (websocketHandshakeComplete) { - final WebSocketCloseStatus webSocketCloseStatus = switch (ExceptionUtils.unwrap(cause)) { - case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage()); - case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated"); - case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error"); - case NoiseException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error"); + final OutboundCloseErrorMessage closeMessage = switch (ExceptionUtils.unwrap(cause)) { + case NoiseHandshakeException e -> new OutboundCloseErrorMessage( + OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR, + e.getMessage()); + case ClientAuthenticationException ignored -> UNAUTHENTICATED_CLOSE; + case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE; + case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE; default -> { log.warn("An unexpected exception reached the end of the pipeline", cause); - yield WebSocketCloseStatus.INTERNAL_SERVER_ERROR; + yield new OutboundCloseErrorMessage( + OutboundCloseErrorMessage.Code.INTERNAL_SERVER_ERROR, + cause.getMessage()); } }; - context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus)) + context.writeAndFlush(closeMessage) .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - } else { - log.debug("Error occurred before websocket handshake complete", cause); - // We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful - // way; just close the connection instead. - context.close(); - } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java index c5433358c..59f691dbe 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java @@ -7,8 +7,6 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; import io.netty.util.ReferenceCountUtil; import java.util.ArrayList; import java.util.List; @@ -22,7 +20,7 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; * any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC * server. */ -class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { +public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { private final GrpcClientConnectionManager grpcClientConnectionManager; @@ -79,7 +77,9 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { // Close the local connection if the remote channel closes and vice versa remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close()); localChannelFuture.channel().closeFuture().addListener(closeFuture -> - remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART))); + remoteChannelContext.channel() + .write(new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed")) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE)); remoteChannelContext.pipeline() .addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel())); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java index be7ddc477..fa35ffcb6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java @@ -7,7 +7,6 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFutureListener; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.util.AttributeKey; import java.net.InetAddress; import java.util.ArrayList; @@ -63,6 +62,9 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener static final AttributeKey EPOCH_ATTRIBUTE_KEY = AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch"); + private static OutboundCloseErrorMessage SERVER_CLOSED = + new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"); + private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class); /** @@ -161,9 +163,7 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener } private static void closeRemoteChannel(final Channel channel) { - channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED - .toWebSocketCloseStatus("Reauthentication required"))) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + channel.writeAndFlush(SERVER_CLOSED).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); } @VisibleForTesting @@ -198,16 +198,16 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener } /** - * Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake + * Handles receipt of a handshake message and associates attributes and headers from the handshake * request with the channel via which the handshake took place. * - * @param channel the channel that completed a WebSocket handshake + * @param channel the channel where the handshake was initiated * @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake * @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null} * @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be * {@code null} */ - static void handleHandshakeComplete(final Channel channel, + public static void handleHandshakeInitiated(final Channel channel, final InetAddress preferredRemoteAddress, @Nullable final String userAgentHeader, @Nullable final String acceptLanguageHeader) { @@ -227,11 +227,10 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener } /** - * Handles successful establishment of a Noise-over-WebSocket connection from a remote client to a local gRPC server. + * Handles successful establishment of a Noise connection from a remote client to a local gRPC server. * - * @param localChannel the newly-opened local channel between the Noise-over-WebSocket tunnel and the local gRPC - * server - * @param remoteChannel the channel from the remote client to the Noise-over-WebSocket tunnel + * @param localChannel the newly-opened local channel between the Noise tunnel and the local gRPC server + * @param remoteChannel the channel from the remote client to the Noise tunnel * @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection */ void handleConnectionEstablished(final LocalChannel localChannel, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java index d65ee230f..cfd0205fc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java @@ -4,7 +4,7 @@ */ package org.whispersystems.textsecuregcm.grpc.net; -enum HandshakePattern { +public enum HandshakePattern { NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"), IK("Noise_IK_25519_ChaChaPoly_BLAKE2b"); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java index 063a83a87..411f5094e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java @@ -17,7 +17,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair; * Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent} * indicating that initiator connected anonymously. */ -class NoiseAnonymousHandler extends NoiseHandler { +public class NoiseAnonymousHandler extends NoiseHandler { public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) { super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java index b6df55e9d..d28910eee 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java @@ -32,11 +32,11 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils; *

* As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}. */ -class NoiseAuthenticatedHandler extends NoiseHandler { +public class NoiseAuthenticatedHandler extends NoiseHandler { private final ClientPublicKeysManager clientPublicKeysManager; - NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager, + public NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) { super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair)); this.clientPublicKeysManager = clientPublicKeysManager; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseException.java index e5d22827f..55bc979c7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseException.java @@ -1,10 +1,12 @@ package org.whispersystems.textsecuregcm.grpc.net; +import org.whispersystems.textsecuregcm.util.NoStackTraceException; + /** * Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/ * format or a general encryption error). */ -class NoiseException extends Exception { +public class NoiseException extends NoStackTraceException { public NoiseException(final String message) { super(message); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java index 3c2c9379a..7390b099c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java @@ -26,13 +26,14 @@ import javax.crypto.ShortBufferException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame; import org.whispersystems.textsecuregcm.util.ExceptionUtils; /** * A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts * inbound messages, and encrypts outbound messages */ -abstract class NoiseHandler extends ChannelDuplexHandler { +public abstract class NoiseHandler extends ChannelDuplexHandler { private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class); @@ -82,17 +83,16 @@ abstract class NoiseHandler extends ChannelDuplexHandler { @Override public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { try { - if (message instanceof BinaryWebSocketFrame frame) { - if (frame.content().readableBytes() > Noise.MAX_PACKET_LEN) { - final String error = "Invalid noise message length " + frame.content().readableBytes(); + if (message instanceof ByteBuf frame) { + if (frame.readableBytes() > Noise.MAX_PACKET_LEN) { + final String error = "Invalid noise message length " + frame.readableBytes(); throw state == State.HANDSHAKE ? new NoiseHandshakeException(error) : new NoiseException(error); } // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array. // We'll need to copy it to a heap buffer. - handleInboundMessage(context, ByteBufUtil.getBytes(frame.content())); + handleInboundMessage(context, ByteBufUtil.getBytes(frame)); } else { - // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an - // error + // Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error throw new IllegalArgumentException("Unexpected message in pipeline: " + message); } } catch (Exception e) { @@ -122,7 +122,7 @@ abstract class NoiseHandler extends ChannelDuplexHandler { // Now that we've authenticated, write the handshake response byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES); - context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage))) + context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage)) .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); // The handshake is complete. We can start intercepting read/write for noise encryption/decryption @@ -193,16 +193,16 @@ abstract class NoiseHandler extends ChannelDuplexHandler { // Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength); - pc.add(context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)))); + pc.add(context.write(Unpooled.wrappedBuffer(noiseBuffer))); } pc.finish(promise); } finally { ReferenceCountUtil.release(byteBuf); } } else { - if (!(message instanceof WebSocketFrame)) { - // Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that - // get issued in response to exceptions) + if (!(message instanceof OutboundCloseErrorMessage)) { + // Downstream handlers may write OutboundCloseErrorMessages that don't need to be encrypted (e.g. "close" frames + // that get issued in response to exceptions) log.warn("Unexpected object in pipeline: {}", message); } context.write(message, promise); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java index 1a11bcec6..b7a509728 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java @@ -1,10 +1,12 @@ package org.whispersystems.textsecuregcm.grpc.net; +import org.whispersystems.textsecuregcm.util.NoStackTraceException; + /** * Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or * a general encryption error). */ -class NoiseHandshakeException extends Exception { +public class NoiseHandshakeException extends NoStackTraceException { public NoiseHandshakeException(final String message) { super(message); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java index 57616c0f3..130bda9d8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java @@ -10,4 +10,4 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; * @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a * type that performs authentication */ -record NoiseIdentityDeterminedEvent(Optional authenticatedDevice) {} +public record NoiseIdentityDeterminedEvent(Optional authenticatedDevice) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java new file mode 100644 index 000000000..ca2d335bc --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java @@ -0,0 +1,35 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +/** + * An error written to the outbound pipeline that indicates the connection should be closed + */ +public record OutboundCloseErrorMessage(Code code, String message) { + public enum Code { + /** + * The server decided to close the connection. This could be because the server is going away, or it could be + * because the credentials for the connected client have been updated. + */ + SERVER_CLOSED, + + /** + * There was a noise decryption error after the noise session was established + */ + NOISE_ERROR, + + /** + * There was an error establishing the noise handshake + */ + NOISE_HANDSHAKE_ERROR, + + /** + * The provided credentials were not valid + */ + AUTHENTICATION_ERROR, + + INTERNAL_SERVER_ERROR + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java index 2d5effc1a..00afa516d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java @@ -8,7 +8,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter; /** * A proxy handler writes all data read from one channel to another peer channel. */ -class ProxyHandler extends ChannelInboundHandlerAdapter { +public class ProxyHandler extends ChannelInboundHandlerAdapter { private final Channel peerChannel; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectDataFrameCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectDataFrameCodec.java new file mode 100644 index 000000000..0ed2f938b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectDataFrameCodec.java @@ -0,0 +1,45 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net.noisedirect; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.ReferenceCountUtil; +import org.whispersystems.textsecuregcm.grpc.net.NoiseException; + +/** + * In the inbound direction, this handler strips the NoiseDirectFrame wrapper we read off the wire and then forwards the + * noise packet to the noise layer as a {@link ByteBuf} for decryption. + *

+ * In the outbound direction, this handler wraps encrypted noise packet {@link ByteBuf}s in a NoiseDirectFrame wrapper + * so it can be wire serialized. This handler assumes the first outbound message received will correspond to the + * handshake response, and then the subsequent messages are all data frame payloads. + */ +public class NoiseDirectDataFrameCodec extends ChannelDuplexHandler { + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { + if (msg instanceof NoiseDirectFrame frame) { + if (frame.frameType() != NoiseDirectFrame.FrameType.DATA) { + ReferenceCountUtil.release(msg); + throw new NoiseException("Invalid frame type received (expected DATA): " + frame.frameType()); + } + ctx.fireChannelRead(frame.content()); + } else { + ctx.fireChannelRead(msg); + } + } + + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) { + if (msg instanceof ByteBuf bb) { + ctx.write(new NoiseDirectFrame(NoiseDirectFrame.FrameType.DATA, bb), promise); + } else { + ctx.write(msg, promise); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java new file mode 100644 index 000000000..96db6280d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java @@ -0,0 +1,71 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net.noisedirect; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.DefaultByteBufHolder; + +public class NoiseDirectFrame extends DefaultByteBufHolder { + + static final byte VERSION = 0x00; + + private final FrameType frameType; + + public NoiseDirectFrame(final FrameType frameType, final ByteBuf data) { + super(data); + this.frameType = frameType; + } + + public FrameType frameType() { + return frameType; + } + + public byte versionedFrameTypeByte() { + final byte frameBits = frameType().getFrameBits(); + return (byte) ((NoiseDirectFrame.VERSION << 4) | frameBits); + } + + + public enum FrameType { + /** + * The payload is the initiator message or the responder message for a Noise NK handshake. If established, the + * session will be unauthenticated. + */ + NK_HANDSHAKE((byte) 1), + /** + * The payload is the initiator message or the responder message for a Noise IK handshake. If established, the + * session will be authenticated. + */ + IK_HANDSHAKE((byte) 2), + /** + * The payload is an encrypted noise packet. + */ + DATA((byte) 3), + /** + * A framing layer error occurred. The payload carries error details. + */ + ERROR((byte) 4); + + private final byte frameType; + + FrameType(byte frameType) { + if (frameType != (0x0F & frameType)) { + throw new IllegalStateException("Frame type must fit in 4 bits"); + } + this.frameType = frameType; + } + + public byte getFrameBits() { + return frameType; + } + + public boolean isHandshake() { + return switch (this) { + case IK_HANDSHAKE, NK_HANDSHAKE -> true; + case DATA, ERROR -> false; + }; + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java new file mode 100644 index 000000000..184a58323 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java @@ -0,0 +1,90 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net.noisedirect; + +import com.southernstorm.noise.protocol.Noise; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.ReferenceCountUtil; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; + +/** + * Handles conversion between bytes on the wire and {@link NoiseDirectFrame}s. This handler assumes that inbound bytes + * have already been framed using a {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder} + */ +public class NoiseDirectFrameCodec extends ChannelDuplexHandler { + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { + if (msg instanceof ByteBuf byteBuf) { + try { + ctx.fireChannelRead(deserialize(byteBuf)); + } catch (Exception e) { + ReferenceCountUtil.release(byteBuf); + throw e; + } + } else { + ctx.fireChannelRead(msg); + } + } + + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) { + if (msg instanceof NoiseDirectFrame noiseDirectFrame) { + try { + // Serialize the frame into a newly allocated direct buffer. Since this is the last handler before the + // network, nothing should have to make another copy of this. If later another layer is added, it may be more + // efficient to reuse the input buffer (typically not direct) by using a composite byte buffer + final ByteBuf serialized = serialize(ctx, noiseDirectFrame); + ctx.writeAndFlush(serialized, promise); + } finally { + ReferenceCountUtil.release(noiseDirectFrame); + } + } else { + ctx.write(msg, promise); + } + } + + private ByteBuf serialize( + final ChannelHandlerContext ctx, + final NoiseDirectFrame noiseDirectFrame) { + if (noiseDirectFrame.content().readableBytes() > Noise.MAX_PACKET_LEN) { + throw new IllegalStateException("Payload too long: " + noiseDirectFrame.content().readableBytes()); + } + + // 1 version/frametype byte, 2 length bytes, content + final ByteBuf byteBuf = ctx.alloc().buffer(1 + 2 + noiseDirectFrame.content().readableBytes()); + + byteBuf.writeByte(noiseDirectFrame.versionedFrameTypeByte()); + byteBuf.writeShort(noiseDirectFrame.content().readableBytes()); + byteBuf.writeBytes(noiseDirectFrame.content()); + return byteBuf; + } + + private NoiseDirectFrame deserialize(final ByteBuf byteBuf) throws Exception { + final byte versionAndFrameByte = byteBuf.readByte(); + final int version = (versionAndFrameByte & 0xF0) >> 4; + if (version != NoiseDirectFrame.VERSION) { + throw new NoiseHandshakeException("Invalid NoiseDirect version: " + version); + } + final byte frameTypeBits = (byte) (versionAndFrameByte & 0x0F); + final NoiseDirectFrame.FrameType frameType = switch (frameTypeBits) { + case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE; + case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE; + case 3 -> NoiseDirectFrame.FrameType.DATA; + case 4 -> NoiseDirectFrame.FrameType.ERROR; + default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits); + }; + + final int length = Short.toUnsignedInt(byteBuf.readShort()); + if (length != byteBuf.readableBytes()) { + throw new IllegalArgumentException( + "Payload length did not match remaining buffer, should have been guaranteed by a previous handler"); + } + return new NoiseDirectFrame(frameType, byteBuf.readSlice(length)); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java new file mode 100644 index 000000000..a2204b660 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java @@ -0,0 +1,75 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net.noisedirect; + +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.ReferenceCountUtil; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; +import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler; +import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; + +import java.io.IOException; +import java.net.InetSocketAddress; + +/** + * Waits for a Handshake {@link NoiseDirectFrame} and then installs a {@link NoiseDirectDataFrameCodec} and + * {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} and removes itself + */ +public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter { + + private final ClientPublicKeysManager clientPublicKeysManager; + private final ECKeyPair ecKeyPair; + + public NoiseDirectHandshakeSelector(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) { + this.clientPublicKeysManager = clientPublicKeysManager; + this.ecKeyPair = ecKeyPair; + } + + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { + if (msg instanceof NoiseDirectFrame frame) { + try { + // We've received an inbound handshake frame so we know what kind of NoiseHandler we need (authenticated or + // anonymous). We construct it here, and then remember the handshake type so we can annotate our handshake + // response with the correct frame type whenever we receive it. + final ChannelDuplexHandler noiseHandler = switch (frame.frameType()) { + case DATA, ERROR -> + throw new NoiseHandshakeException("Invalid frame type for first message " + frame.frameType()); + case IK_HANDSHAKE -> new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair); + case NK_HANDSHAKE -> new NoiseAnonymousHandler(ecKeyPair); + }; + if (ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) { + // TODO: Provide connection metadata / headers in handshake payload + GrpcClientConnectionManager.handleHandshakeInitiated(ctx.channel(), + inetSocketAddress.getAddress(), + "NoiseDirect", + ""); + + } else { + throw new IOException("Could not determine remote address"); + } + + // Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames + // should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded + // for network serialization. Note that we need to install the Data frame handler before firing the read, + // because we may receive an outbound message from the noiseHandler + ctx.pipeline().addAfter(ctx.name(), null, noiseHandler); + ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec()); + ctx.fireChannelRead(frame.content()); + } catch (Exception e) { + ReferenceCountUtil.release(msg); + throw e; + } + } else { + ctx.fireChannelRead(msg); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java new file mode 100644 index 000000000..500400cd9 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java @@ -0,0 +1,39 @@ +package org.whispersystems.textsecuregcm.grpc.net.noisedirect; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage; + +/** + * Translates {@link OutboundCloseErrorMessage}s into {@link NoiseDirectFrame} error frames. After error frames are + * written, the channel is closed + */ +class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter { + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof OutboundCloseErrorMessage err) { + final NoiseDirectProtos.Error.Type type = switch (err.code()) { + case SERVER_CLOSED -> NoiseDirectProtos.Error.Type.UNAVAILABLE; + case NOISE_ERROR -> NoiseDirectProtos.Error.Type.ENCRYPTION_ERROR; + case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.Error.Type.HANDSHAKE_ERROR; + case AUTHENTICATION_ERROR -> NoiseDirectProtos.Error.Type.AUTHENTICATION_ERROR; + case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.Error.Type.INTERNAL_ERROR; + }; + final NoiseDirectProtos.Error proto = NoiseDirectProtos.Error.newBuilder() + .setType(type) + .setMessage(err.message()) + .build(); + final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize()); + proto.writeTo(new ByteBufOutputStream(byteBuf)); + ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.ERROR, byteBuf)) + .addListener(ChannelFutureListener.CLOSE); + } else { + ctx.write(msg, promise); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java new file mode 100644 index 000000000..3bca2e69e --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java @@ -0,0 +1,90 @@ +package org.whispersystems.textsecuregcm.grpc.net.noisedirect; + +import com.google.common.annotations.VisibleForTesting; +import com.southernstorm.noise.protocol.Noise; +import io.dropwizard.lifecycle.Managed; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import java.net.InetSocketAddress; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler; +import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; +import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler; +import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; + +/** + * A NoiseDirectTunnelServer accepts traffic from the public internet (in the form of Noise packets framed by a custom + * binary framing protocol) and passes it through to a local gRPC server. + */ +public class NoiseDirectTunnelServer implements Managed { + + private final ServerBootstrap bootstrap; + private ServerSocketChannel channel; + + private static final Logger log = LoggerFactory.getLogger(NoiseDirectTunnelServer.class); + + public NoiseDirectTunnelServer(final int port, + final NioEventLoopGroup eventLoopGroup, + final GrpcClientConnectionManager grpcClientConnectionManager, + final ClientPublicKeysManager clientPublicKeysManager, + final ECKeyPair ecKeyPair, + final LocalAddress authenticatedGrpcServerAddress, + final LocalAddress anonymousGrpcServerAddress) { + + this.bootstrap = new ServerBootstrap() + .group(eventLoopGroup) + .channel(NioServerSocketChannel.class) + .localAddress(port) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel socketChannel) { + socketChannel.pipeline() + .addLast(new ProxyProtocolDetectionHandler()) + .addLast(new HAProxyMessageHandler()); + + socketChannel.pipeline() + // frame byte followed by a 2-byte length field + .addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2)) + // Parses NoiseDirectFrames from wire bytes and vice versa + .addLast(new NoiseDirectFrameCodec()) + // Turn generic OutboundCloseErrorMessages into noise direct error frames + .addLast(new NoiseDirectOutboundErrorHandler()) + // Waits for the handshake to finish and then replaces itself with a NoiseDirectFrameCodec and a + // NoiseHandler to handle noise encryption/decryption + .addLast(new NoiseDirectHandshakeSelector(clientPublicKeysManager, ecKeyPair)) + // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler + // once the Noise handshake has completed + .addLast(new EstablishLocalGrpcConnectionHandler( + grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress)) + .addLast(new ErrorHandler()); + } + }); + } + + @VisibleForTesting + public InetSocketAddress getLocalAddress() { + return channel.localAddress(); + } + + @Override + public void start() throws InterruptedException { + channel = (ServerSocketChannel) bootstrap.bind().await().channel(); + } + + @Override + public void stop() throws InterruptedException { + if (channel != null) { + channel.close().await(); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ApplicationWebSocketCloseReason.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java similarity index 59% rename from service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ApplicationWebSocketCloseReason.java rename to service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java index a0ba4ee0b..c4dc8528a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ApplicationWebSocketCloseReason.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java @@ -1,12 +1,11 @@ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.websocket; import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; enum ApplicationWebSocketCloseReason { NOISE_HANDSHAKE_ERROR(4001), CLIENT_AUTHENTICATION_ERROR(4002), - NOISE_ENCRYPTION_ERROR(4003), - REAUTHENTICATION_REQUIRED(4004); + NOISE_ENCRYPTION_ERROR(4003); private final int statusCode; @@ -17,8 +16,4 @@ enum ApplicationWebSocketCloseReason { public int getStatusCode() { return statusCode; } - - WebSocketCloseStatus toWebSocketCloseStatus(final String reason) { - return new WebSocketCloseStatus(statusCode, reason); - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java similarity index 94% rename from service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java rename to service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java index fcf06beff..cd37ac603 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java @@ -1,4 +1,4 @@ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.websocket; import com.google.common.annotations.VisibleForTesting; import com.southernstorm.noise.protocol.Noise; @@ -28,6 +28,7 @@ import javax.net.ssl.SSLException; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.grpc.net.*; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; /** @@ -103,7 +104,10 @@ public class NoiseWebSocketTunnelServer implements Managed { // request and passed it down the pipeline .addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH)) .addLast(new WebSocketServerProtocolHandler("/", true)) + // Turn generic OutboundCloseErrorMessages into websocket close frames + .addLast(new WebSocketOutboundErrorHandler()) .addLast(new RejectUnsupportedMessagesHandler()) + .addLast(new WebsocketPayloadCodec()) // The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once // a WebSocket handshake has been completed .addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret)) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandler.java similarity index 95% rename from service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandler.java rename to service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandler.java index 41d7436e2..d313951b3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandler.java @@ -1,4 +1,4 @@ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.websocket; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandler.java similarity index 98% rename from service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandler.java rename to service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandler.java index ddca6c4fe..e70cdec8e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandler.java @@ -1,4 +1,4 @@ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.websocket; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java new file mode 100644 index 000000000..4dfb54be4 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java @@ -0,0 +1,64 @@ +package org.whispersystems.textsecuregcm.grpc.net.websocket; + +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import javax.crypto.BadPaddingException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.grpc.net.ClientAuthenticationException; +import org.whispersystems.textsecuregcm.grpc.net.NoiseException; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; +import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; + +/** + * Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames + */ +class WebSocketOutboundErrorHandler extends ChannelDuplexHandler { + + private boolean websocketHandshakeComplete = false; + + private static final Logger log = LoggerFactory.getLogger(WebSocketOutboundErrorHandler.class); + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception { + if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) { + setWebsocketHandshakeComplete(); + } + + context.fireUserEventTriggered(event); + } + + protected void setWebsocketHandshakeComplete() { + this.websocketHandshakeComplete = true; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof OutboundCloseErrorMessage err) { + if (websocketHandshakeComplete) { + final int status = switch (err.code()) { + case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code(); + case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode(); + case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode(); + case AUTHENTICATION_ERROR -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode(); + case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code(); + }; + ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } else { + log.debug("Error {} occurred before websocket handshake complete", err); + // We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful + // way; just close the connection instead. + ctx.close(); + } + } else { + ctx.write(msg, promise); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java similarity index 94% rename from service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java rename to service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java index f847314b0..b41e4328d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java @@ -1,4 +1,4 @@ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.websocket; import com.google.common.annotations.VisibleForTesting; import com.google.common.net.InetAddresses; @@ -20,6 +20,9 @@ import org.apache.commons.lang3.StringUtils; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; +import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler; +import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; /** @@ -74,7 +77,7 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { preferredRemoteAddress = maybePreferredRemoteAddress.get(); } - GrpcClientConnectionManager.handleHandshakeComplete(context.channel(), + GrpcClientConnectionManager.handleHandshakeInitiated(context.channel(), preferredRemoteAddress, handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT), handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketPayloadCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketPayloadCodec.java new file mode 100644 index 000000000..4ef5bd151 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketPayloadCodec.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc.net.websocket; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; + +/** + * Extracts buffers from inbound BinaryWebsocketFrames before forwarding to a + * {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} for decryption and wraps outbound encrypted noise + * packet buffers in BinaryWebsocketFrames for writing through the websocket layer. + */ +public class WebsocketPayloadCodec extends ChannelDuplexHandler { + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) { + if (msg instanceof BinaryWebSocketFrame frame) { + ctx.fireChannelRead(frame.content()); + } else { + ctx.fireChannelRead(msg); + } + } + + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) { + if (msg instanceof ByteBuf bb) { + ctx.write(new BinaryWebSocketFrame(bb), promise); + } else { + ctx.write(msg, promise); + } + } +} diff --git a/service/src/main/proto/NoiseDirect.proto b/service/src/main/proto/NoiseDirect.proto new file mode 100644 index 000000000..10a7801df --- /dev/null +++ b/service/src/main/proto/NoiseDirect.proto @@ -0,0 +1,22 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +syntax = "proto3"; + +option java_package = "org.whispersystems.textsecuregcm.grpc.net.noisedirect"; +option java_outer_classname = "NoiseDirectProtos"; + +message Error { + enum Type { + UNSPECIFIED = 0; + HANDSHAKE_ERROR = 1; + ENCRYPTION_ERROR = 2; + UNAVAILABLE = 3; + INTERNAL_ERROR = 4; + AUTHENTICATION_ERROR = 5; + } + Type type = 1; + string message = 2; +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java index 79fb2078f..9cee9e7f1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java @@ -4,7 +4,7 @@ import io.netty.util.ResourceLeakDetector; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -abstract class AbstractLeakDetectionTest { +public abstract class AbstractLeakDetectionTest { private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java index 7c8a40f41..0b699b543 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java @@ -119,15 +119,15 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { * waiting messages in the channel, return null. */ byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException { - final BinaryWebSocketFrame responseFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); + final ByteBuf responseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); if (responseFrame == null) { return null; } - final byte[] plaintext = new byte[responseFrame.content().readableBytes() - 16]; + final byte[] plaintext = new byte[responseFrame.readableBytes() - 16]; final int read = clientCipherPair.getReceiver().decryptWithAd(null, - ByteBufUtil.getBytes(responseFrame.content()), 0, + ByteBufUtil.getBytes(responseFrame), 0, plaintext, 0, - responseFrame.content().readableBytes()); + responseFrame.readableBytes()); assertEquals(read, plaintext.length); return plaintext; } @@ -140,7 +140,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { final ByteBuf content = Unpooled.wrappedBuffer(contentBytes); - final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await(); + final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(content).await(); assertFalse(writeFuture.isSuccess()); assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause()); @@ -150,18 +150,18 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { @Test void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException { - final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7]; + final ByteBuf[] frames = new ByteBuf[7]; for (int i = 0; i < frames.length; i++) { final byte[] contentBytes = new byte[17]; ThreadLocalRandom.current().nextBytes(contentBytes); - frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes)); + frames[i] = Unpooled.wrappedBuffer(contentBytes); embeddedChannel.writeOneInbound(frames[i]).await(); } - for (final BinaryWebSocketFrame frame : frames) { + for (final ByteBuf frame : frames) { assertEquals(0, frame.refCnt()); } @@ -169,11 +169,11 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { } @Test - void handleNonWebSocketBinaryFrame() throws Throwable { + void handleNonByteBufBinaryFrame() throws Throwable { final byte[] contentBytes = new byte[17]; ThreadLocalRandom.current().nextBytes(contentBytes); - final ByteBuf message = Unpooled.wrappedBuffer(contentBytes); + final BinaryWebSocketFrame message = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes)); final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await(); @@ -192,7 +192,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()]; clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length); - final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext)); + final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(ciphertext); assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess()); assertEquals(0, ciphertextFrame.refCnt()); @@ -206,7 +206,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { final byte[] bogusCiphertext = new byte[32]; io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext); - final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext)); + final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(bogusCiphertext); final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await(); assertEquals(0, ciphertextFrame.refCnt()); @@ -235,11 +235,11 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { assertTrue(writePlaintextFuture.await().isSuccess()); assertEquals(0, plaintextBuffer.refCnt()); - final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); + final ByteBuf ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); assertNotNull(ciphertextFrame); assertTrue(embeddedChannel.outboundMessages().isEmpty()); - final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content()); + final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame); ciphertextFrame.release(); final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()]; @@ -272,10 +272,10 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { final byte[] decryptedPlaintext = new byte[plaintextLength]; int plaintextOffset = 0; - BinaryWebSocketFrame ciphertextFrame; - while ((ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll()) != null) { - assertTrue(ciphertextFrame.content().readableBytes() <= Noise.MAX_PACKET_LEN); - final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content()); + ByteBuf ciphertextFrame; + while ((ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll()) != null) { + assertTrue(ciphertextFrame.readableBytes() <= Noise.MAX_PACKET_LEN); + final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame); ciphertextFrame.release(); plaintextOffset += clientCipherStatePair.getReceiver() .decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length); @@ -289,7 +289,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { public void writeHugeInboundMessage() throws Throwable { doHandshake(); final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1); - embeddedChannel.pipeline().fireChannelRead(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(big))); + embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big)); assertThrows(NoiseException.class, embeddedChannel::checkException); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java new file mode 100644 index 000000000..f2d75e196 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java @@ -0,0 +1,426 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.protobuf.ByteString; +import io.grpc.ManagedChannel; +import io.grpc.ServerBuilder; +import io.grpc.Status; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.stub.StreamObserver; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.handler.codec.haproxy.HAProxyCommand; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; +import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Supplier; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.signal.chat.rpc.EchoRequest; +import org.signal.chat.rpc.EchoResponse; +import org.signal.chat.rpc.EchoServiceGrpc; +import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; +import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; +import org.signal.chat.rpc.GetRequestAttributesRequest; +import org.signal.chat.rpc.RequestAttributesGrpc; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; +import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; +import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor; +import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; +import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; +import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; +import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; +import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent; +import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest { + + private static NioEventLoopGroup nioEventLoopGroup; + private static DefaultEventLoopGroup defaultEventLoopGroup; + private static ExecutorService delegatedTaskExecutor; + private static ExecutorService serverCallExecutor; + + private GrpcClientConnectionManager grpcClientConnectionManager; + private ClientPublicKeysManager clientPublicKeysManager; + + private ECKeyPair serverKeyPair; + private ECKeyPair clientKeyPair; + + private ManagedLocalGrpcServer authenticatedGrpcServer; + private ManagedLocalGrpcServer anonymousGrpcServer; + + private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID(); + private static final byte DEVICE_ID = Device.PRIMARY_ID; + + public static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16); + + @BeforeAll + static void setUpBeforeAll() { + nioEventLoopGroup = new NioEventLoopGroup(); + defaultEventLoopGroup = new DefaultEventLoopGroup(); + delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor(); + serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor(); + } + + @BeforeEach + void setUp() throws Exception { + + clientKeyPair = Curve.generateKeyPair(); + serverKeyPair = Curve.generateKeyPair(); + + grpcClientConnectionManager = new GrpcClientConnectionManager(); + + clientPublicKeysManager = mock(ClientPublicKeysManager.class); + when(clientPublicKeysManager.findPublicKey(any(), anyByte())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); + + final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated"); + final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous"); + + authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) { + @Override + protected void configureServer(final ServerBuilder serverBuilder) { + serverBuilder + .executor(serverCallExecutor) + .addService(new RequestAttributesServiceImpl()) + .addService(new EchoServiceImpl()) + .intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager)) + .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) + .intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager)); + } + }; + + authenticatedGrpcServer.start(); + + anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) { + @Override + protected void configureServer(final ServerBuilder serverBuilder) { + serverBuilder + .executor(serverCallExecutor) + .addService(new RequestAttributesServiceImpl()) + .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) + .intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager)); + } + }; + + anonymousGrpcServer.start(); + this.start( + nioEventLoopGroup, + delegatedTaskExecutor, + grpcClientConnectionManager, + clientPublicKeysManager, + serverKeyPair, + authenticatedGrpcServerAddress, anonymousGrpcServerAddress, + RECOGNIZED_PROXY_SECRET); + } + + + protected abstract void start( + final NioEventLoopGroup eventLoopGroup, + final Executor delegatedTaskExecutor, + final GrpcClientConnectionManager grpcClientConnectionManager, + final ClientPublicKeysManager clientPublicKeysManager, + final ECKeyPair serverKeyPair, + final LocalAddress authenticatedGrpcServerAddress, + final LocalAddress anonymousGrpcServerAddress, + final String recognizedProxySecret) throws Exception; + protected abstract void stop() throws Exception; + protected abstract NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey); + + public void assertClosedWith(final NoiseTunnelClient client, final CloseFrameEvent.CloseReason reason) + throws ExecutionException, InterruptedException, TimeoutException { + final CloseFrameEvent result = client.closeFrameFuture().get(1, TimeUnit.SECONDS); + assertEquals(reason, result.closeReason()); + } + + @AfterEach + void tearDown() throws Exception { + authenticatedGrpcServer.stop(); + anonymousGrpcServer.stop(); + this.stop(); + } + + @AfterAll + static void tearDownAfterAll() throws InterruptedException { + nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await(); + defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await(); + + delegatedTaskExecutor.shutdown(); + //noinspection ResultOfMethodCallIgnored + delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS); + + serverCallExecutor.shutdown(); + //noinspection ResultOfMethodCallIgnored + serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException { + try (final NoiseTunnelClient client = authenticated() + .setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage)) + .build()) { + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); + + assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); + assertEquals(DEVICE_ID, response.getDeviceId()); + } finally { + channel.shutdown(); + } + } + } + + @Test + void connectAuthenticatedBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException { + + // Try to verify the server's public key with something other than the key with which it was signed + try (final NoiseTunnelClient client = authenticated() + .setServerPublicKey(Curve.generateKeyPair().getPublicKey()) + .build()) { + + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); + } + } + + @Test + void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException, ExecutionException, TimeoutException { + + when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey()))); + + try (final NoiseTunnelClient client = authenticated().build()) { + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR); + } + } + + @Test + void connectAuthenticatedUnrecognizedDevice() throws InterruptedException, ExecutionException, TimeoutException { + when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + try (final NoiseTunnelClient client = authenticated().build()) { + + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + + assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR); + } + + } + + @Test + void connectAnonymous() throws InterruptedException { + try (final NoiseTunnelClient client = anonymous().build()) { + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); + + assertTrue(response.getAccountIdentifier().isEmpty()); + assertEquals(0, response.getDeviceId()); + } finally { + channel.shutdown(); + } + } + } + + @Test + void connectAnonymousBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException { + + // Try to verify the server's public key with something other than the key with which it was signed + try (final NoiseTunnelClient client = anonymous() + .setServerPublicKey(Curve.generateKeyPair().getPublicKey()) + .build()) { + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); + } + + } + + protected ManagedChannel buildManagedChannel(final LocalAddress localAddress) { + return NettyChannelBuilder.forAddress(localAddress) + .channelType(LocalChannel.class) + .eventLoopGroup(defaultEventLoopGroup) + .usePlaintext() + .build(); + } + + + @Test + void closeForReauthentication() throws InterruptedException, ExecutionException, TimeoutException { + + try (final NoiseTunnelClient client = authenticated().build()) { + + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); + + assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); + assertEquals(DEVICE_ID, response.getDeviceId()); + + grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID)); + final CloseFrameEvent closeEvent = client.closeFrameFuture().get(2, TimeUnit.SECONDS); + assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeEvent.closeReason()); + assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeEvent.closeInitiator()); + } finally { + channel.shutdown(); + } + } + } + + @Test + void waitForCallCompletion() throws InterruptedException, ExecutionException, TimeoutException { + try (final NoiseTunnelClient client = authenticated().build()) { + + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); + + assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); + assertEquals(DEVICE_ID, response.getDeviceId()); + + final CountDownLatch responseCountDownLatch = new CountDownLatch(1); + + // Start an open-ended server call and leave it in a non-complete state + final StreamObserver echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream( + new StreamObserver<>() { + @Override + public void onNext(final EchoResponse echoResponse) { + responseCountDownLatch.countDown(); + } + + @Override + public void onError(final Throwable throwable) { + } + + @Override + public void onCompleted() { + } + }); + + // Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before + // the request even starts. Make sure we've done at least one request/response pair to ensure that the call has + // truly started before requesting connection closure. + echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()); + assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS)); + + grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID)); + try { + client.closeFrameFuture().get(100, TimeUnit.MILLISECONDS); + fail("Channel should not close until active requests have finished"); + } catch (TimeoutException e) { + } + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel) + .echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build())); + + // Complete the open-ended server call + echoRequestStreamObserver.onCompleted(); + + final CloseFrameEvent closeFrameEvent = client.closeFrameFuture().get(1, TimeUnit.SECONDS); + assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeFrameEvent.closeInitiator()); + assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeFrameEvent.closeReason()); + } finally { + channel.shutdown(); + } + } + } + + protected NoiseTunnelClient.Builder anonymous() { + return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey()); + } + + protected NoiseTunnelClient.Builder authenticated() { + return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey()) + .setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID); + } + + private static Supplier proxyMessageSupplier(boolean includeProxyMesage) { + return includeProxyMesage + ? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, + "10.0.0.1", "10.0.0.2", 12345, 443) + : null; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientErrorHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientErrorHandler.java deleted file mode 100644 index ae7bcc133..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientErrorHandler.java +++ /dev/null @@ -1,18 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; - -class ClientErrorHandler extends ErrorHandler { - - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception { - if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) { - if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { - setWebsocketHandshakeComplete(); - } - } - - super.userEventTriggered(context, event); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java deleted file mode 100644 index 66039f7b9..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java +++ /dev/null @@ -1,198 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import com.southernstorm.noise.protocol.Noise; -import io.netty.bootstrap.Bootstrap; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import io.netty.handler.codec.haproxy.HAProxyMessageEncoder; -import io.netty.handler.codec.http.HttpClientCodec; -import io.netty.handler.codec.http.HttpHeaders; -import io.netty.handler.codec.http.HttpObjectAggregator; -import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; -import io.netty.handler.codec.http.websocketx.WebSocketVersion; -import io.netty.handler.ssl.SslContextBuilder; -import io.netty.util.ReferenceCountUtil; -import java.net.SocketAddress; -import java.net.URI; -import java.nio.ByteBuffer; -import java.security.cert.X509Certificate; -import java.util.ArrayList; -import java.util.List; -import java.util.UUID; -import java.util.function.Supplier; -import javax.annotation.Nullable; -import javax.net.ssl.SSLException; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; - -/** - * Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote - * gRPC server - */ -class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { - - private final boolean useTls; - @Nullable private final X509Certificate trustedServerCertificate; - private final URI websocketUri; - private final boolean authenticated; - @Nullable private final ECKeyPair ecKeyPair; - private final ECPublicKey serverPublicKey; - @Nullable private final UUID accountIdentifier; - private final byte deviceId; - private final HttpHeaders headers; - private final SocketAddress remoteServerAddress; - private final WebSocketCloseListener webSocketCloseListener; - @Nullable private final Supplier proxyMessageSupplier; - // If provided, will be sent with the payload in the noise handshake - private final byte[] fastOpenRequest; - - private final List pendingReads = new ArrayList<>(); - - private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake"; - - EstablishRemoteConnectionHandler( - final boolean useTls, - @Nullable final X509Certificate trustedServerCertificate, - final URI websocketUri, - final boolean authenticated, - @Nullable final ECKeyPair ecKeyPair, - final ECPublicKey serverPublicKey, - @Nullable final UUID accountIdentifier, - final byte deviceId, - final HttpHeaders headers, - final SocketAddress remoteServerAddress, - final WebSocketCloseListener webSocketCloseListener, - @Nullable Supplier proxyMessageSupplier, - @Nullable byte[] fastOpenRequest) { - - this.useTls = useTls; - this.trustedServerCertificate = trustedServerCertificate; - this.websocketUri = websocketUri; - this.authenticated = authenticated; - this.ecKeyPair = ecKeyPair; - this.serverPublicKey = serverPublicKey; - this.accountIdentifier = accountIdentifier; - this.deviceId = deviceId; - this.headers = headers; - this.remoteServerAddress = remoteServerAddress; - this.webSocketCloseListener = webSocketCloseListener; - this.proxyMessageSupplier = proxyMessageSupplier; - this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest; - } - - @Override - public void handlerAdded(final ChannelHandlerContext localContext) { - new Bootstrap() - .channel(NioSocketChannel.class) - .group(localContext.channel().eventLoop()) - .handler(new ChannelInitializer() { - @Override - protected void initChannel(final SocketChannel channel) throws SSLException { - - if (proxyMessageSupplier != null) { - // In a production setting, we'd want some mechanism to remove these handlers after the initial message - // were sent. Since this is just for testing, though, we can tolerate the inefficiency of leaving a - // pair of inert handlers in the pipeline. - channel.pipeline() - .addLast(HAProxyMessageEncoder.INSTANCE) - .addLast(new HAProxyMessageSender(proxyMessageSupplier)); - } - - if (useTls) { - final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient(); - - if (trustedServerCertificate != null) { - sslContextBuilder.trustManager(trustedServerCertificate); - } - - channel.pipeline().addLast(sslContextBuilder.build().newHandler(channel.alloc())); - } - - final NoiseClientHandshakeHelper helper = authenticated - ? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair) - : NoiseClientHandshakeHelper.NK(serverPublicKey); - - channel.pipeline() - .addLast(new HttpClientCodec()) - .addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN)) - // Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we - // want to react to them on our own, we need to catch them before they hit that handler. - .addLast(new InboundCloseWebSocketFrameHandler(webSocketCloseListener)) - .addLast(new WebSocketClientProtocolHandler(websocketUri, - WebSocketVersion.V13, - null, - false, - headers, - Noise.MAX_PACKET_LEN, - 10_000)) - .addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener)) - // Listens for a Websocket HANDSHAKE_COMPLETE and begins the noise handshake when it is done - .addLast(new NoiseClientHandshakeHandler(helper, initialPayload())) - .addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event) - throws Exception { - if (event instanceof NoiseClientHandshakeCompleteEvent handshakeCompleteEvent) { - remoteContext.pipeline() - .replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel())); - localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel())); - - // If there was a payload response on the handshake, write it back to our gRPC client - handshakeCompleteEvent.fastResponse().ifPresent(plaintext -> - localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext))); - - // Forward any messages we got from our gRPC client, now will be proxied to the remote context - pendingReads.forEach(localContext::fireChannelRead); - pendingReads.clear(); - localContext.pipeline().remove(EstablishRemoteConnectionHandler.this); - } - - super.userEventTriggered(remoteContext, event); - } - }) - .addLast(new ClientErrorHandler()); - } - }) - .connect(remoteServerAddress) - .addListener((ChannelFutureListener) future -> { - if (future.isSuccess()) { - // Close the local connection if the remote channel closes and vice versa - future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close()); - localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close()); - } else { - localContext.close(); - } - }); - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) { - pendingReads.add(message); - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - pendingReads.forEach(ReferenceCountUtil::release); - pendingReads.clear(); - } - - private byte[] initialPayload() { - if (!authenticated) { - return fastOpenRequest; - } - - final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length); - bb.putLong(accountIdentifier.getMostSignificantBits()); - bb.putLong(accountIdentifier.getLeastSignificantBits()); - bb.put(deviceId); - bb.put(fastOpenRequest); - bb.flip(); - return bb.array(); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/FastOpenRequestBufferedEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/FastOpenRequestBufferedEvent.java deleted file mode 100644 index 7f7b68df6..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/FastOpenRequestBufferedEvent.java +++ /dev/null @@ -1,9 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.buffer.ByteBuf; - -record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java index 1cba6d4b1..654ac5d72 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java @@ -146,14 +146,14 @@ class GrpcClientConnectionManagerTest { @ParameterizedTest @MethodSource - void handleHandshakeCompleteRequestAttributes(final InetAddress preferredRemoteAddress, + void handleHandshakeInitiatedRequestAttributes(final InetAddress preferredRemoteAddress, final String userAgentHeader, final String acceptLanguageHeader, final RequestAttributes expectedRequestAttributes) { final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - GrpcClientConnectionManager.handleHandshakeComplete(embeddedChannel, + GrpcClientConnectionManager.handleHandshakeInitiated(embeddedChannel, preferredRemoteAddress, userAgentHeader, acceptLanguageHeader); @@ -162,7 +162,7 @@ class GrpcClientConnectionManagerTest { embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get()); } - private static List handleHandshakeCompleteRequestAttributes() { + private static List handleHandshakeInitiatedRequestAttributes() { final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1"); return List.of( diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/InboundCloseWebSocketFrameHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/InboundCloseWebSocketFrameHandler.java deleted file mode 100644 index e946c937a..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/InboundCloseWebSocketFrameHandler.java +++ /dev/null @@ -1,23 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; - -class InboundCloseWebSocketFrameHandler extends ChannelInboundHandlerAdapter { - - private final WebSocketCloseListener webSocketCloseListener; - - public InboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) { - this.webSocketCloseListener = webSocketCloseListener; - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - if (message instanceof CloseWebSocketFrame closeWebSocketFrame) { - webSocketCloseListener.handleWebSocketClosedByServer(closeWebSocketFrame.statusCode()); - } - - super.channelRead(context, message); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java index 51518a274..987e7f166 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java @@ -10,9 +10,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import com.southernstorm.noise.protocol.CipherStatePair; import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import java.util.Optional; import javax.crypto.BadPaddingException; import javax.crypto.ShortBufferException; @@ -49,22 +49,18 @@ class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest { assertEquals( initiateHandshakeMessageLength, clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length)); - - final BinaryWebSocketFrame initiateHandshakeFrame = new BinaryWebSocketFrame( - Unpooled.wrappedBuffer(initiateHandshakeMessage)); - - assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeFrame).await().isSuccess()); - assertEquals(0, initiateHandshakeFrame.refCnt()); + final ByteBuf initiateHandshakeMessageBuf = Unpooled.wrappedBuffer(initiateHandshakeMessage); + assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeMessageBuf).await().isSuccess()); + assertEquals(0, initiateHandshakeMessageBuf.refCnt()); embeddedChannel.runPendingTasks(); // Read responder handshake message assertFalse(embeddedChannel.outboundMessages().isEmpty()); - final BinaryWebSocketFrame responderHandshakeFrame = (BinaryWebSocketFrame) - embeddedChannel.outboundMessages().poll(); + final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); @SuppressWarnings("DataFlowIssue") final byte[] responderHandshakeBytes = - new byte[responderHandshakeFrame.content().readableBytes()]; - responderHandshakeFrame.content().readBytes(responderHandshakeBytes); + new byte[responderHandshakeFrame.readableBytes()]; + responderHandshakeFrame.readBytes(responderHandshakeBytes); // ephemeral key, empty encrypted payload AEAD tag final byte[] handshakeResponsePayload = new byte[32 + 16]; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java index b8c539295..b498ef954 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java @@ -15,10 +15,10 @@ import static org.mockito.Mockito.when; import com.southernstorm.noise.protocol.CipherStatePair; import com.southernstorm.noise.protocol.HandshakeState; import com.southernstorm.noise.protocol.Noise; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.util.internal.EmptyArrays; import java.nio.ByteBuffer; import java.security.NoSuchAlgorithmException; @@ -34,6 +34,7 @@ import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.TestRandomUtil; @@ -204,13 +205,12 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>(); when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture); - final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer( - initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId)))); + final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer( + initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))); assertTrue(embeddedChannel.writeOneInbound(initiatorMessageFrame).await().isSuccess()); // While waiting for the public key, send another message - final ChannelFuture f = embeddedChannel.writeOneInbound( - new BinaryWebSocketFrame(Unpooled.wrappedBuffer(new byte[0]))).await(); + final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await(); assertInstanceOf(NoiseHandshakeException.class, f.exceptionNow()); findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey())); @@ -267,8 +267,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { final HandshakeState clientHandshakeState = clientHandshakeState(); final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload); - final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame( - Unpooled.wrappedBuffer(initiatorMessage)); + final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer(initiatorMessage); final ChannelFuture await = embeddedChannel.writeOneInbound(initiatorMessageFrame).await(); assertEquals(0, initiatorMessageFrame.refCnt()); if (!await.isSuccess()) { @@ -286,11 +285,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { assertFalse(embeddedChannel.outboundMessages().isEmpty()); - final BinaryWebSocketFrame serverStaticKeyMessageFrame = - (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); + final ByteBuf serverStaticKeyMessageFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); @SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes = - new byte[serverStaticKeyMessageFrame.content().readableBytes()]; - serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes); + new byte[serverStaticKeyMessageFrame.readableBytes()]; + serverStaticKeyMessageFrame.readBytes(serverStaticKeyMessageBytes); assertEquals(readHandshakeResponse(clientHandshakeState, serverStaticKeyMessageBytes).length, 0); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHandler.java deleted file mode 100644 index c8df5b2be..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHandler.java +++ /dev/null @@ -1,55 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; -import java.util.Optional; - -class NoiseClientHandshakeHandler extends ChannelInboundHandlerAdapter { - - private final NoiseClientHandshakeHelper handshakeHelper; - private final byte[] payload; - - NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper, final byte[] payload) { - this.handshakeHelper = handshakeHelper; - this.payload = payload; - } - - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception { - if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) { - if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { - byte[] handshakeMessage = handshakeHelper.write(payload); - context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage))) - .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); - } - } - super.userEventTriggered(context, event); - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) - throws NoiseHandshakeException { - if (message instanceof BinaryWebSocketFrame frame) { - try { - final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame.content())); - final Optional fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload); - context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split())); - context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse)); - } finally { - frame.release(); - } - } else { - context.fireChannelRead(message); - } - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - handshakeHelper.destroy(); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java index e818a93e3..f6c9a4f79 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientHandshakeHelper; public class NoiseHandshakeHelperTest { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java deleted file mode 100644 index 74e760e6c..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java +++ /dev/null @@ -1,160 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.bootstrap.ServerBootstrap; -import io.netty.buffer.ByteBufUtil; -import io.netty.channel.Channel; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalChannel; -import io.netty.channel.local.LocalServerChannel; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import io.netty.handler.codec.http.DefaultHttpHeaders; -import io.netty.handler.codec.http.HttpHeaders; -import java.net.SocketAddress; -import java.net.URI; -import java.security.cert.X509Certificate; -import java.util.UUID; -import java.util.function.Function; -import java.util.function.Supplier; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; - -class NoiseWebSocketTunnelClient implements AutoCloseable { - - private final ServerBootstrap serverBootstrap; - private Channel serverChannel; - - static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated"); - static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous"); - - static class Builder { - - final SocketAddress remoteServerAddress; - NioEventLoopGroup eventLoopGroup; - ECPublicKey serverPublicKey; - - URI websocketUri = ANONYMOUS_WEBSOCKET_URI; - HttpHeaders headers = new DefaultHttpHeaders(); - WebSocketCloseListener webSocketCloseListener = WebSocketCloseListener.NOOP_LISTENER; - - boolean authenticated = false; - ECKeyPair ecKeyPair = null; - UUID accountIdentifier = null; - byte deviceId = 0x00; - boolean useTls; - X509Certificate trustedServerCertificate = null; - Supplier proxyMessageSupplier = null; - - Builder( - final SocketAddress remoteServerAddress, - final NioEventLoopGroup eventLoopGroup, - final ECPublicKey serverPublicKey) { - this.remoteServerAddress = remoteServerAddress; - this.eventLoopGroup = eventLoopGroup; - this.serverPublicKey = serverPublicKey; - } - - Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) { - this.authenticated = true; - this.accountIdentifier = accountIdentifier; - this.deviceId = deviceId; - this.ecKeyPair = ecKeyPair; - this.websocketUri = AUTHENTICATED_WEBSOCKET_URI; - return this; - } - - Builder setWebsocketUri(final URI websocketUri) { - this.websocketUri = websocketUri; - return this; - } - - Builder setUseTls(X509Certificate trustedServerCertificate) { - this.useTls = true; - this.trustedServerCertificate = trustedServerCertificate; - return this; - } - - Builder setProxyMessageSupplier(Supplier proxyMessageSupplier) { - this.proxyMessageSupplier = proxyMessageSupplier; - return this; - } - - Builder setHeaders(final HttpHeaders headers) { - this.headers = headers; - return this; - } - - Builder setWebSocketCloseListener(final WebSocketCloseListener webSocketCloseListener) { - this.webSocketCloseListener = webSocketCloseListener; - return this; - } - - Builder setServerPublicKey(ECPublicKey serverPublicKey) { - this.serverPublicKey = serverPublicKey; - return this; - } - - NoiseWebSocketTunnelClient build() { - final NoiseWebSocketTunnelClient client = - new NoiseWebSocketTunnelClient(eventLoopGroup, fastOpenRequest -> new EstablishRemoteConnectionHandler( - useTls, trustedServerCertificate, websocketUri, authenticated, ecKeyPair, serverPublicKey, - accountIdentifier, deviceId, headers, remoteServerAddress, webSocketCloseListener, proxyMessageSupplier, - fastOpenRequest)); - client.start(); - return client; - } - } - - private NoiseWebSocketTunnelClient(NioEventLoopGroup eventLoopGroup, - Function handler) { - - this.serverBootstrap = new ServerBootstrap() - .localAddress(new LocalAddress("websocket-noise-tunnel-client")) - .channel(LocalServerChannel.class) - .group(eventLoopGroup) - .childHandler(new ChannelInitializer() { - @Override - protected void initChannel(final LocalChannel localChannel) { - localChannel.pipeline() - // We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the - // stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put - // in the handshake. - .addLast(Http2Buffering.handler()) - // Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At - // that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually - // connect to the remote service - .addLast(new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { - if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) { - byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest()); - requestBufferedEvent.fastOpenRequest().release(); - ctx.pipeline().addLast(handler.apply(fastOpenRequest)); - } - super.userEventTriggered(ctx, evt); - } - }) - .addLast(new ClientErrorHandler()); - } - }); - } - - - LocalAddress getLocalAddress() { - return (LocalAddress) serverChannel.localAddress(); - } - - private NoiseWebSocketTunnelClient start() { - serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel(); - return this; - } - - @Override - public void close() throws InterruptedException { - serverChannel.close().await(); - } - -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java deleted file mode 100644 index 301443a4a..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java +++ /dev/null @@ -1,705 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyByte; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.google.protobuf.ByteString; -import io.grpc.ManagedChannel; -import io.grpc.ServerBuilder; -import io.grpc.Status; -import io.grpc.netty.NettyChannelBuilder; -import io.grpc.stub.StreamObserver; -import io.netty.channel.DefaultEventLoopGroup; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalChannel; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.handler.codec.haproxy.HAProxyCommand; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; -import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; -import io.netty.handler.codec.http.DefaultHttpHeaders; -import io.netty.handler.codec.http.HttpHeaders; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.nio.charset.StandardCharsets; -import java.security.KeyFactory; -import java.security.KeyStore; -import java.security.NoSuchAlgorithmException; -import java.security.PrivateKey; -import java.security.SecureRandom; -import java.security.cert.CertificateException; -import java.security.cert.CertificateFactory; -import java.security.cert.X509Certificate; -import java.security.spec.InvalidKeySpecException; -import java.security.spec.PKCS8EncodedKeySpec; -import java.util.Base64; -import java.util.List; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Supplier; -import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManagerFactory; -import org.apache.commons.lang3.RandomStringUtils; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.signal.chat.rpc.EchoRequest; -import org.signal.chat.rpc.EchoResponse; -import org.signal.chat.rpc.EchoServiceGrpc; -import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; -import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; -import org.signal.chat.rpc.GetRequestAttributesRequest; -import org.signal.chat.rpc.GetRequestAttributesResponse; -import org.signal.chat.rpc.RequestAttributesGrpc; -import org.signal.libsignal.protocol.ecc.Curve; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; -import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; -import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor; -import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; -import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; -import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; -import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.util.UUIDUtil; - -class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTest { - - private static NioEventLoopGroup nioEventLoopGroup; - private static DefaultEventLoopGroup defaultEventLoopGroup; - private static ExecutorService delegatedTaskExecutor; - private static ExecutorService serverCallExecutor; - - private static X509Certificate serverTlsCertificate; - - private GrpcClientConnectionManager grpcClientConnectionManager; - private ClientPublicKeysManager clientPublicKeysManager; - - private ECKeyPair serverKeyPair; - private ECKeyPair clientKeyPair; - - private ManagedLocalGrpcServer authenticatedGrpcServer; - private ManagedLocalGrpcServer anonymousGrpcServer; - - private NoiseWebSocketTunnelServer tlsNoiseWebSocketTunnelServer; - private NoiseWebSocketTunnelServer plaintextNoiseWebSocketTunnelServer; - - private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID(); - private static final byte DEVICE_ID = Device.PRIMARY_ID; - - private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16); - - // Please note that this certificate/key are used only for testing and are not used anywhere outside of this test. - // They were generated with: - // - // ```shell - // openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost" - // ``` - private static final String SERVER_CERTIFICATE = """ - -----BEGIN CERTIFICATE----- - MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw - FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx - MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA - IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV - jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq - SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME - GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG - SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw - XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi - iOr9sHiO8Rn2u0xRKgU5Ig== - -----END CERTIFICATE----- - """; - - // BEGIN/END PRIVATE KEY header/footer removed for easier parsing - private static final String SERVER_PRIVATE_KEY = """ - MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj - kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd - PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA - O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo= - """; - - @BeforeAll - static void setUpBeforeAll() throws CertificateException { - nioEventLoopGroup = new NioEventLoopGroup(); - defaultEventLoopGroup = new DefaultEventLoopGroup(); - delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor(); - serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor(); - - final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); - serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate( - new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8))); - } - - @BeforeEach - void setUp() throws NoSuchAlgorithmException, InvalidKeySpecException, IOException, InterruptedException { - - final PrivateKey serverTlsPrivateKey; - { - final KeyFactory keyFactory = KeyFactory.getInstance("EC"); - serverTlsPrivateKey = - keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY))); - } - - clientKeyPair = Curve.generateKeyPair(); - serverKeyPair = Curve.generateKeyPair(); - - grpcClientConnectionManager = new GrpcClientConnectionManager(); - - clientPublicKeysManager = mock(ClientPublicKeysManager.class); - when(clientPublicKeysManager.findPublicKey(any(), anyByte())) - .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - - when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); - - final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated"); - final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous"); - - authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) { - @Override - protected void configureServer(final ServerBuilder serverBuilder) { - serverBuilder - .executor(serverCallExecutor) - .addService(new RequestAttributesServiceImpl()) - .addService(new EchoServiceImpl()) - .intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager)) - .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) - .intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager)); - } - }; - - authenticatedGrpcServer.start(); - - anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) { - @Override - protected void configureServer(final ServerBuilder serverBuilder) { - serverBuilder - .executor(serverCallExecutor) - .addService(new RequestAttributesServiceImpl()) - .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) - .intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager)); - } - }; - - anonymousGrpcServer.start(); - - tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0, - new X509Certificate[]{serverTlsCertificate}, - serverTlsPrivateKey, - nioEventLoopGroup, - delegatedTaskExecutor, - grpcClientConnectionManager, - clientPublicKeysManager, - serverKeyPair, - authenticatedGrpcServerAddress, - anonymousGrpcServerAddress, - RECOGNIZED_PROXY_SECRET); - - tlsNoiseWebSocketTunnelServer.start(); - - plaintextNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0, - null, - null, - nioEventLoopGroup, - delegatedTaskExecutor, - grpcClientConnectionManager, - clientPublicKeysManager, - serverKeyPair, - authenticatedGrpcServerAddress, - anonymousGrpcServerAddress, - RECOGNIZED_PROXY_SECRET); - - plaintextNoiseWebSocketTunnelServer.start(); - } - - @AfterEach - void tearDown() throws InterruptedException { - tlsNoiseWebSocketTunnelServer.stop(); - plaintextNoiseWebSocketTunnelServer.stop(); - authenticatedGrpcServer.stop(); - anonymousGrpcServer.stop(); - } - - @AfterAll - static void tearDownAfterAll() throws InterruptedException { - nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await(); - defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await(); - - delegatedTaskExecutor.shutdown(); - //noinspection ResultOfMethodCallIgnored - delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS); - - serverCallExecutor.shutdown(); - //noinspection ResultOfMethodCallIgnored - serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException { - try (final NoiseWebSocketTunnelClient client = authenticated() - .setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage)) - .build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - - assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); - assertEquals(DEVICE_ID, response.getDeviceId()); - } finally { - channel.shutdown(); - } - } - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void connectAuthenticatedPlaintext(final boolean includeProxyMessage) throws InterruptedException { - try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient - .Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey()) - .setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID) - .setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage)) - .build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - - assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); - assertEquals(DEVICE_ID, response.getDeviceId()); - } finally { - channel.shutdown(); - } - } - } - - @Test - void connectAuthenticatedBadServerKeySignature() throws InterruptedException { - final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); - - // Try to verify the server's public key with something other than the key with which it was signed - try (final NoiseWebSocketTunnelClient client = authenticated() - .setWebSocketCloseListener(webSocketCloseListener) - .setServerPublicKey(Curve.generateKeyPair().getPublicKey()) - .build()) { - - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - } - - verify(webSocketCloseListener).handleWebSocketClosedByServer( - ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); - } - - @Test - void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException { - final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); - - when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey()))); - - try (final NoiseWebSocketTunnelClient client = authenticated() - .setWebSocketCloseListener(webSocketCloseListener) - .build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - } - - verify(webSocketCloseListener).handleWebSocketClosedByServer( - ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode()); - } - - @Test - void connectAuthenticatedUnrecognizedDevice() throws InterruptedException { - final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); - - when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) - .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - - try (final NoiseWebSocketTunnelClient client = authenticated() - .setWebSocketCloseListener(webSocketCloseListener) - .build()) { - - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - } - - verify(webSocketCloseListener).handleWebSocketClosedByServer( - ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode()); - } - - @Test - void connectAuthenticatedToAnonymousService() throws InterruptedException { - final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); - - try (final NoiseWebSocketTunnelClient client = authenticated() - .setWebsocketUri(NoiseWebSocketTunnelClient.ANONYMOUS_WEBSOCKET_URI) - .setWebSocketCloseListener(webSocketCloseListener) - .build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - } - - verify(webSocketCloseListener).handleWebSocketClosedByServer( - ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); - } - - @Test - void connectAnonymous() throws InterruptedException { - try (final NoiseWebSocketTunnelClient client = anonymous().build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - - assertTrue(response.getAccountIdentifier().isEmpty()); - assertEquals(0, response.getDeviceId()); - } finally { - channel.shutdown(); - } - } - } - - @Test - void connectAnonymousBadServerKeySignature() throws InterruptedException { - final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); - - // Try to verify the server's public key with something other than the key with which it was signed - try (final NoiseWebSocketTunnelClient client = anonymous() - .setWebSocketCloseListener(webSocketCloseListener) - .setServerPublicKey(Curve.generateKeyPair().getPublicKey()) - .build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - } - - verify(webSocketCloseListener).handleWebSocketClosedByServer( - ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); - } - - @Test - void connectAnonymousToAuthenticatedService() throws InterruptedException { - final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); - - try (final NoiseWebSocketTunnelClient client = anonymous() - .setWebsocketUri(NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI) - .setWebSocketCloseListener(webSocketCloseListener) - .build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - } - - verify(webSocketCloseListener).handleWebSocketClosedByServer( - ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); - } - - private ManagedChannel buildManagedChannel(final LocalAddress localAddress) { - return NettyChannelBuilder.forAddress(localAddress) - .channelType(LocalChannel.class) - .eventLoopGroup(defaultEventLoopGroup) - .usePlaintext() - .build(); - } - - @Test - void rejectIllegalRequests() throws Exception { - - final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); - keyStore.load(null, null); - keyStore.setCertificateEntry("tunnel", serverTlsCertificate); - - final TrustManagerFactory trustManagerFactory = - TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - - trustManagerFactory.init(keyStore); - - final SSLContext sslContext = SSLContext.getInstance("TLS"); - sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom()); - - final URI authenticatedUri = - new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null); - - final URI incorrectUri = - new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null); - - try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) { - assertEquals(405, httpClient.send(HttpRequest.newBuilder() - .uri(authenticatedUri) - .PUT(HttpRequest.BodyPublishers.ofString("test")) - .build(), - HttpResponse.BodyHandlers.ofString()).statusCode(), - "Non-GET requests should not be allowed"); - - assertEquals(426, httpClient.send(HttpRequest.newBuilder() - .GET() - .uri(authenticatedUri) - .build(), - HttpResponse.BodyHandlers.ofString()).statusCode(), - "GET requests without upgrade headers should not be allowed"); - - assertEquals(404, httpClient.send(HttpRequest.newBuilder() - .GET() - .uri(incorrectUri) - .build(), - HttpResponse.BodyHandlers.ofString()).statusCode(), - "GET requests to unrecognized URIs should not be allowed"); - } - } - - @Test - void getRequestAttributes() throws InterruptedException { - final String remoteAddress = "4.5.6.7"; - final String acceptLanguage = "en"; - final String userAgent = "Signal-Desktop/1.2.3 Linux"; - - final HttpHeaders headers = new DefaultHttpHeaders() - .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) - .add("X-Forwarded-For", remoteAddress) - .add("Accept-Language", acceptLanguage) - .add("User-Agent", userAgent); - - try (final NoiseWebSocketTunnelClient client = anonymous().setHeaders(headers).build()) { - - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()); - - assertEquals(remoteAddress, response.getRemoteAddress()); - assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList()); - assertEquals(userAgent, response.getUserAgent()); - } finally { - channel.shutdown(); - } - } - } - - @Test - void closeForReauthentication() throws InterruptedException { - final CountDownLatch connectionCloseLatch = new CountDownLatch(1); - final AtomicInteger serverCloseStatusCode = new AtomicInteger(0); - final AtomicBoolean closedByServer = new AtomicBoolean(false); - - final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() { - - @Override - public void handleWebSocketClosedByClient(final int statusCode) { - serverCloseStatusCode.set(statusCode); - closedByServer.set(false); - connectionCloseLatch.countDown(); - } - - @Override - public void handleWebSocketClosedByServer(final int statusCode) { - serverCloseStatusCode.set(statusCode); - closedByServer.set(true); - connectionCloseLatch.countDown(); - } - }; - - try (final NoiseWebSocketTunnelClient client = authenticated() - .setWebSocketCloseListener(webSocketCloseListener) - .build()) { - - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - - assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); - assertEquals(DEVICE_ID, response.getDeviceId()); - - grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID)); - assertTrue(connectionCloseLatch.await(2, TimeUnit.SECONDS)); - - assertEquals(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED.getStatusCode(), - serverCloseStatusCode.get()); - - assertTrue(closedByServer.get()); - } finally { - channel.shutdown(); - } - } - } - - @Test - void waitForCallCompletion() throws InterruptedException { - final CountDownLatch connectionCloseLatch = new CountDownLatch(1); - final AtomicInteger serverCloseStatusCode = new AtomicInteger(0); - final AtomicBoolean closedByServer = new AtomicBoolean(false); - - final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() { - - @Override - public void handleWebSocketClosedByClient(final int statusCode) { - serverCloseStatusCode.set(statusCode); - closedByServer.set(false); - connectionCloseLatch.countDown(); - } - - @Override - public void handleWebSocketClosedByServer(final int statusCode) { - serverCloseStatusCode.set(statusCode); - closedByServer.set(true); - connectionCloseLatch.countDown(); - } - }; - - try (final NoiseWebSocketTunnelClient client = authenticated() - .setWebSocketCloseListener(webSocketCloseListener) - .build()) { - - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - - assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); - assertEquals(DEVICE_ID, response.getDeviceId()); - - final CountDownLatch responseCountDownLatch = new CountDownLatch(1); - - // Start an open-ended server call and leave it in a non-complete state - final StreamObserver echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream( - new StreamObserver<>() { - @Override - public void onNext(final EchoResponse echoResponse) { - responseCountDownLatch.countDown(); - } - - @Override - public void onError(final Throwable throwable) { - } - - @Override - public void onCompleted() { - } - }); - - // Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before - // the request even starts. Make sure we've done at least one request/response pair to ensure that the call has - // truly started before requesting connection closure. - echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()); - assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS)); - - grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID)); - assertFalse(connectionCloseLatch.await(1, TimeUnit.SECONDS), - "Channel should not close until active requests have finished"); - - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel) - .echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build())); - - // Complete the open-ended server call - echoRequestStreamObserver.onCompleted(); - - assertTrue(connectionCloseLatch.await(1, TimeUnit.SECONDS), - "Channel should close once active requests have finished"); - - assertTrue(closedByServer.get()); - assertEquals(4004, serverCloseStatusCode.get()); - } finally { - channel.shutdown(); - } - } - } - - private NoiseWebSocketTunnelClient.Builder anonymous() { - return new NoiseWebSocketTunnelClient - .Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey()) - .setUseTls(serverTlsCertificate); - - } - - private NoiseWebSocketTunnelClient.Builder authenticated() { - return new NoiseWebSocketTunnelClient - .Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey()) - .setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID) - .setUseTls(serverTlsCertificate); - } - - private static Supplier proxyMessageSupplier(boolean includeProxyMesage) { - return includeProxyMesage - ? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, - "10.0.0.1", "10.0.0.2", 12345, 443) - : null; - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseWebSocketFrameHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseWebSocketFrameHandler.java deleted file mode 100644 index 682b22d90..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseWebSocketFrameHandler.java +++ /dev/null @@ -1,24 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelOutboundHandlerAdapter; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; - -class OutboundCloseWebSocketFrameHandler extends ChannelOutboundHandlerAdapter { - - private final WebSocketCloseListener webSocketCloseListener; - - OutboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) { - this.webSocketCloseListener = webSocketCloseListener; - } - - @Override - public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception { - if (message instanceof CloseWebSocketFrame closeWebSocketFrame) { - webSocketCloseListener.handleWebSocketClosedByClient(closeWebSocketFrame.statusCode()); - } - - super.write(context, message, promise); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/TypedNoiseChannelDuplexHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/TypedNoiseChannelDuplexHandler.java deleted file mode 100644 index 6a316d745..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/TypedNoiseChannelDuplexHandler.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketFrame; -import io.netty.util.ReferenceCountUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * A TypedNoiseChannelDuplexHandler is a convenience {@link ChannelDuplexHandler} that can be inserted in a pipeline - * after a successful websocket handshake. It expects inbound messages to be {@link BinaryWebSocketFrame}s and outbound - * messages to be bytes. - */ -abstract class TypedNoiseChannelDuplexHandler extends ChannelDuplexHandler { - - private static final Logger log = LoggerFactory.getLogger(TypedNoiseChannelDuplexHandler.class); - - /** - * Handle an inbound message. The frame will be automatically released after the method is finished running. - * - * @param context The current {@link ChannelHandlerContext} - * @param frameBytes A {@link ByteBuf} extracted from a {@link BinaryWebSocketFrame} that contains a complete noise - * packet - * @throws Exception - */ - abstract void handleInbound(final ChannelHandlerContext context, ByteBuf frameBytes) throws Exception; - - /** - * Handle an outbound byte message. The message will be automatically released after the method is finished running. - * - * @param context The current {@link ChannelHandlerContext} - * @param bytes The bytes to write - * @throws Exception - */ - abstract void handleOutbound(final ChannelHandlerContext context, final ByteBuf bytes, - final ChannelPromise promise) throws Exception; - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - try { - if (message instanceof BinaryWebSocketFrame frame) { - handleInbound(context, frame.content()); - } else { - // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an - // error - throw new IllegalArgumentException("Unexpected message in pipeline: " + message); - } - } finally { - ReferenceCountUtil.release(message); - } - } - - - @Override - public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) - throws Exception { - if (message instanceof ByteBuf serverResponse) { - try { - handleOutbound(context, serverResponse, promise); - } finally { - ReferenceCountUtil.release(serverResponse); - } - } else { - if (!(message instanceof WebSocketFrame)) { - // Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that - // get issued in response to exceptions) - log.warn("Unexpected object in pipeline: {}", message); - } - context.write(message, promise); - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketCloseListener.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketCloseListener.java deleted file mode 100644 index 38253eca7..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketCloseListener.java +++ /dev/null @@ -1,18 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -interface WebSocketCloseListener { - - WebSocketCloseListener NOOP_LISTENER = new WebSocketCloseListener() { - @Override - public void handleWebSocketClosedByClient(final int statusCode) { - } - - @Override - public void handleWebSocketClosedByServer(final int statusCode) { - } - }; - - void handleWebSocketClosedByClient(int statusCode); - - void handleWebSocketClosedByServer(int statusCode); -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ClientErrorHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ClientErrorHandler.java new file mode 100644 index 000000000..0b6e2c0ac --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ClientErrorHandler.java @@ -0,0 +1,16 @@ +package org.whispersystems.textsecuregcm.grpc.net.client; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ClientErrorHandler extends ChannelInboundHandlerAdapter { + private static final Logger log = LoggerFactory.getLogger(ClientErrorHandler.class); + + @Override + public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) { + log.error("Caught inbound error in client; closing connection", cause); + context.channel().close(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java new file mode 100644 index 000000000..4f118b99c --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java @@ -0,0 +1,53 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net.client; + +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos; + +public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) { + + public enum CloseReason { + SERVER_CLOSED, + NOISE_ERROR, + NOISE_HANDSHAKE_ERROR, + AUTHENTICATION_ERROR, + INTERNAL_SERVER_ERROR, + UNKNOWN + } + + public enum CloseInitiator { + SERVER, + CLIENT + } + + public static CloseFrameEvent fromWebsocketCloseFrame( + CloseWebSocketFrame closeWebSocketFrame, + CloseInitiator closeInitiator) { + final CloseReason code = switch (closeWebSocketFrame.statusCode()) { + case 4003 -> CloseReason.NOISE_ERROR; + case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR; + case 4002 -> CloseReason.AUTHENTICATION_ERROR; + case 1011 -> CloseReason.INTERNAL_SERVER_ERROR; + case 1012 -> CloseReason.SERVER_CLOSED; + default -> CloseReason.UNKNOWN; + }; + return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText()); + } + + public static CloseFrameEvent fromNoiseDirectErrorFrame( + NoiseDirectProtos.Error noiseDirectError, + CloseInitiator closeInitiator) { + final CloseReason code = switch (noiseDirectError.getType()) { + case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR; + case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR; + case UNAVAILABLE -> CloseReason.SERVER_CLOSED; + case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR; + case AUTHENTICATION_ERROR -> CloseReason.AUTHENTICATION_ERROR; + case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN; + }; + return new CloseFrameEvent(code, closeInitiator, noiseDirectError.getMessage()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java new file mode 100644 index 000000000..c6e11c3cc --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java @@ -0,0 +1,136 @@ +package org.whispersystems.textsecuregcm.grpc.net.client; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.ReferenceCountUtil; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import javax.annotation.Nullable; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler; + +/** + * Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote + * gRPC server. + *

+ * This handler waits until the first gRPC client message is ready and then establishes a connection with the remote + * gRPC server. It expects the provided remoteHandlerStack to emit a {@link ReadyForNoiseHandshakeEvent} when the remote + * connection is ready for its first inbound payload, and to emit a {@link NoiseClientHandshakeCompleteEvent} when the + * handshake is finished. + */ +class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { + + private final List remoteHandlerStack; + @Nullable + private final AuthenticatedDevice authenticatedDevice; + + private final SocketAddress remoteServerAddress; + // If provided, will be sent with the payload in the noise handshake + private final byte[] fastOpenRequest; + + private final List pendingReads = new ArrayList<>(); + + private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake"; + + EstablishRemoteConnectionHandler( + final List remoteHandlerStack, + @Nullable final AuthenticatedDevice authenticatedDevice, + final SocketAddress remoteServerAddress, + @Nullable byte[] fastOpenRequest) { + this.remoteHandlerStack = remoteHandlerStack; + this.authenticatedDevice = authenticatedDevice; + this.remoteServerAddress = remoteServerAddress; + this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest; + } + + @Override + public void handlerAdded(final ChannelHandlerContext localContext) { + new Bootstrap() + .channel(NioSocketChannel.class) + .group(localContext.channel().eventLoop()) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(final SocketChannel channel) throws Exception { + + for (ChannelHandler handler : remoteHandlerStack) { + channel.pipeline().addLast(handler); + } + channel.pipeline() + .addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event) + throws Exception { + switch (event) { + case ReadyForNoiseHandshakeEvent ignored -> + remoteContext.writeAndFlush(Unpooled.wrappedBuffer(initialPayload())) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + case NoiseClientHandshakeCompleteEvent(Optional fastResponse) -> { + remoteContext.pipeline() + .replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel())); + localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel())); + + // If there was a payload response on the handshake, write it back to our gRPC client + fastResponse.ifPresent(plaintext -> + localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext))); + + // Forward any messages we got from our gRPC client, now will be proxied to the remote context + pendingReads.forEach(localContext::fireChannelRead); + pendingReads.clear(); + localContext.pipeline().remove(EstablishRemoteConnectionHandler.this); + } + default -> { + } + } + super.userEventTriggered(remoteContext, event); + } + }) + .addLast(new ClientErrorHandler()); + } + }) + .connect(remoteServerAddress) + .addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + // Close the local connection if the remote channel closes and vice versa + future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close()); + localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close()); + } else { + localContext.close(); + } + }); + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) { + pendingReads.add(message); + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + pendingReads.forEach(ReferenceCountUtil::release); + pendingReads.clear(); + } + + private byte[] initialPayload() { + if (authenticatedDevice == null) { + return fastOpenRequest; + } + + final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length); + bb.putLong(authenticatedDevice.accountIdentifier().getMostSignificantBits()); + bb.putLong(authenticatedDevice.accountIdentifier().getLeastSignificantBits()); + bb.put(authenticatedDevice.deviceId()); + bb.put(fastOpenRequest); + bb.flip(); + return bb.array(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/FastOpenRequestBufferedEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/FastOpenRequestBufferedEvent.java new file mode 100644 index 000000000..dae7c51d6 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/FastOpenRequestBufferedEvent.java @@ -0,0 +1,9 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net.client; + +import io.netty.buffer.ByteBuf; + +public record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageSender.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/HAProxyMessageSender.java similarity index 93% rename from service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageSender.java rename to service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/HAProxyMessageSender.java index 4bd337009..4237d889f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageSender.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/HAProxyMessageSender.java @@ -1,4 +1,4 @@ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.client; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/Http2Buffering.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/Http2Buffering.java similarity index 98% rename from service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/Http2Buffering.java rename to service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/Http2Buffering.java index 24e2e84c4..e476ca887 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/Http2Buffering.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/Http2Buffering.java @@ -2,7 +2,7 @@ * Copyright 2024 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -12,6 +12,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.util.ReferenceCountUtil; + import java.util.ArrayList; import java.util.Arrays; import java.util.HexFormat; @@ -27,12 +28,12 @@ import java.util.stream.Stream; * Once an entire request has been buffered, the handler will remove itself from the pipeline and emit a * {@link FastOpenRequestBufferedEvent} */ -class Http2Buffering { +public class Http2Buffering { /** * Create a pipeline handler that consumes serialized HTTP/2 ByteBufs and emits a fast-open request */ - static ChannelInboundHandler handler() { + public static ChannelInboundHandler handler() { return new Http2PrefaceHandler(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeCompleteEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeCompleteEvent.java similarity index 89% rename from service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeCompleteEvent.java rename to service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeCompleteEvent.java index ab94f1d2f..3174651df 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeCompleteEvent.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeCompleteEvent.java @@ -2,7 +2,7 @@ * Copyright 2024 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.client; import java.util.Optional; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHandler.java new file mode 100644 index 000000000..b9743bf68 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHandler.java @@ -0,0 +1,56 @@ +package org.whispersystems.textsecuregcm.grpc.net.client; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; + +import java.util.Optional; + +public class NoiseClientHandshakeHandler extends ChannelDuplexHandler { + + private final NoiseClientHandshakeHelper handshakeHelper; + + public NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper) { + this.handshakeHelper = handshakeHelper; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof ByteBuf plaintextHandshakePayload) { + final byte[] payloadBytes = ByteBufUtil.getBytes(plaintextHandshakePayload, + plaintextHandshakePayload.readerIndex(), plaintextHandshakePayload.readableBytes(), + false); + final byte[] handshakeMessage = handshakeHelper.write(payloadBytes); + ctx.write(Unpooled.wrappedBuffer(handshakeMessage), promise); + } else { + ctx.write(msg, promise); + } + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) + throws NoiseHandshakeException { + if (message instanceof ByteBuf frame) { + try { + final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame)); + final Optional fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload); + context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split())); + context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse)); + } finally { + frame.release(); + } + } else { + context.fireChannelRead(message); + } + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + handshakeHelper.destroy(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHelper.java similarity index 84% rename from service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHelper.java rename to service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHelper.java index cb3998628..df95bf762 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHelper.java @@ -2,7 +2,7 @@ * Copyright 2024 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.client; import com.southernstorm.noise.protocol.CipherStatePair; import com.southernstorm.noise.protocol.HandshakeState; @@ -11,6 +11,8 @@ import javax.crypto.BadPaddingException; import javax.crypto.ShortBufferException; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; public class NoiseClientHandshakeHelper { @@ -22,7 +24,7 @@ public class NoiseClientHandshakeHelper { this.handshakeState = handshakeState; } - static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) { + public static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) { try { final HandshakeState state = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR); state.getLocalKeyPair().setPrivateKey(clientStaticKey.getPrivateKey().serialize(), 0); @@ -34,7 +36,7 @@ public class NoiseClientHandshakeHelper { } } - static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) { + public static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) { try { final HandshakeState state = new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR); state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0); @@ -45,7 +47,7 @@ public class NoiseClientHandshakeHelper { } } - byte[] write(final byte[] requestPayload) throws ShortBufferException { + public byte[] write(final byte[] requestPayload) throws ShortBufferException { final byte[] initiateHandshakeMessage = new byte[initiateHandshakeKeysLength() + requestPayload.length + 16]; handshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length); return initiateHandshakeMessage; @@ -60,7 +62,7 @@ public class NoiseClientHandshakeHelper { }; } - byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException { + public byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException { // Don't process additional messages if the handshake failed and we're just waiting to close if (handshakeState.getAction() != HandshakeState.READ_MESSAGE) { throw new NoiseHandshakeException("Received message with handshake state " + handshakeState.getAction()); @@ -83,11 +85,11 @@ public class NoiseClientHandshakeHelper { } } - CipherStatePair split() { + public CipherStatePair split() { return this.handshakeState.split(); } - void destroy() { + public void destroy() { this.handshakeState.destroy(); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientTransportHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientTransportHandler.java similarity index 77% rename from service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientTransportHandler.java rename to service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientTransportHandler.java index 57229af83..cae4d82d0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientTransportHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientTransportHandler.java @@ -1,4 +1,4 @@ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.client; import com.southernstorm.noise.protocol.CipherState; import com.southernstorm.noise.protocol.CipherStatePair; @@ -8,8 +8,6 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.util.ReferenceCountUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -17,7 +15,7 @@ import org.slf4j.LoggerFactory; /** * A Noise transport handler manages a bidirectional Noise session after a handshake has completed. */ -class NoiseClientTransportHandler extends ChannelDuplexHandler { +public class NoiseClientTransportHandler extends ChannelDuplexHandler { private final CipherStatePair cipherStatePair; @@ -30,19 +28,19 @@ class NoiseClientTransportHandler extends ChannelDuplexHandler { @Override public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { try { - if (message instanceof BinaryWebSocketFrame frame) { + if (message instanceof ByteBuf frame) { final CipherState cipherState = cipherStatePair.getReceiver(); // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array. // We'll need to copy it to a heap buffer. - final byte[] noiseBuffer = ByteBufUtil.getBytes(frame.content()); + final byte[] noiseBuffer = ByteBufUtil.getBytes(frame); // Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length); context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength)); } else { - // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an + // Anything except binary frames should have been filtered out of the pipeline by now; treat this as an // error throw new IllegalArgumentException("Unexpected message in pipeline: " + message); } @@ -69,16 +67,13 @@ class NoiseClientTransportHandler extends ChannelDuplexHandler { // Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength); - context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise); + context.write(Unpooled.wrappedBuffer(noiseBuffer), promise); } finally { ReferenceCountUtil.release(plaintext); } } else { - if (!(message instanceof WebSocketFrame)) { - // Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that - // get issued in response to exceptions) - log.warn("Unexpected object in pipeline: {}", message); - } + // Clients only write ByteBufs or close the connection on errors, so any other message is unexpected + log.warn("Unexpected object in pipeline: {}", message); context.write(message, promise); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseTunnelClient.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseTunnelClient.java new file mode 100644 index 000000000..fddac67fb --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseTunnelClient.java @@ -0,0 +1,354 @@ +package org.whispersystems.textsecuregcm.grpc.net.client; + +import com.southernstorm.noise.protocol.Noise; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.*; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.codec.MessageToMessageCodec; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.handler.codec.haproxy.HAProxyMessageEncoder; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaders; +import java.net.SocketAddress; +import java.net.URI; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.function.Supplier; + +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; +import io.netty.handler.codec.http.websocketx.WebSocketVersion; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.util.ReferenceCountUtil; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame; +import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrameCodec; +import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos; +import org.whispersystems.textsecuregcm.grpc.net.websocket.WebsocketPayloadCodec; + +import javax.net.ssl.SSLException; + +public class NoiseTunnelClient implements AutoCloseable { + + private final CompletableFuture closeEventFuture; + private final ServerBootstrap serverBootstrap; + private Channel serverChannel; + + public static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated"); + public static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous"); + + public enum FramingType { + WEBSOCKET, + NOISE_DIRECT + } + + public static class Builder { + + final SocketAddress remoteServerAddress; + NioEventLoopGroup eventLoopGroup; + ECPublicKey serverPublicKey; + + FramingType framingType = FramingType.WEBSOCKET; + URI websocketUri = ANONYMOUS_WEBSOCKET_URI; + HttpHeaders headers = new DefaultHttpHeaders(); + + boolean authenticated = false; + ECKeyPair ecKeyPair = null; + UUID accountIdentifier = null; + byte deviceId = 0x00; + boolean useTls; + X509Certificate trustedServerCertificate = null; + Supplier proxyMessageSupplier = null; + + public Builder( + final SocketAddress remoteServerAddress, + final NioEventLoopGroup eventLoopGroup, + final ECPublicKey serverPublicKey) { + this.remoteServerAddress = remoteServerAddress; + this.eventLoopGroup = eventLoopGroup; + this.serverPublicKey = serverPublicKey; + } + + public Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) { + this.authenticated = true; + this.accountIdentifier = accountIdentifier; + this.deviceId = deviceId; + this.ecKeyPair = ecKeyPair; + this.websocketUri = AUTHENTICATED_WEBSOCKET_URI; + return this; + } + + public Builder setWebsocketUri(final URI websocketUri) { + this.websocketUri = websocketUri; + return this; + } + + public Builder setUseTls(X509Certificate trustedServerCertificate) { + this.useTls = true; + this.trustedServerCertificate = trustedServerCertificate; + return this; + } + + public Builder setProxyMessageSupplier(Supplier proxyMessageSupplier) { + this.proxyMessageSupplier = proxyMessageSupplier; + return this; + } + + public Builder setHeaders(final HttpHeaders headers) { + this.headers = headers; + return this; + } + + public Builder setServerPublicKey(ECPublicKey serverPublicKey) { + this.serverPublicKey = serverPublicKey; + return this; + } + + public Builder setFramingType(FramingType framingType) { + this.framingType = framingType; + return this; + } + + public NoiseTunnelClient build() { + final List handlers = new ArrayList<>(); + if (proxyMessageSupplier != null) { + handlers.addAll(List.of(HAProxyMessageEncoder.INSTANCE, new HAProxyMessageSender(proxyMessageSupplier))); + } + if (useTls) { + final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient(); + + if (trustedServerCertificate != null) { + sslContextBuilder.trustManager(trustedServerCertificate); + } + + try { + handlers.add(sslContextBuilder.build().newHandler(ByteBufAllocator.DEFAULT)); + } catch (SSLException e) { + throw new IllegalArgumentException(e); + } + } + + // handles the wrapping and unrwrapping the framing layer (websockets or noisedirect) + handlers.addAll(switch (framingType) { + case WEBSOCKET -> websocketHandlerStack(websocketUri, headers); + case NOISE_DIRECT -> noiseDirectHandlerStack(authenticated); + }); + + final NoiseClientHandshakeHelper helper = authenticated + ? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair) + : NoiseClientHandshakeHelper.NK(serverPublicKey); + + handlers.add(new NoiseClientHandshakeHandler(helper)); + + // Whenever the framing layer sends or receives a close frame, it will emit a CloseFrameEvent and we'll save off + // information about why the connection was closed. + final UserEventFuture closeEventHandler = new UserEventFuture<>(CloseFrameEvent.class); + handlers.add(closeEventHandler); + + final NoiseTunnelClient client = + new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, fastOpenRequest -> new EstablishRemoteConnectionHandler( + handlers, + authenticated ? new AuthenticatedDevice(accountIdentifier, deviceId) : null, + remoteServerAddress, + fastOpenRequest)); + client.start(); + return client; + } + } + + private NoiseTunnelClient(NioEventLoopGroup eventLoopGroup, + CompletableFuture closeEventFuture, + Function handler) { + + this.closeEventFuture = closeEventFuture; + this.serverBootstrap = new ServerBootstrap() + .localAddress(new LocalAddress("websocket-noise-tunnel-client")) + .channel(LocalServerChannel.class) + .group(eventLoopGroup) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(final LocalChannel localChannel) { + localChannel.pipeline() + // We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the + // stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put + // in the handshake. + .addLast(Http2Buffering.handler()) + // Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At + // that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually + // connect to the remote service + .addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { + if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) { + byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest()); + requestBufferedEvent.fastOpenRequest().release(); + ctx.pipeline().addLast(handler.apply(fastOpenRequest)); + } + super.userEventTriggered(ctx, evt); + } + }) + .addLast(new ClientErrorHandler()); + } + }); + } + + private static class UserEventFuture extends ChannelInboundHandlerAdapter { + private final CompletableFuture future = new CompletableFuture<>(); + private final Class cls; + + UserEventFuture(Class cls) { + this.cls = cls; + } + + @Override + public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { + if (cls.isInstance(evt)) { + future.complete((T) evt); + } + ctx.fireUserEventTriggered(evt); + } + } + + + public LocalAddress getLocalAddress() { + return (LocalAddress) serverChannel.localAddress(); + } + + private NoiseTunnelClient start() { + serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel(); + return this; + } + + @Override + public void close() throws InterruptedException { + serverChannel.close().await(); + } + + /** + * @return A future that completes when a close frame is observed + */ + public CompletableFuture closeFrameFuture() { + return closeEventFuture; + } + + private static List noiseDirectHandlerStack(boolean authenticated) { + return List.of( + new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2), + new NoiseDirectFrameCodec(), + new ChannelDuplexHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent()); + ctx.fireChannelActive(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) { + try { + final NoiseDirectProtos.Error errorPayload = + NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content())); + ctx.fireUserEventTriggered( + CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.SERVER)); + } finally { + ReferenceCountUtil.release(msg); + } + } else { + ctx.fireChannelRead(msg); + } + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) { + final NoiseDirectProtos.Error errorPayload = + NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content())); + ctx.fireUserEventTriggered( + CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT)); + } + ctx.write(msg, promise); + } + }, + new MessageToMessageCodec() { + boolean noiseHandshakeFinished = false; + + @Override + protected void encode(final ChannelHandlerContext ctx, final ByteBuf msg, final List out) { + final NoiseDirectFrame.FrameType frameType = noiseHandshakeFinished + ? NoiseDirectFrame.FrameType.DATA + : (authenticated ? NoiseDirectFrame.FrameType.IK_HANDSHAKE : NoiseDirectFrame.FrameType.NK_HANDSHAKE); + noiseHandshakeFinished = true; + out.add(new NoiseDirectFrame(frameType, msg.retain())); + } + + @Override + protected void decode(final ChannelHandlerContext ctx, final NoiseDirectFrame msg, + final List out) { + out.add(msg.content().retain()); + } + }); + } + + private static List websocketHandlerStack(final URI websocketUri, final HttpHeaders headers) { + return List.of( + new HttpClientCodec(), + new HttpObjectAggregator(Noise.MAX_PACKET_LEN), + // Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we + // want to react to them on our own, we need to catch them before they hit that handler. + new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + if (message instanceof CloseWebSocketFrame closeWebSocketFrame) { + context.fireUserEventTriggered( + CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.SERVER)); + } + + super.channelRead(context, message); + } + }, + new WebSocketClientProtocolHandler(websocketUri, + WebSocketVersion.V13, + null, + false, + headers, + Noise.MAX_PACKET_LEN, + 10_000), + new ChannelOutboundHandlerAdapter() { + @Override + public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception { + if (message instanceof CloseWebSocketFrame closeWebSocketFrame) { + context.fireUserEventTriggered( + CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.CLIENT)); + } + super.write(context, message, promise); + } + }, + new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) { + if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) { + if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + context.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent()); + } + } + context.fireUserEventTriggered(event); + } + }, + new WebsocketPayloadCodec()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ReadyForNoiseHandshakeEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ReadyForNoiseHandshakeEvent.java new file mode 100644 index 000000000..90478599e --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ReadyForNoiseHandshakeEvent.java @@ -0,0 +1,4 @@ +package org.whispersystems.textsecuregcm.grpc.net.client; + +public record ReadyForNoiseHandshakeEvent() { +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/DirectNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/DirectNoiseTunnelServerIntegrationTest.java new file mode 100644 index 000000000..f2d876c1b --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/DirectNoiseTunnelServerIntegrationTest.java @@ -0,0 +1,49 @@ +package org.whispersystems.textsecuregcm.grpc.net.noisedirect; + +import io.netty.channel.local.LocalAddress; +import io.netty.channel.nio.NioEventLoopGroup; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; +import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient; +import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; + +import java.util.concurrent.Executor; + +class DirectNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest { + private NoiseDirectTunnelServer noiseDirectTunnelServer; + + @Override + protected void start( + final NioEventLoopGroup eventLoopGroup, + final Executor delegatedTaskExecutor, + final GrpcClientConnectionManager grpcClientConnectionManager, + final ClientPublicKeysManager clientPublicKeysManager, + final ECKeyPair serverKeyPair, + final LocalAddress authenticatedGrpcServerAddress, + final LocalAddress anonymousGrpcServerAddress, + final String recognizedProxySecret) throws Exception { + + noiseDirectTunnelServer = new NoiseDirectTunnelServer(0, + eventLoopGroup, + grpcClientConnectionManager, + clientPublicKeysManager, + serverKeyPair, + authenticatedGrpcServerAddress, + anonymousGrpcServerAddress); + noiseDirectTunnelServer.start(); + } + + @Override + protected void stop() throws InterruptedException { + noiseDirectTunnelServer.stop(); + } + + @Override + protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) { + return new NoiseTunnelClient + .Builder(noiseDirectTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey) + .setFramingType(NoiseTunnelClient.FramingType.NOISE_DIRECT); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandlerTest.java similarity index 94% rename from service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandlerTest.java index da312c719..1765e9c14 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandlerTest.java @@ -1,4 +1,4 @@ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.websocket; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -18,6 +18,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest; class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/TlsWebSocketNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/TlsWebSocketNoiseTunnelServerIntegrationTest.java new file mode 100644 index 000000000..04a68714e --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/TlsWebSocketNoiseTunnelServerIntegrationTest.java @@ -0,0 +1,237 @@ +package org.whispersystems.textsecuregcm.grpc.net.websocket; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.grpc.ManagedChannel; +import io.grpc.Status; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; +import java.io.ByteArrayInputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.security.KeyFactory; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.SecureRandom; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.Base64; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeoutException; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; +import org.junit.jupiter.api.Test; +import org.signal.chat.rpc.GetRequestAttributesRequest; +import org.signal.chat.rpc.GetRequestAttributesResponse; +import org.signal.chat.rpc.RequestAttributesGrpc; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; +import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; +import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent; +import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; + +class TlsWebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest { + private NoiseWebSocketTunnelServer tlsNoiseWebSocketTunnelServer; + private X509Certificate serverTlsCertificate; + + + // Please note that this certificate/key are used only for testing and are not used anywhere outside of this test. + // They were generated with: + // + // ```shell + // openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost" + // ``` + private static final String SERVER_CERTIFICATE = """ + -----BEGIN CERTIFICATE----- + MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw + FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx + MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA + IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV + jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq + SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME + GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG + SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw + XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi + iOr9sHiO8Rn2u0xRKgU5Ig== + -----END CERTIFICATE----- + """; + + // BEGIN/END PRIVATE KEY header/footer removed for easier parsing + private static final String SERVER_PRIVATE_KEY = """ + MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj + kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd + PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA + O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo= + """; + @Override + protected void start( + final NioEventLoopGroup eventLoopGroup, + final Executor delegatedTaskExecutor, + final GrpcClientConnectionManager grpcClientConnectionManager, + final ClientPublicKeysManager clientPublicKeysManager, + final ECKeyPair serverKeyPair, + final LocalAddress authenticatedGrpcServerAddress, + final LocalAddress anonymousGrpcServerAddress, + final String recognizedProxySecret) throws Exception { + final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); + serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate( + new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8))); + final PrivateKey serverTlsPrivateKey; + final KeyFactory keyFactory = KeyFactory.getInstance("EC"); + serverTlsPrivateKey = + keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY))); + tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0, + new X509Certificate[]{serverTlsCertificate}, + serverTlsPrivateKey, + eventLoopGroup, + delegatedTaskExecutor, + grpcClientConnectionManager, + clientPublicKeysManager, + serverKeyPair, + authenticatedGrpcServerAddress, + anonymousGrpcServerAddress, + recognizedProxySecret); + tlsNoiseWebSocketTunnelServer.start(); + } + + @Override + protected void stop() throws InterruptedException { + tlsNoiseWebSocketTunnelServer.stop(); + } + + @Override + protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, + final ECPublicKey serverPublicKey) { + return new NoiseTunnelClient + .Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey) + .setUseTls(serverTlsCertificate); + } + + @Test + void getRequestAttributes() throws InterruptedException { + final String remoteAddress = "4.5.6.7"; + final String acceptLanguage = "en"; + final String userAgent = "Signal-Desktop/1.2.3 Linux"; + + final HttpHeaders headers = new DefaultHttpHeaders() + .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) + .add("X-Forwarded-For", remoteAddress) + .add("Accept-Language", acceptLanguage) + .add("User-Agent", userAgent); + + try (final NoiseTunnelClient client = anonymous().setHeaders(headers).build()) { + + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()); + + assertEquals(remoteAddress, response.getRemoteAddress()); + assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList()); + assertEquals(userAgent, response.getUserAgent()); + } finally { + channel.shutdown(); + } + } + } + + @Test + void connectAuthenticatedToAnonymousService() throws InterruptedException, ExecutionException, TimeoutException { + try (final NoiseTunnelClient client = authenticated() + .setWebsocketUri(NoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI) + .build()) { + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); + } + } + + + @Test + void connectAnonymousToAuthenticatedService() throws InterruptedException, ExecutionException, TimeoutException { + try (final NoiseTunnelClient client = anonymous() + .setWebsocketUri(NoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI) + .build()) { + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> RequestAttributesGrpc.newBlockingStub(channel) + .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); + } + } + + @Test + void rejectIllegalRequests() throws Exception { + + final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null, null); + keyStore.setCertificateEntry("tunnel", serverTlsCertificate); + + final TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + + trustManagerFactory.init(keyStore); + + final SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom()); + + final URI authenticatedUri = + new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/authenticated", + null, null); + + final URI incorrectUri = + new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/incorrect", + null, null); + + try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) { + assertEquals(405, httpClient.send(HttpRequest.newBuilder() + .uri(authenticatedUri) + .PUT(HttpRequest.BodyPublishers.ofString("test")) + .build(), + HttpResponse.BodyHandlers.ofString()).statusCode(), + "Non-GET requests should not be allowed"); + + assertEquals(426, httpClient.send(HttpRequest.newBuilder() + .GET() + .uri(authenticatedUri) + .build(), + HttpResponse.BodyHandlers.ofString()).statusCode(), + "GET requests without upgrade headers should not be allowed"); + + assertEquals(404, httpClient.send(HttpRequest.newBuilder() + .GET() + .uri(incorrectUri) + .build(), + HttpResponse.BodyHandlers.ofString()).statusCode(), + "GET requests to unrecognized URIs should not be allowed"); + } + } + + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketNoiseTunnelServerIntegrationTest.java new file mode 100644 index 000000000..30ef53022 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketNoiseTunnelServerIntegrationTest.java @@ -0,0 +1,52 @@ +package org.whispersystems.textsecuregcm.grpc.net.websocket; + +import io.netty.channel.local.LocalAddress; +import io.netty.channel.nio.NioEventLoopGroup; +import java.util.concurrent.Executor; +import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; +import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage; +import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; + +class WebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest { + private NoiseWebSocketTunnelServer plaintextNoiseWebSocketTunnelServer; + + @Override + protected void start( + final NioEventLoopGroup eventLoopGroup, + final Executor delegatedTaskExecutor, + final GrpcClientConnectionManager grpcClientConnectionManager, + final ClientPublicKeysManager clientPublicKeysManager, + final ECKeyPair serverKeyPair, + final LocalAddress authenticatedGrpcServerAddress, + final LocalAddress anonymousGrpcServerAddress, + final String recognizedProxySecret) throws Exception { + plaintextNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0, + null, + null, + eventLoopGroup, + delegatedTaskExecutor, + grpcClientConnectionManager, + clientPublicKeysManager, + serverKeyPair, + authenticatedGrpcServerAddress, + anonymousGrpcServerAddress, + recognizedProxySecret); + plaintextNoiseWebSocketTunnelServer.start(); + } + + @Override + protected void stop() throws InterruptedException { + plaintextNoiseWebSocketTunnelServer.stop(); + } + + @Override + protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) { + return new NoiseTunnelClient + .Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandlerTest.java similarity index 96% rename from service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandlerTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandlerTest.java index 1dd906bb2..8bd1fad4f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandlerTest.java @@ -1,4 +1,4 @@ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.websocket; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; @@ -19,6 +19,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest; class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandlerTest.java similarity index 96% rename from service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandlerTest.java index 54b796ced..97aeef600 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandlerTest.java @@ -1,4 +1,4 @@ -package org.whispersystems.textsecuregcm.grpc.net; +package org.whispersystems.textsecuregcm.grpc.net.websocket; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -34,6 +34,10 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.signal.libsignal.protocol.ecc.Curve; import org.whispersystems.textsecuregcm.grpc.RequestAttributes; +import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; +import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler; +import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {