Add NoiseDirect framing protocol
This commit is contained in:
parent
e285bf1a52
commit
0398e02690
|
@ -154,7 +154,7 @@ import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
|
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
|
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup;
|
import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseWebSocketTunnelServer;
|
import org.whispersystems.textsecuregcm.grpc.net.websocket.NoiseWebSocketTunnelServer;
|
||||||
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
|
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
|
||||||
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
|
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
|
||||||
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
|
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Indicates that an attempt to authenticate a remote client failed for some reason.
|
* Indicates that an attempt to authenticate a remote client failed for some reason.
|
||||||
*/
|
*/
|
||||||
class ClientAuthenticationException extends Exception {
|
public class ClientAuthenticationException extends NoStackTraceException {
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,62 +1,46 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
import io.netty.channel.ChannelFutureListener;
|
import io.netty.channel.ChannelFutureListener;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
|
||||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
|
||||||
import javax.crypto.BadPaddingException;
|
import javax.crypto.BadPaddingException;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a
|
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. It translates exceptions
|
||||||
* WebSocket handshake, the error handler will send appropriate WebSocket closure codes to the client in an attempt to
|
* thrown in inbound handlers into {@link OutboundCloseErrorMessage}s.
|
||||||
* identify the problem. If the client has not completed a WebSocket handshake, the handler simply closes the
|
|
||||||
* connection.
|
|
||||||
*/
|
*/
|
||||||
class ErrorHandler extends ChannelInboundHandlerAdapter {
|
public class ErrorHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
private boolean websocketHandshakeComplete = false;
|
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
|
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
|
||||||
|
|
||||||
@Override
|
private static OutboundCloseErrorMessage UNAUTHENTICATED_CLOSE = new OutboundCloseErrorMessage(
|
||||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
OutboundCloseErrorMessage.Code.AUTHENTICATION_ERROR,
|
||||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
|
"Not authenticated");
|
||||||
setWebsocketHandshakeComplete();
|
private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage(
|
||||||
}
|
OutboundCloseErrorMessage.Code.NOISE_ERROR,
|
||||||
|
"Noise encryption error");
|
||||||
context.fireUserEventTriggered(event);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void setWebsocketHandshakeComplete() {
|
|
||||||
this.websocketHandshakeComplete = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
|
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
|
||||||
if (websocketHandshakeComplete) {
|
final OutboundCloseErrorMessage closeMessage = switch (ExceptionUtils.unwrap(cause)) {
|
||||||
final WebSocketCloseStatus webSocketCloseStatus = switch (ExceptionUtils.unwrap(cause)) {
|
case NoiseHandshakeException e -> new OutboundCloseErrorMessage(
|
||||||
case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage());
|
OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR,
|
||||||
case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated");
|
e.getMessage());
|
||||||
case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
|
case ClientAuthenticationException ignored -> UNAUTHENTICATED_CLOSE;
|
||||||
case NoiseException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
|
case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
||||||
|
case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
||||||
default -> {
|
default -> {
|
||||||
log.warn("An unexpected exception reached the end of the pipeline", cause);
|
log.warn("An unexpected exception reached the end of the pipeline", cause);
|
||||||
yield WebSocketCloseStatus.INTERNAL_SERVER_ERROR;
|
yield new OutboundCloseErrorMessage(
|
||||||
|
OutboundCloseErrorMessage.Code.INTERNAL_SERVER_ERROR,
|
||||||
|
cause.getMessage());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus))
|
context.writeAndFlush(closeMessage)
|
||||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||||
} else {
|
|
||||||
log.debug("Error occurred before websocket handshake complete", cause);
|
|
||||||
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
|
|
||||||
// way; just close the connection instead.
|
|
||||||
context.close();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,8 +7,6 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
import io.netty.channel.ChannelInitializer;
|
import io.netty.channel.ChannelInitializer;
|
||||||
import io.netty.channel.local.LocalAddress;
|
import io.netty.channel.local.LocalAddress;
|
||||||
import io.netty.channel.local.LocalChannel;
|
import io.netty.channel.local.LocalChannel;
|
||||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -22,7 +20,7 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
* any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC
|
* any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC
|
||||||
* server.
|
* server.
|
||||||
*/
|
*/
|
||||||
class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||||
|
|
||||||
|
@ -79,7 +77,9 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||||
// Close the local connection if the remote channel closes and vice versa
|
// Close the local connection if the remote channel closes and vice versa
|
||||||
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
|
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
|
||||||
localChannelFuture.channel().closeFuture().addListener(closeFuture ->
|
localChannelFuture.channel().closeFuture().addListener(closeFuture ->
|
||||||
remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART)));
|
remoteChannelContext.channel()
|
||||||
|
.write(new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"))
|
||||||
|
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
|
||||||
|
|
||||||
remoteChannelContext.pipeline()
|
remoteChannelContext.pipeline()
|
||||||
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));
|
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));
|
||||||
|
|
|
@ -7,7 +7,6 @@ import io.netty.channel.Channel;
|
||||||
import io.netty.channel.ChannelFutureListener;
|
import io.netty.channel.ChannelFutureListener;
|
||||||
import io.netty.channel.local.LocalAddress;
|
import io.netty.channel.local.LocalAddress;
|
||||||
import io.netty.channel.local.LocalChannel;
|
import io.netty.channel.local.LocalChannel;
|
||||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
|
||||||
import io.netty.util.AttributeKey;
|
import io.netty.util.AttributeKey;
|
||||||
import java.net.InetAddress;
|
import java.net.InetAddress;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -63,6 +62,9 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
||||||
static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
|
static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
|
||||||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
|
AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
|
||||||
|
|
||||||
|
private static OutboundCloseErrorMessage SERVER_CLOSED =
|
||||||
|
new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed");
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
|
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -161,9 +163,7 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void closeRemoteChannel(final Channel channel) {
|
private static void closeRemoteChannel(final Channel channel) {
|
||||||
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
|
channel.writeAndFlush(SERVER_CLOSED).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||||
.toWebSocketCloseStatus("Reauthentication required")))
|
|
||||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
|
@ -198,16 +198,16 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
|
* Handles receipt of a handshake message and associates attributes and headers from the handshake
|
||||||
* request with the channel via which the handshake took place.
|
* request with the channel via which the handshake took place.
|
||||||
*
|
*
|
||||||
* @param channel the channel that completed a WebSocket handshake
|
* @param channel the channel where the handshake was initiated
|
||||||
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
|
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
|
||||||
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
|
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
|
||||||
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
|
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
|
||||||
* {@code null}
|
* {@code null}
|
||||||
*/
|
*/
|
||||||
static void handleHandshakeComplete(final Channel channel,
|
public static void handleHandshakeInitiated(final Channel channel,
|
||||||
final InetAddress preferredRemoteAddress,
|
final InetAddress preferredRemoteAddress,
|
||||||
@Nullable final String userAgentHeader,
|
@Nullable final String userAgentHeader,
|
||||||
@Nullable final String acceptLanguageHeader) {
|
@Nullable final String acceptLanguageHeader) {
|
||||||
|
@ -227,11 +227,10 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handles successful establishment of a Noise-over-WebSocket connection from a remote client to a local gRPC server.
|
* Handles successful establishment of a Noise connection from a remote client to a local gRPC server.
|
||||||
*
|
*
|
||||||
* @param localChannel the newly-opened local channel between the Noise-over-WebSocket tunnel and the local gRPC
|
* @param localChannel the newly-opened local channel between the Noise tunnel and the local gRPC server
|
||||||
* server
|
* @param remoteChannel the channel from the remote client to the Noise tunnel
|
||||||
* @param remoteChannel the channel from the remote client to the Noise-over-WebSocket tunnel
|
|
||||||
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
|
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
|
||||||
*/
|
*/
|
||||||
void handleConnectionEstablished(final LocalChannel localChannel,
|
void handleConnectionEstablished(final LocalChannel localChannel,
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
enum HandshakePattern {
|
public enum HandshakePattern {
|
||||||
NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
|
NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
|
||||||
IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");
|
IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
* Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent}
|
* Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent}
|
||||||
* indicating that initiator connected anonymously.
|
* indicating that initiator connected anonymously.
|
||||||
*/
|
*/
|
||||||
class NoiseAnonymousHandler extends NoiseHandler {
|
public class NoiseAnonymousHandler extends NoiseHandler {
|
||||||
|
|
||||||
public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) {
|
public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) {
|
||||||
super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair));
|
super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair));
|
||||||
|
|
|
@ -32,11 +32,11 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||||
* <p>
|
* <p>
|
||||||
* As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}.
|
* As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}.
|
||||||
*/
|
*/
|
||||||
class NoiseAuthenticatedHandler extends NoiseHandler {
|
public class NoiseAuthenticatedHandler extends NoiseHandler {
|
||||||
|
|
||||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||||
|
|
||||||
NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
public NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
||||||
final ECKeyPair ecKeyPair) {
|
final ECKeyPair ecKeyPair) {
|
||||||
super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair));
|
super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair));
|
||||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/
|
* Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/
|
||||||
* format or a general encryption error).
|
* format or a general encryption error).
|
||||||
*/
|
*/
|
||||||
class NoiseException extends Exception {
|
public class NoiseException extends NoStackTraceException {
|
||||||
public NoiseException(final String message) {
|
public NoiseException(final String message) {
|
||||||
super(message);
|
super(message);
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,13 +26,14 @@ import javax.crypto.ShortBufferException;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
|
||||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts
|
* A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts
|
||||||
* inbound messages, and encrypts outbound messages
|
* inbound messages, and encrypts outbound messages
|
||||||
*/
|
*/
|
||||||
abstract class NoiseHandler extends ChannelDuplexHandler {
|
public abstract class NoiseHandler extends ChannelDuplexHandler {
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
|
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
|
||||||
|
|
||||||
|
@ -82,17 +83,16 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
|
||||||
@Override
|
@Override
|
||||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
try {
|
try {
|
||||||
if (message instanceof BinaryWebSocketFrame frame) {
|
if (message instanceof ByteBuf frame) {
|
||||||
if (frame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
if (frame.readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||||
final String error = "Invalid noise message length " + frame.content().readableBytes();
|
final String error = "Invalid noise message length " + frame.readableBytes();
|
||||||
throw state == State.HANDSHAKE ? new NoiseHandshakeException(error) : new NoiseException(error);
|
throw state == State.HANDSHAKE ? new NoiseHandshakeException(error) : new NoiseException(error);
|
||||||
}
|
}
|
||||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||||
// We'll need to copy it to a heap buffer.
|
// We'll need to copy it to a heap buffer.
|
||||||
handleInboundMessage(context, ByteBufUtil.getBytes(frame.content()));
|
handleInboundMessage(context, ByteBufUtil.getBytes(frame));
|
||||||
} else {
|
} else {
|
||||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
// Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error
|
||||||
// error
|
|
||||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
@ -122,7 +122,7 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
|
||||||
|
|
||||||
// Now that we've authenticated, write the handshake response
|
// Now that we've authenticated, write the handshake response
|
||||||
byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES);
|
byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES);
|
||||||
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
|
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
|
||||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||||
|
|
||||||
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
|
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
|
||||||
|
@ -193,16 +193,16 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
|
||||||
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
||||||
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
||||||
|
|
||||||
pc.add(context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer))));
|
pc.add(context.write(Unpooled.wrappedBuffer(noiseBuffer)));
|
||||||
}
|
}
|
||||||
pc.finish(promise);
|
pc.finish(promise);
|
||||||
} finally {
|
} finally {
|
||||||
ReferenceCountUtil.release(byteBuf);
|
ReferenceCountUtil.release(byteBuf);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!(message instanceof WebSocketFrame)) {
|
if (!(message instanceof OutboundCloseErrorMessage)) {
|
||||||
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
|
// Downstream handlers may write OutboundCloseErrorMessages that don't need to be encrypted (e.g. "close" frames
|
||||||
// get issued in response to exceptions)
|
// that get issued in response to exceptions)
|
||||||
log.warn("Unexpected object in pipeline: {}", message);
|
log.warn("Unexpected object in pipeline: {}", message);
|
||||||
}
|
}
|
||||||
context.write(message, promise);
|
context.write(message, promise);
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
|
* Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
|
||||||
* a general encryption error).
|
* a general encryption error).
|
||||||
*/
|
*/
|
||||||
class NoiseHandshakeException extends Exception {
|
public class NoiseHandshakeException extends NoStackTraceException {
|
||||||
|
|
||||||
public NoiseHandshakeException(final String message) {
|
public NoiseHandshakeException(final String message) {
|
||||||
super(message);
|
super(message);
|
||||||
|
|
|
@ -10,4 +10,4 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
|
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
|
||||||
* type that performs authentication
|
* type that performs authentication
|
||||||
*/
|
*/
|
||||||
record NoiseIdentityDeterminedEvent(Optional<AuthenticatedDevice> authenticatedDevice) {}
|
public record NoiseIdentityDeterminedEvent(Optional<AuthenticatedDevice> authenticatedDevice) {}
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An error written to the outbound pipeline that indicates the connection should be closed
|
||||||
|
*/
|
||||||
|
public record OutboundCloseErrorMessage(Code code, String message) {
|
||||||
|
public enum Code {
|
||||||
|
/**
|
||||||
|
* The server decided to close the connection. This could be because the server is going away, or it could be
|
||||||
|
* because the credentials for the connected client have been updated.
|
||||||
|
*/
|
||||||
|
SERVER_CLOSED,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* There was a noise decryption error after the noise session was established
|
||||||
|
*/
|
||||||
|
NOISE_ERROR,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* There was an error establishing the noise handshake
|
||||||
|
*/
|
||||||
|
NOISE_HANDSHAKE_ERROR,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The provided credentials were not valid
|
||||||
|
*/
|
||||||
|
AUTHENTICATION_ERROR,
|
||||||
|
|
||||||
|
INTERNAL_SERVER_ERROR
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,7 +8,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
/**
|
/**
|
||||||
* A proxy handler writes all data read from one channel to another peer channel.
|
* A proxy handler writes all data read from one channel to another peer channel.
|
||||||
*/
|
*/
|
||||||
class ProxyHandler extends ChannelInboundHandlerAdapter {
|
public class ProxyHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
private final Channel peerChannel;
|
private final Channel peerChannel;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In the inbound direction, this handler strips the NoiseDirectFrame wrapper we read off the wire and then forwards the
|
||||||
|
* noise packet to the noise layer as a {@link ByteBuf} for decryption.
|
||||||
|
* <p>
|
||||||
|
* In the outbound direction, this handler wraps encrypted noise packet {@link ByteBuf}s in a NoiseDirectFrame wrapper
|
||||||
|
* so it can be wire serialized. This handler assumes the first outbound message received will correspond to the
|
||||||
|
* handshake response, and then the subsequent messages are all data frame payloads.
|
||||||
|
*/
|
||||||
|
public class NoiseDirectDataFrameCodec extends ChannelDuplexHandler {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||||
|
if (msg instanceof NoiseDirectFrame frame) {
|
||||||
|
if (frame.frameType() != NoiseDirectFrame.FrameType.DATA) {
|
||||||
|
ReferenceCountUtil.release(msg);
|
||||||
|
throw new NoiseException("Invalid frame type received (expected DATA): " + frame.frameType());
|
||||||
|
}
|
||||||
|
ctx.fireChannelRead(frame.content());
|
||||||
|
} else {
|
||||||
|
ctx.fireChannelRead(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||||
|
if (msg instanceof ByteBuf bb) {
|
||||||
|
ctx.write(new NoiseDirectFrame(NoiseDirectFrame.FrameType.DATA, bb), promise);
|
||||||
|
} else {
|
||||||
|
ctx.write(msg, promise);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,71 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.DefaultByteBufHolder;
|
||||||
|
|
||||||
|
public class NoiseDirectFrame extends DefaultByteBufHolder {
|
||||||
|
|
||||||
|
static final byte VERSION = 0x00;
|
||||||
|
|
||||||
|
private final FrameType frameType;
|
||||||
|
|
||||||
|
public NoiseDirectFrame(final FrameType frameType, final ByteBuf data) {
|
||||||
|
super(data);
|
||||||
|
this.frameType = frameType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FrameType frameType() {
|
||||||
|
return frameType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public byte versionedFrameTypeByte() {
|
||||||
|
final byte frameBits = frameType().getFrameBits();
|
||||||
|
return (byte) ((NoiseDirectFrame.VERSION << 4) | frameBits);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public enum FrameType {
|
||||||
|
/**
|
||||||
|
* The payload is the initiator message or the responder message for a Noise NK handshake. If established, the
|
||||||
|
* session will be unauthenticated.
|
||||||
|
*/
|
||||||
|
NK_HANDSHAKE((byte) 1),
|
||||||
|
/**
|
||||||
|
* The payload is the initiator message or the responder message for a Noise IK handshake. If established, the
|
||||||
|
* session will be authenticated.
|
||||||
|
*/
|
||||||
|
IK_HANDSHAKE((byte) 2),
|
||||||
|
/**
|
||||||
|
* The payload is an encrypted noise packet.
|
||||||
|
*/
|
||||||
|
DATA((byte) 3),
|
||||||
|
/**
|
||||||
|
* A framing layer error occurred. The payload carries error details.
|
||||||
|
*/
|
||||||
|
ERROR((byte) 4);
|
||||||
|
|
||||||
|
private final byte frameType;
|
||||||
|
|
||||||
|
FrameType(byte frameType) {
|
||||||
|
if (frameType != (0x0F & frameType)) {
|
||||||
|
throw new IllegalStateException("Frame type must fit in 4 bits");
|
||||||
|
}
|
||||||
|
this.frameType = frameType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public byte getFrameBits() {
|
||||||
|
return frameType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isHandshake() {
|
||||||
|
return switch (this) {
|
||||||
|
case IK_HANDSHAKE, NK_HANDSHAKE -> true;
|
||||||
|
case DATA, ERROR -> false;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,90 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles conversion between bytes on the wire and {@link NoiseDirectFrame}s. This handler assumes that inbound bytes
|
||||||
|
* have already been framed using a {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder}
|
||||||
|
*/
|
||||||
|
public class NoiseDirectFrameCodec extends ChannelDuplexHandler {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||||
|
if (msg instanceof ByteBuf byteBuf) {
|
||||||
|
try {
|
||||||
|
ctx.fireChannelRead(deserialize(byteBuf));
|
||||||
|
} catch (Exception e) {
|
||||||
|
ReferenceCountUtil.release(byteBuf);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ctx.fireChannelRead(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||||
|
if (msg instanceof NoiseDirectFrame noiseDirectFrame) {
|
||||||
|
try {
|
||||||
|
// Serialize the frame into a newly allocated direct buffer. Since this is the last handler before the
|
||||||
|
// network, nothing should have to make another copy of this. If later another layer is added, it may be more
|
||||||
|
// efficient to reuse the input buffer (typically not direct) by using a composite byte buffer
|
||||||
|
final ByteBuf serialized = serialize(ctx, noiseDirectFrame);
|
||||||
|
ctx.writeAndFlush(serialized, promise);
|
||||||
|
} finally {
|
||||||
|
ReferenceCountUtil.release(noiseDirectFrame);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ctx.write(msg, promise);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ByteBuf serialize(
|
||||||
|
final ChannelHandlerContext ctx,
|
||||||
|
final NoiseDirectFrame noiseDirectFrame) {
|
||||||
|
if (noiseDirectFrame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||||
|
throw new IllegalStateException("Payload too long: " + noiseDirectFrame.content().readableBytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1 version/frametype byte, 2 length bytes, content
|
||||||
|
final ByteBuf byteBuf = ctx.alloc().buffer(1 + 2 + noiseDirectFrame.content().readableBytes());
|
||||||
|
|
||||||
|
byteBuf.writeByte(noiseDirectFrame.versionedFrameTypeByte());
|
||||||
|
byteBuf.writeShort(noiseDirectFrame.content().readableBytes());
|
||||||
|
byteBuf.writeBytes(noiseDirectFrame.content());
|
||||||
|
return byteBuf;
|
||||||
|
}
|
||||||
|
|
||||||
|
private NoiseDirectFrame deserialize(final ByteBuf byteBuf) throws Exception {
|
||||||
|
final byte versionAndFrameByte = byteBuf.readByte();
|
||||||
|
final int version = (versionAndFrameByte & 0xF0) >> 4;
|
||||||
|
if (version != NoiseDirectFrame.VERSION) {
|
||||||
|
throw new NoiseHandshakeException("Invalid NoiseDirect version: " + version);
|
||||||
|
}
|
||||||
|
final byte frameTypeBits = (byte) (versionAndFrameByte & 0x0F);
|
||||||
|
final NoiseDirectFrame.FrameType frameType = switch (frameTypeBits) {
|
||||||
|
case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE;
|
||||||
|
case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE;
|
||||||
|
case 3 -> NoiseDirectFrame.FrameType.DATA;
|
||||||
|
case 4 -> NoiseDirectFrame.FrameType.ERROR;
|
||||||
|
default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits);
|
||||||
|
};
|
||||||
|
|
||||||
|
final int length = Short.toUnsignedInt(byteBuf.readShort());
|
||||||
|
if (length != byteBuf.readableBytes()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Payload length did not match remaining buffer, should have been guaranteed by a previous handler");
|
||||||
|
}
|
||||||
|
return new NoiseDirectFrame(frameType, byteBuf.readSlice(length));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.InetSocketAddress;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Waits for a Handshake {@link NoiseDirectFrame} and then installs a {@link NoiseDirectDataFrameCodec} and
|
||||||
|
* {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} and removes itself
|
||||||
|
*/
|
||||||
|
public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||||
|
private final ECKeyPair ecKeyPair;
|
||||||
|
|
||||||
|
public NoiseDirectHandshakeSelector(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) {
|
||||||
|
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||||
|
this.ecKeyPair = ecKeyPair;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||||
|
if (msg instanceof NoiseDirectFrame frame) {
|
||||||
|
try {
|
||||||
|
// We've received an inbound handshake frame so we know what kind of NoiseHandler we need (authenticated or
|
||||||
|
// anonymous). We construct it here, and then remember the handshake type so we can annotate our handshake
|
||||||
|
// response with the correct frame type whenever we receive it.
|
||||||
|
final ChannelDuplexHandler noiseHandler = switch (frame.frameType()) {
|
||||||
|
case DATA, ERROR ->
|
||||||
|
throw new NoiseHandshakeException("Invalid frame type for first message " + frame.frameType());
|
||||||
|
case IK_HANDSHAKE -> new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair);
|
||||||
|
case NK_HANDSHAKE -> new NoiseAnonymousHandler(ecKeyPair);
|
||||||
|
};
|
||||||
|
if (ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) {
|
||||||
|
// TODO: Provide connection metadata / headers in handshake payload
|
||||||
|
GrpcClientConnectionManager.handleHandshakeInitiated(ctx.channel(),
|
||||||
|
inetSocketAddress.getAddress(),
|
||||||
|
"NoiseDirect",
|
||||||
|
"");
|
||||||
|
|
||||||
|
} else {
|
||||||
|
throw new IOException("Could not determine remote address");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames
|
||||||
|
// should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded
|
||||||
|
// for network serialization. Note that we need to install the Data frame handler before firing the read,
|
||||||
|
// because we may receive an outbound message from the noiseHandler
|
||||||
|
ctx.pipeline().addAfter(ctx.name(), null, noiseHandler);
|
||||||
|
ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec());
|
||||||
|
ctx.fireChannelRead(frame.content());
|
||||||
|
} catch (Exception e) {
|
||||||
|
ReferenceCountUtil.release(msg);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ctx.fireChannelRead(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,39 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufOutputStream;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Translates {@link OutboundCloseErrorMessage}s into {@link NoiseDirectFrame} error frames. After error frames are
|
||||||
|
* written, the channel is closed
|
||||||
|
*/
|
||||||
|
class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||||
|
if (msg instanceof OutboundCloseErrorMessage err) {
|
||||||
|
final NoiseDirectProtos.Error.Type type = switch (err.code()) {
|
||||||
|
case SERVER_CLOSED -> NoiseDirectProtos.Error.Type.UNAVAILABLE;
|
||||||
|
case NOISE_ERROR -> NoiseDirectProtos.Error.Type.ENCRYPTION_ERROR;
|
||||||
|
case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.Error.Type.HANDSHAKE_ERROR;
|
||||||
|
case AUTHENTICATION_ERROR -> NoiseDirectProtos.Error.Type.AUTHENTICATION_ERROR;
|
||||||
|
case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.Error.Type.INTERNAL_ERROR;
|
||||||
|
};
|
||||||
|
final NoiseDirectProtos.Error proto = NoiseDirectProtos.Error.newBuilder()
|
||||||
|
.setType(type)
|
||||||
|
.setMessage(err.message())
|
||||||
|
.build();
|
||||||
|
final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize());
|
||||||
|
proto.writeTo(new ByteBufOutputStream(byteBuf));
|
||||||
|
ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.ERROR, byteBuf))
|
||||||
|
.addListener(ChannelFutureListener.CLOSE);
|
||||||
|
} else {
|
||||||
|
ctx.write(msg, promise);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,90 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||||
|
|
||||||
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
|
import io.dropwizard.lifecycle.Managed;
|
||||||
|
import io.netty.bootstrap.ServerBootstrap;
|
||||||
|
import io.netty.channel.ChannelInitializer;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
|
import io.netty.channel.socket.ServerSocketChannel;
|
||||||
|
import io.netty.channel.socket.SocketChannel;
|
||||||
|
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
||||||
|
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
||||||
|
import java.net.InetSocketAddress;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A NoiseDirectTunnelServer accepts traffic from the public internet (in the form of Noise packets framed by a custom
|
||||||
|
* binary framing protocol) and passes it through to a local gRPC server.
|
||||||
|
*/
|
||||||
|
public class NoiseDirectTunnelServer implements Managed {
|
||||||
|
|
||||||
|
private final ServerBootstrap bootstrap;
|
||||||
|
private ServerSocketChannel channel;
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(NoiseDirectTunnelServer.class);
|
||||||
|
|
||||||
|
public NoiseDirectTunnelServer(final int port,
|
||||||
|
final NioEventLoopGroup eventLoopGroup,
|
||||||
|
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||||
|
final ClientPublicKeysManager clientPublicKeysManager,
|
||||||
|
final ECKeyPair ecKeyPair,
|
||||||
|
final LocalAddress authenticatedGrpcServerAddress,
|
||||||
|
final LocalAddress anonymousGrpcServerAddress) {
|
||||||
|
|
||||||
|
this.bootstrap = new ServerBootstrap()
|
||||||
|
.group(eventLoopGroup)
|
||||||
|
.channel(NioServerSocketChannel.class)
|
||||||
|
.localAddress(port)
|
||||||
|
.childHandler(new ChannelInitializer<SocketChannel>() {
|
||||||
|
@Override
|
||||||
|
protected void initChannel(SocketChannel socketChannel) {
|
||||||
|
socketChannel.pipeline()
|
||||||
|
.addLast(new ProxyProtocolDetectionHandler())
|
||||||
|
.addLast(new HAProxyMessageHandler());
|
||||||
|
|
||||||
|
socketChannel.pipeline()
|
||||||
|
// frame byte followed by a 2-byte length field
|
||||||
|
.addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2))
|
||||||
|
// Parses NoiseDirectFrames from wire bytes and vice versa
|
||||||
|
.addLast(new NoiseDirectFrameCodec())
|
||||||
|
// Turn generic OutboundCloseErrorMessages into noise direct error frames
|
||||||
|
.addLast(new NoiseDirectOutboundErrorHandler())
|
||||||
|
// Waits for the handshake to finish and then replaces itself with a NoiseDirectFrameCodec and a
|
||||||
|
// NoiseHandler to handle noise encryption/decryption
|
||||||
|
.addLast(new NoiseDirectHandshakeSelector(clientPublicKeysManager, ecKeyPair))
|
||||||
|
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||||
|
// once the Noise handshake has completed
|
||||||
|
.addLast(new EstablishLocalGrpcConnectionHandler(
|
||||||
|
grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
||||||
|
.addLast(new ErrorHandler());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
public InetSocketAddress getLocalAddress() {
|
||||||
|
return channel.localAddress();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void start() throws InterruptedException {
|
||||||
|
channel = (ServerSocketChannel) bootstrap.bind().await().channel();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void stop() throws InterruptedException {
|
||||||
|
if (channel != null) {
|
||||||
|
channel.close().await();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,12 +1,11 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
|
|
||||||
enum ApplicationWebSocketCloseReason {
|
enum ApplicationWebSocketCloseReason {
|
||||||
NOISE_HANDSHAKE_ERROR(4001),
|
NOISE_HANDSHAKE_ERROR(4001),
|
||||||
CLIENT_AUTHENTICATION_ERROR(4002),
|
CLIENT_AUTHENTICATION_ERROR(4002),
|
||||||
NOISE_ENCRYPTION_ERROR(4003),
|
NOISE_ENCRYPTION_ERROR(4003);
|
||||||
REAUTHENTICATION_REQUIRED(4004);
|
|
||||||
|
|
||||||
private final int statusCode;
|
private final int statusCode;
|
||||||
|
|
||||||
|
@ -17,8 +16,4 @@ enum ApplicationWebSocketCloseReason {
|
||||||
public int getStatusCode() {
|
public int getStatusCode() {
|
||||||
return statusCode;
|
return statusCode;
|
||||||
}
|
}
|
||||||
|
|
||||||
WebSocketCloseStatus toWebSocketCloseStatus(final String reason) {
|
|
||||||
return new WebSocketCloseStatus(statusCode, reason);
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import com.southernstorm.noise.protocol.Noise;
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
|
@ -28,6 +28,7 @@ import javax.net.ssl.SSLException;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.*;
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -103,7 +104,10 @@ public class NoiseWebSocketTunnelServer implements Managed {
|
||||||
// request and passed it down the pipeline
|
// request and passed it down the pipeline
|
||||||
.addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH))
|
.addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH))
|
||||||
.addLast(new WebSocketServerProtocolHandler("/", true))
|
.addLast(new WebSocketServerProtocolHandler("/", true))
|
||||||
|
// Turn generic OutboundCloseErrorMessages into websocket close frames
|
||||||
|
.addLast(new WebSocketOutboundErrorHandler())
|
||||||
.addLast(new RejectUnsupportedMessagesHandler())
|
.addLast(new RejectUnsupportedMessagesHandler())
|
||||||
|
.addLast(new WebsocketPayloadCodec())
|
||||||
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
|
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
|
||||||
// a WebSocket handshake has been completed
|
// a WebSocket handshake has been completed
|
||||||
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret))
|
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret))
|
|
@ -1,4 +1,4 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
|
@ -1,4 +1,4 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
import io.netty.channel.ChannelFutureListener;
|
import io.netty.channel.ChannelFutureListener;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
|
@ -0,0 +1,64 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.ClientAuthenticationException;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||||
|
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames
|
||||||
|
*/
|
||||||
|
class WebSocketOutboundErrorHandler extends ChannelDuplexHandler {
|
||||||
|
|
||||||
|
private boolean websocketHandshakeComplete = false;
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(WebSocketOutboundErrorHandler.class);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||||
|
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
|
||||||
|
setWebsocketHandshakeComplete();
|
||||||
|
}
|
||||||
|
|
||||||
|
context.fireUserEventTriggered(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void setWebsocketHandshakeComplete() {
|
||||||
|
this.websocketHandshakeComplete = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||||
|
if (msg instanceof OutboundCloseErrorMessage err) {
|
||||||
|
if (websocketHandshakeComplete) {
|
||||||
|
final int status = switch (err.code()) {
|
||||||
|
case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code();
|
||||||
|
case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode();
|
||||||
|
case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode();
|
||||||
|
case AUTHENTICATION_ERROR -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode();
|
||||||
|
case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
|
||||||
|
};
|
||||||
|
ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise)
|
||||||
|
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||||
|
} else {
|
||||||
|
log.debug("Error {} occurred before websocket handshake complete", err);
|
||||||
|
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
|
||||||
|
// way; just close the connection instead.
|
||||||
|
ctx.close();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ctx.write(msg, promise);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import com.google.common.net.InetAddresses;
|
import com.google.common.net.InetAddresses;
|
||||||
|
@ -20,6 +20,9 @@ import org.apache.commons.lang3.StringUtils;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler;
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -74,7 +77,7 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||||
preferredRemoteAddress = maybePreferredRemoteAddress.get();
|
preferredRemoteAddress = maybePreferredRemoteAddress.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
GrpcClientConnectionManager.handleHandshakeComplete(context.channel(),
|
GrpcClientConnectionManager.handleHandshakeInitiated(context.channel(),
|
||||||
preferredRemoteAddress,
|
preferredRemoteAddress,
|
||||||
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
|
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
|
||||||
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));
|
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));
|
|
@ -0,0 +1,38 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts buffers from inbound BinaryWebsocketFrames before forwarding to a
|
||||||
|
* {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} for decryption and wraps outbound encrypted noise
|
||||||
|
* packet buffers in BinaryWebsocketFrames for writing through the websocket layer.
|
||||||
|
*/
|
||||||
|
public class WebsocketPayloadCodec extends ChannelDuplexHandler {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
|
||||||
|
if (msg instanceof BinaryWebSocketFrame frame) {
|
||||||
|
ctx.fireChannelRead(frame.content());
|
||||||
|
} else {
|
||||||
|
ctx.fireChannelRead(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||||
|
if (msg instanceof ByteBuf bb) {
|
||||||
|
ctx.write(new BinaryWebSocketFrame(bb), promise);
|
||||||
|
} else {
|
||||||
|
ctx.write(msg, promise);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
option java_package = "org.whispersystems.textsecuregcm.grpc.net.noisedirect";
|
||||||
|
option java_outer_classname = "NoiseDirectProtos";
|
||||||
|
|
||||||
|
message Error {
|
||||||
|
enum Type {
|
||||||
|
UNSPECIFIED = 0;
|
||||||
|
HANDSHAKE_ERROR = 1;
|
||||||
|
ENCRYPTION_ERROR = 2;
|
||||||
|
UNAVAILABLE = 3;
|
||||||
|
INTERNAL_ERROR = 4;
|
||||||
|
AUTHENTICATION_ERROR = 5;
|
||||||
|
}
|
||||||
|
Type type = 1;
|
||||||
|
string message = 2;
|
||||||
|
}
|
|
@ -4,7 +4,7 @@ import io.netty.util.ResourceLeakDetector;
|
||||||
import org.junit.jupiter.api.AfterAll;
|
import org.junit.jupiter.api.AfterAll;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
|
||||||
abstract class AbstractLeakDetectionTest {
|
public abstract class AbstractLeakDetectionTest {
|
||||||
|
|
||||||
private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel;
|
private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel;
|
||||||
|
|
||||||
|
|
|
@ -119,15 +119,15 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
* waiting messages in the channel, return null.
|
* waiting messages in the channel, return null.
|
||||||
*/
|
*/
|
||||||
byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException {
|
byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException {
|
||||||
final BinaryWebSocketFrame responseFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
final ByteBuf responseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||||
if (responseFrame == null) {
|
if (responseFrame == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
final byte[] plaintext = new byte[responseFrame.content().readableBytes() - 16];
|
final byte[] plaintext = new byte[responseFrame.readableBytes() - 16];
|
||||||
final int read = clientCipherPair.getReceiver().decryptWithAd(null,
|
final int read = clientCipherPair.getReceiver().decryptWithAd(null,
|
||||||
ByteBufUtil.getBytes(responseFrame.content()), 0,
|
ByteBufUtil.getBytes(responseFrame), 0,
|
||||||
plaintext, 0,
|
plaintext, 0,
|
||||||
responseFrame.content().readableBytes());
|
responseFrame.readableBytes());
|
||||||
assertEquals(read, plaintext.length);
|
assertEquals(read, plaintext.length);
|
||||||
return plaintext;
|
return plaintext;
|
||||||
}
|
}
|
||||||
|
@ -140,7 +140,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
|
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
|
||||||
|
|
||||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await();
|
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(content).await();
|
||||||
|
|
||||||
assertFalse(writeFuture.isSuccess());
|
assertFalse(writeFuture.isSuccess());
|
||||||
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
|
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
|
||||||
|
@ -150,18 +150,18 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
|
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
|
||||||
final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7];
|
final ByteBuf[] frames = new ByteBuf[7];
|
||||||
|
|
||||||
for (int i = 0; i < frames.length; i++) {
|
for (int i = 0; i < frames.length; i++) {
|
||||||
final byte[] contentBytes = new byte[17];
|
final byte[] contentBytes = new byte[17];
|
||||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||||
|
|
||||||
frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
|
frames[i] = Unpooled.wrappedBuffer(contentBytes);
|
||||||
|
|
||||||
embeddedChannel.writeOneInbound(frames[i]).await();
|
embeddedChannel.writeOneInbound(frames[i]).await();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (final BinaryWebSocketFrame frame : frames) {
|
for (final ByteBuf frame : frames) {
|
||||||
assertEquals(0, frame.refCnt());
|
assertEquals(0, frame.refCnt());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,11 +169,11 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void handleNonWebSocketBinaryFrame() throws Throwable {
|
void handleNonByteBufBinaryFrame() throws Throwable {
|
||||||
final byte[] contentBytes = new byte[17];
|
final byte[] contentBytes = new byte[17];
|
||||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||||
|
|
||||||
final ByteBuf message = Unpooled.wrappedBuffer(contentBytes);
|
final BinaryWebSocketFrame message = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
|
||||||
|
|
||||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
|
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
|
||||||
|
|
||||||
|
@ -192,7 +192,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
|
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
|
||||||
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
|
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
|
||||||
|
|
||||||
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext));
|
final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(ciphertext);
|
||||||
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
|
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
|
||||||
assertEquals(0, ciphertextFrame.refCnt());
|
assertEquals(0, ciphertextFrame.refCnt());
|
||||||
|
|
||||||
|
@ -206,7 +206,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
final byte[] bogusCiphertext = new byte[32];
|
final byte[] bogusCiphertext = new byte[32];
|
||||||
io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext);
|
io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext);
|
||||||
|
|
||||||
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext));
|
final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(bogusCiphertext);
|
||||||
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
|
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
|
||||||
|
|
||||||
assertEquals(0, ciphertextFrame.refCnt());
|
assertEquals(0, ciphertextFrame.refCnt());
|
||||||
|
@ -235,11 +235,11 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
assertTrue(writePlaintextFuture.await().isSuccess());
|
assertTrue(writePlaintextFuture.await().isSuccess());
|
||||||
assertEquals(0, plaintextBuffer.refCnt());
|
assertEquals(0, plaintextBuffer.refCnt());
|
||||||
|
|
||||||
final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
final ByteBuf ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||||
assertNotNull(ciphertextFrame);
|
assertNotNull(ciphertextFrame);
|
||||||
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
||||||
|
|
||||||
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
|
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame);
|
||||||
ciphertextFrame.release();
|
ciphertextFrame.release();
|
||||||
|
|
||||||
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
|
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
|
||||||
|
@ -272,10 +272,10 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
final byte[] decryptedPlaintext = new byte[plaintextLength];
|
final byte[] decryptedPlaintext = new byte[plaintextLength];
|
||||||
int plaintextOffset = 0;
|
int plaintextOffset = 0;
|
||||||
BinaryWebSocketFrame ciphertextFrame;
|
ByteBuf ciphertextFrame;
|
||||||
while ((ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll()) != null) {
|
while ((ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll()) != null) {
|
||||||
assertTrue(ciphertextFrame.content().readableBytes() <= Noise.MAX_PACKET_LEN);
|
assertTrue(ciphertextFrame.readableBytes() <= Noise.MAX_PACKET_LEN);
|
||||||
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
|
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame);
|
||||||
ciphertextFrame.release();
|
ciphertextFrame.release();
|
||||||
plaintextOffset += clientCipherStatePair.getReceiver()
|
plaintextOffset += clientCipherStatePair.getReceiver()
|
||||||
.decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length);
|
.decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length);
|
||||||
|
@ -289,7 +289,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
public void writeHugeInboundMessage() throws Throwable {
|
public void writeHugeInboundMessage() throws Throwable {
|
||||||
doHandshake();
|
doHandshake();
|
||||||
final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
|
final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
|
||||||
embeddedChannel.pipeline().fireChannelRead(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(big)));
|
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big));
|
||||||
assertThrows(NoiseException.class, embeddedChannel::checkException);
|
assertThrows(NoiseException.class, embeddedChannel::checkException);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,426 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyByte;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import com.google.protobuf.ByteString;
|
||||||
|
import io.grpc.ManagedChannel;
|
||||||
|
import io.grpc.ServerBuilder;
|
||||||
|
import io.grpc.Status;
|
||||||
|
import io.grpc.netty.NettyChannelBuilder;
|
||||||
|
import io.grpc.stub.StreamObserver;
|
||||||
|
import io.netty.channel.DefaultEventLoopGroup;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.local.LocalChannel;
|
||||||
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
|
import io.netty.handler.codec.haproxy.HAProxyCommand;
|
||||||
|
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||||
|
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
|
||||||
|
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.ExecutionException;
|
||||||
|
import java.util.concurrent.Executor;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.TimeoutException;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
import org.apache.commons.lang3.RandomStringUtils;
|
||||||
|
import org.junit.jupiter.api.AfterAll;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
|
import org.signal.chat.rpc.EchoRequest;
|
||||||
|
import org.signal.chat.rpc.EchoResponse;
|
||||||
|
import org.signal.chat.rpc.EchoServiceGrpc;
|
||||||
|
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
||||||
|
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||||
|
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||||
|
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||||
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
|
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||||
|
|
||||||
|
public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
|
private static NioEventLoopGroup nioEventLoopGroup;
|
||||||
|
private static DefaultEventLoopGroup defaultEventLoopGroup;
|
||||||
|
private static ExecutorService delegatedTaskExecutor;
|
||||||
|
private static ExecutorService serverCallExecutor;
|
||||||
|
|
||||||
|
private GrpcClientConnectionManager grpcClientConnectionManager;
|
||||||
|
private ClientPublicKeysManager clientPublicKeysManager;
|
||||||
|
|
||||||
|
private ECKeyPair serverKeyPair;
|
||||||
|
private ECKeyPair clientKeyPair;
|
||||||
|
|
||||||
|
private ManagedLocalGrpcServer authenticatedGrpcServer;
|
||||||
|
private ManagedLocalGrpcServer anonymousGrpcServer;
|
||||||
|
|
||||||
|
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
|
||||||
|
private static final byte DEVICE_ID = Device.PRIMARY_ID;
|
||||||
|
|
||||||
|
public static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
static void setUpBeforeAll() {
|
||||||
|
nioEventLoopGroup = new NioEventLoopGroup();
|
||||||
|
defaultEventLoopGroup = new DefaultEventLoopGroup();
|
||||||
|
delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
||||||
|
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() throws Exception {
|
||||||
|
|
||||||
|
clientKeyPair = Curve.generateKeyPair();
|
||||||
|
serverKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
|
grpcClientConnectionManager = new GrpcClientConnectionManager();
|
||||||
|
|
||||||
|
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||||
|
when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||||
|
|
||||||
|
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated");
|
||||||
|
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous");
|
||||||
|
|
||||||
|
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
|
||||||
|
@Override
|
||||||
|
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||||
|
serverBuilder
|
||||||
|
.executor(serverCallExecutor)
|
||||||
|
.addService(new RequestAttributesServiceImpl())
|
||||||
|
.addService(new EchoServiceImpl())
|
||||||
|
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
|
||||||
|
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
|
||||||
|
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
authenticatedGrpcServer.start();
|
||||||
|
|
||||||
|
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
|
||||||
|
@Override
|
||||||
|
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||||
|
serverBuilder
|
||||||
|
.executor(serverCallExecutor)
|
||||||
|
.addService(new RequestAttributesServiceImpl())
|
||||||
|
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
|
||||||
|
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
anonymousGrpcServer.start();
|
||||||
|
this.start(
|
||||||
|
nioEventLoopGroup,
|
||||||
|
delegatedTaskExecutor,
|
||||||
|
grpcClientConnectionManager,
|
||||||
|
clientPublicKeysManager,
|
||||||
|
serverKeyPair,
|
||||||
|
authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
|
||||||
|
RECOGNIZED_PROXY_SECRET);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected abstract void start(
|
||||||
|
final NioEventLoopGroup eventLoopGroup,
|
||||||
|
final Executor delegatedTaskExecutor,
|
||||||
|
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||||
|
final ClientPublicKeysManager clientPublicKeysManager,
|
||||||
|
final ECKeyPair serverKeyPair,
|
||||||
|
final LocalAddress authenticatedGrpcServerAddress,
|
||||||
|
final LocalAddress anonymousGrpcServerAddress,
|
||||||
|
final String recognizedProxySecret) throws Exception;
|
||||||
|
protected abstract void stop() throws Exception;
|
||||||
|
protected abstract NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey);
|
||||||
|
|
||||||
|
public void assertClosedWith(final NoiseTunnelClient client, final CloseFrameEvent.CloseReason reason)
|
||||||
|
throws ExecutionException, InterruptedException, TimeoutException {
|
||||||
|
final CloseFrameEvent result = client.closeFrameFuture().get(1, TimeUnit.SECONDS);
|
||||||
|
assertEquals(reason, result.closeReason());
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() throws Exception {
|
||||||
|
authenticatedGrpcServer.stop();
|
||||||
|
anonymousGrpcServer.stop();
|
||||||
|
this.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterAll
|
||||||
|
static void tearDownAfterAll() throws InterruptedException {
|
||||||
|
nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
|
||||||
|
defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
|
||||||
|
|
||||||
|
delegatedTaskExecutor.shutdown();
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
serverCallExecutor.shutdown();
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(booleans = {true, false})
|
||||||
|
void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
|
||||||
|
try (final NoiseTunnelClient client = authenticated()
|
||||||
|
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
|
||||||
|
.build()) {
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||||
|
|
||||||
|
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
||||||
|
assertEquals(DEVICE_ID, response.getDeviceId());
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAuthenticatedBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException {
|
||||||
|
|
||||||
|
// Try to verify the server's public key with something other than the key with which it was signed
|
||||||
|
try (final NoiseTunnelClient client = authenticated()
|
||||||
|
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
|
||||||
|
.build()) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException, ExecutionException, TimeoutException {
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
|
||||||
|
|
||||||
|
try (final NoiseTunnelClient client = authenticated().build()) {
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAuthenticatedUnrecognizedDevice() throws InterruptedException, ExecutionException, TimeoutException {
|
||||||
|
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||||
|
|
||||||
|
try (final NoiseTunnelClient client = authenticated().build()) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAnonymous() throws InterruptedException {
|
||||||
|
try (final NoiseTunnelClient client = anonymous().build()) {
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||||
|
|
||||||
|
assertTrue(response.getAccountIdentifier().isEmpty());
|
||||||
|
assertEquals(0, response.getDeviceId());
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAnonymousBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException {
|
||||||
|
|
||||||
|
// Try to verify the server's public key with something other than the key with which it was signed
|
||||||
|
try (final NoiseTunnelClient client = anonymous()
|
||||||
|
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
|
||||||
|
.build()) {
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
protected ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
|
||||||
|
return NettyChannelBuilder.forAddress(localAddress)
|
||||||
|
.channelType(LocalChannel.class)
|
||||||
|
.eventLoopGroup(defaultEventLoopGroup)
|
||||||
|
.usePlaintext()
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void closeForReauthentication() throws InterruptedException, ExecutionException, TimeoutException {
|
||||||
|
|
||||||
|
try (final NoiseTunnelClient client = authenticated().build()) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||||
|
|
||||||
|
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
||||||
|
assertEquals(DEVICE_ID, response.getDeviceId());
|
||||||
|
|
||||||
|
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
|
||||||
|
final CloseFrameEvent closeEvent = client.closeFrameFuture().get(2, TimeUnit.SECONDS);
|
||||||
|
assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeEvent.closeReason());
|
||||||
|
assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeEvent.closeInitiator());
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void waitForCallCompletion() throws InterruptedException, ExecutionException, TimeoutException {
|
||||||
|
try (final NoiseTunnelClient client = authenticated().build()) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||||
|
|
||||||
|
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
||||||
|
assertEquals(DEVICE_ID, response.getDeviceId());
|
||||||
|
|
||||||
|
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
|
||||||
|
|
||||||
|
// Start an open-ended server call and leave it in a non-complete state
|
||||||
|
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
|
||||||
|
new StreamObserver<>() {
|
||||||
|
@Override
|
||||||
|
public void onNext(final EchoResponse echoResponse) {
|
||||||
|
responseCountDownLatch.countDown();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(final Throwable throwable) {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onCompleted() {
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
|
||||||
|
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
|
||||||
|
// truly started before requesting connection closure.
|
||||||
|
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
|
||||||
|
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
|
||||||
|
|
||||||
|
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
|
||||||
|
try {
|
||||||
|
client.closeFrameFuture().get(100, TimeUnit.MILLISECONDS);
|
||||||
|
fail("Channel should not close until active requests have finished");
|
||||||
|
} catch (TimeoutException e) {
|
||||||
|
}
|
||||||
|
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
|
||||||
|
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
|
||||||
|
|
||||||
|
// Complete the open-ended server call
|
||||||
|
echoRequestStreamObserver.onCompleted();
|
||||||
|
|
||||||
|
final CloseFrameEvent closeFrameEvent = client.closeFrameFuture().get(1, TimeUnit.SECONDS);
|
||||||
|
assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeFrameEvent.closeInitiator());
|
||||||
|
assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeFrameEvent.closeReason());
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected NoiseTunnelClient.Builder anonymous() {
|
||||||
|
return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected NoiseTunnelClient.Builder authenticated() {
|
||||||
|
return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey())
|
||||||
|
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Supplier<HAProxyMessage> proxyMessageSupplier(boolean includeProxyMesage) {
|
||||||
|
return includeProxyMesage
|
||||||
|
? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
|
||||||
|
"10.0.0.1", "10.0.0.2", 12345, 443)
|
||||||
|
: null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,18 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
|
||||||
|
|
||||||
class ClientErrorHandler extends ErrorHandler {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
|
||||||
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
|
|
||||||
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
|
|
||||||
setWebsocketHandshakeComplete();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
super.userEventTriggered(context, event);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,198 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import com.southernstorm.noise.protocol.Noise;
|
|
||||||
import io.netty.bootstrap.Bootstrap;
|
|
||||||
import io.netty.buffer.Unpooled;
|
|
||||||
import io.netty.channel.ChannelFutureListener;
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
|
||||||
import io.netty.channel.ChannelInitializer;
|
|
||||||
import io.netty.channel.socket.SocketChannel;
|
|
||||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
|
||||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
|
||||||
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
|
|
||||||
import io.netty.handler.codec.http.HttpClientCodec;
|
|
||||||
import io.netty.handler.codec.http.HttpHeaders;
|
|
||||||
import io.netty.handler.codec.http.HttpObjectAggregator;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
|
|
||||||
import io.netty.handler.ssl.SslContextBuilder;
|
|
||||||
import io.netty.util.ReferenceCountUtil;
|
|
||||||
import java.net.SocketAddress;
|
|
||||||
import java.net.URI;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.security.cert.X509Certificate;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.UUID;
|
|
||||||
import java.util.function.Supplier;
|
|
||||||
import javax.annotation.Nullable;
|
|
||||||
import javax.net.ssl.SSLException;
|
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
|
||||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote
|
|
||||||
* gRPC server
|
|
||||||
*/
|
|
||||||
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|
||||||
|
|
||||||
private final boolean useTls;
|
|
||||||
@Nullable private final X509Certificate trustedServerCertificate;
|
|
||||||
private final URI websocketUri;
|
|
||||||
private final boolean authenticated;
|
|
||||||
@Nullable private final ECKeyPair ecKeyPair;
|
|
||||||
private final ECPublicKey serverPublicKey;
|
|
||||||
@Nullable private final UUID accountIdentifier;
|
|
||||||
private final byte deviceId;
|
|
||||||
private final HttpHeaders headers;
|
|
||||||
private final SocketAddress remoteServerAddress;
|
|
||||||
private final WebSocketCloseListener webSocketCloseListener;
|
|
||||||
@Nullable private final Supplier<HAProxyMessage> proxyMessageSupplier;
|
|
||||||
// If provided, will be sent with the payload in the noise handshake
|
|
||||||
private final byte[] fastOpenRequest;
|
|
||||||
|
|
||||||
private final List<Object> pendingReads = new ArrayList<>();
|
|
||||||
|
|
||||||
private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
|
|
||||||
|
|
||||||
EstablishRemoteConnectionHandler(
|
|
||||||
final boolean useTls,
|
|
||||||
@Nullable final X509Certificate trustedServerCertificate,
|
|
||||||
final URI websocketUri,
|
|
||||||
final boolean authenticated,
|
|
||||||
@Nullable final ECKeyPair ecKeyPair,
|
|
||||||
final ECPublicKey serverPublicKey,
|
|
||||||
@Nullable final UUID accountIdentifier,
|
|
||||||
final byte deviceId,
|
|
||||||
final HttpHeaders headers,
|
|
||||||
final SocketAddress remoteServerAddress,
|
|
||||||
final WebSocketCloseListener webSocketCloseListener,
|
|
||||||
@Nullable Supplier<HAProxyMessage> proxyMessageSupplier,
|
|
||||||
@Nullable byte[] fastOpenRequest) {
|
|
||||||
|
|
||||||
this.useTls = useTls;
|
|
||||||
this.trustedServerCertificate = trustedServerCertificate;
|
|
||||||
this.websocketUri = websocketUri;
|
|
||||||
this.authenticated = authenticated;
|
|
||||||
this.ecKeyPair = ecKeyPair;
|
|
||||||
this.serverPublicKey = serverPublicKey;
|
|
||||||
this.accountIdentifier = accountIdentifier;
|
|
||||||
this.deviceId = deviceId;
|
|
||||||
this.headers = headers;
|
|
||||||
this.remoteServerAddress = remoteServerAddress;
|
|
||||||
this.webSocketCloseListener = webSocketCloseListener;
|
|
||||||
this.proxyMessageSupplier = proxyMessageSupplier;
|
|
||||||
this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void handlerAdded(final ChannelHandlerContext localContext) {
|
|
||||||
new Bootstrap()
|
|
||||||
.channel(NioSocketChannel.class)
|
|
||||||
.group(localContext.channel().eventLoop())
|
|
||||||
.handler(new ChannelInitializer<SocketChannel>() {
|
|
||||||
@Override
|
|
||||||
protected void initChannel(final SocketChannel channel) throws SSLException {
|
|
||||||
|
|
||||||
if (proxyMessageSupplier != null) {
|
|
||||||
// In a production setting, we'd want some mechanism to remove these handlers after the initial message
|
|
||||||
// were sent. Since this is just for testing, though, we can tolerate the inefficiency of leaving a
|
|
||||||
// pair of inert handlers in the pipeline.
|
|
||||||
channel.pipeline()
|
|
||||||
.addLast(HAProxyMessageEncoder.INSTANCE)
|
|
||||||
.addLast(new HAProxyMessageSender(proxyMessageSupplier));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (useTls) {
|
|
||||||
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
|
|
||||||
|
|
||||||
if (trustedServerCertificate != null) {
|
|
||||||
sslContextBuilder.trustManager(trustedServerCertificate);
|
|
||||||
}
|
|
||||||
|
|
||||||
channel.pipeline().addLast(sslContextBuilder.build().newHandler(channel.alloc()));
|
|
||||||
}
|
|
||||||
|
|
||||||
final NoiseClientHandshakeHelper helper = authenticated
|
|
||||||
? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair)
|
|
||||||
: NoiseClientHandshakeHelper.NK(serverPublicKey);
|
|
||||||
|
|
||||||
channel.pipeline()
|
|
||||||
.addLast(new HttpClientCodec())
|
|
||||||
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
|
|
||||||
// Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we
|
|
||||||
// want to react to them on our own, we need to catch them before they hit that handler.
|
|
||||||
.addLast(new InboundCloseWebSocketFrameHandler(webSocketCloseListener))
|
|
||||||
.addLast(new WebSocketClientProtocolHandler(websocketUri,
|
|
||||||
WebSocketVersion.V13,
|
|
||||||
null,
|
|
||||||
false,
|
|
||||||
headers,
|
|
||||||
Noise.MAX_PACKET_LEN,
|
|
||||||
10_000))
|
|
||||||
.addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener))
|
|
||||||
// Listens for a Websocket HANDSHAKE_COMPLETE and begins the noise handshake when it is done
|
|
||||||
.addLast(new NoiseClientHandshakeHandler(helper, initialPayload()))
|
|
||||||
.addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
|
|
||||||
@Override
|
|
||||||
public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
|
|
||||||
throws Exception {
|
|
||||||
if (event instanceof NoiseClientHandshakeCompleteEvent handshakeCompleteEvent) {
|
|
||||||
remoteContext.pipeline()
|
|
||||||
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
|
|
||||||
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
|
|
||||||
|
|
||||||
// If there was a payload response on the handshake, write it back to our gRPC client
|
|
||||||
handshakeCompleteEvent.fastResponse().ifPresent(plaintext ->
|
|
||||||
localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext)));
|
|
||||||
|
|
||||||
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
|
|
||||||
pendingReads.forEach(localContext::fireChannelRead);
|
|
||||||
pendingReads.clear();
|
|
||||||
localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
|
|
||||||
}
|
|
||||||
|
|
||||||
super.userEventTriggered(remoteContext, event);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.addLast(new ClientErrorHandler());
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.connect(remoteServerAddress)
|
|
||||||
.addListener((ChannelFutureListener) future -> {
|
|
||||||
if (future.isSuccess()) {
|
|
||||||
// Close the local connection if the remote channel closes and vice versa
|
|
||||||
future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close());
|
|
||||||
localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
|
|
||||||
} else {
|
|
||||||
localContext.close();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
|
||||||
pendingReads.add(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
|
||||||
pendingReads.forEach(ReferenceCountUtil::release);
|
|
||||||
pendingReads.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
private byte[] initialPayload() {
|
|
||||||
if (!authenticated) {
|
|
||||||
return fastOpenRequest;
|
|
||||||
}
|
|
||||||
|
|
||||||
final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length);
|
|
||||||
bb.putLong(accountIdentifier.getMostSignificantBits());
|
|
||||||
bb.putLong(accountIdentifier.getLeastSignificantBits());
|
|
||||||
bb.put(deviceId);
|
|
||||||
bb.put(fastOpenRequest);
|
|
||||||
bb.flip();
|
|
||||||
return bb.array();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,9 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2024 Signal Messenger, LLC
|
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
|
||||||
*/
|
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
|
||||||
|
|
||||||
record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {}
|
|
|
@ -146,14 +146,14 @@ class GrpcClientConnectionManagerTest {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource
|
@MethodSource
|
||||||
void handleHandshakeCompleteRequestAttributes(final InetAddress preferredRemoteAddress,
|
void handleHandshakeInitiatedRequestAttributes(final InetAddress preferredRemoteAddress,
|
||||||
final String userAgentHeader,
|
final String userAgentHeader,
|
||||||
final String acceptLanguageHeader,
|
final String acceptLanguageHeader,
|
||||||
final RequestAttributes expectedRequestAttributes) {
|
final RequestAttributes expectedRequestAttributes) {
|
||||||
|
|
||||||
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
|
||||||
|
|
||||||
GrpcClientConnectionManager.handleHandshakeComplete(embeddedChannel,
|
GrpcClientConnectionManager.handleHandshakeInitiated(embeddedChannel,
|
||||||
preferredRemoteAddress,
|
preferredRemoteAddress,
|
||||||
userAgentHeader,
|
userAgentHeader,
|
||||||
acceptLanguageHeader);
|
acceptLanguageHeader);
|
||||||
|
@ -162,7 +162,7 @@ class GrpcClientConnectionManagerTest {
|
||||||
embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
|
embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Arguments> handleHandshakeCompleteRequestAttributes() {
|
private static List<Arguments> handleHandshakeInitiatedRequestAttributes() {
|
||||||
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
|
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
|
||||||
|
|
||||||
return List.of(
|
return List.of(
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
|
||||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
|
||||||
|
|
||||||
class InboundCloseWebSocketFrameHandler extends ChannelInboundHandlerAdapter {
|
|
||||||
|
|
||||||
private final WebSocketCloseListener webSocketCloseListener;
|
|
||||||
|
|
||||||
public InboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) {
|
|
||||||
this.webSocketCloseListener = webSocketCloseListener;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
|
||||||
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
|
|
||||||
webSocketCloseListener.handleWebSocketClosedByServer(closeWebSocketFrame.statusCode());
|
|
||||||
}
|
|
||||||
|
|
||||||
super.channelRead(context, message);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -10,9 +10,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
import com.southernstorm.noise.protocol.HandshakeState;
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.embedded.EmbeddedChannel;
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import javax.crypto.BadPaddingException;
|
import javax.crypto.BadPaddingException;
|
||||||
import javax.crypto.ShortBufferException;
|
import javax.crypto.ShortBufferException;
|
||||||
|
@ -49,22 +49,18 @@ class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
assertEquals(
|
assertEquals(
|
||||||
initiateHandshakeMessageLength,
|
initiateHandshakeMessageLength,
|
||||||
clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
|
clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
|
||||||
|
final ByteBuf initiateHandshakeMessageBuf = Unpooled.wrappedBuffer(initiateHandshakeMessage);
|
||||||
final BinaryWebSocketFrame initiateHandshakeFrame = new BinaryWebSocketFrame(
|
assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeMessageBuf).await().isSuccess());
|
||||||
Unpooled.wrappedBuffer(initiateHandshakeMessage));
|
assertEquals(0, initiateHandshakeMessageBuf.refCnt());
|
||||||
|
|
||||||
assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeFrame).await().isSuccess());
|
|
||||||
assertEquals(0, initiateHandshakeFrame.refCnt());
|
|
||||||
|
|
||||||
embeddedChannel.runPendingTasks();
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
// Read responder handshake message
|
// Read responder handshake message
|
||||||
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
||||||
final BinaryWebSocketFrame responderHandshakeFrame = (BinaryWebSocketFrame)
|
final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||||
embeddedChannel.outboundMessages().poll();
|
|
||||||
@SuppressWarnings("DataFlowIssue") final byte[] responderHandshakeBytes =
|
@SuppressWarnings("DataFlowIssue") final byte[] responderHandshakeBytes =
|
||||||
new byte[responderHandshakeFrame.content().readableBytes()];
|
new byte[responderHandshakeFrame.readableBytes()];
|
||||||
responderHandshakeFrame.content().readBytes(responderHandshakeBytes);
|
responderHandshakeFrame.readBytes(responderHandshakeBytes);
|
||||||
|
|
||||||
// ephemeral key, empty encrypted payload AEAD tag
|
// ephemeral key, empty encrypted payload AEAD tag
|
||||||
final byte[] handshakeResponsePayload = new byte[32 + 16];
|
final byte[] handshakeResponsePayload = new byte[32 + 16];
|
||||||
|
|
|
@ -15,10 +15,10 @@ import static org.mockito.Mockito.when;
|
||||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
import com.southernstorm.noise.protocol.HandshakeState;
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
import com.southernstorm.noise.protocol.Noise;
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.ChannelFuture;
|
import io.netty.channel.ChannelFuture;
|
||||||
import io.netty.channel.embedded.EmbeddedChannel;
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
|
||||||
import io.netty.util.internal.EmptyArrays;
|
import io.netty.util.internal.EmptyArrays;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.security.NoSuchAlgorithmException;
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
@ -34,6 +34,7 @@ import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler;
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.Device;
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||||
|
@ -204,13 +205,12 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
||||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
||||||
|
|
||||||
final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(
|
final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer(
|
||||||
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))));
|
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId)));
|
||||||
assertTrue(embeddedChannel.writeOneInbound(initiatorMessageFrame).await().isSuccess());
|
assertTrue(embeddedChannel.writeOneInbound(initiatorMessageFrame).await().isSuccess());
|
||||||
|
|
||||||
// While waiting for the public key, send another message
|
// While waiting for the public key, send another message
|
||||||
final ChannelFuture f = embeddedChannel.writeOneInbound(
|
final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await();
|
||||||
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(new byte[0]))).await();
|
|
||||||
assertInstanceOf(NoiseHandshakeException.class, f.exceptionNow());
|
assertInstanceOf(NoiseHandshakeException.class, f.exceptionNow());
|
||||||
|
|
||||||
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
|
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
|
||||||
|
@ -267,8 +267,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
final HandshakeState clientHandshakeState = clientHandshakeState();
|
final HandshakeState clientHandshakeState = clientHandshakeState();
|
||||||
final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
|
final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
|
||||||
|
|
||||||
final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame(
|
final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer(initiatorMessage);
|
||||||
Unpooled.wrappedBuffer(initiatorMessage));
|
|
||||||
final ChannelFuture await = embeddedChannel.writeOneInbound(initiatorMessageFrame).await();
|
final ChannelFuture await = embeddedChannel.writeOneInbound(initiatorMessageFrame).await();
|
||||||
assertEquals(0, initiatorMessageFrame.refCnt());
|
assertEquals(0, initiatorMessageFrame.refCnt());
|
||||||
if (!await.isSuccess()) {
|
if (!await.isSuccess()) {
|
||||||
|
@ -286,11 +285,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
|
|
||||||
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
||||||
|
|
||||||
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
|
final ByteBuf serverStaticKeyMessageFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||||
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
|
||||||
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
|
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
|
||||||
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
|
new byte[serverStaticKeyMessageFrame.readableBytes()];
|
||||||
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
|
serverStaticKeyMessageFrame.readBytes(serverStaticKeyMessageBytes);
|
||||||
|
|
||||||
assertEquals(readHandshakeResponse(clientHandshakeState, serverStaticKeyMessageBytes).length, 0);
|
assertEquals(readHandshakeResponse(clientHandshakeState, serverStaticKeyMessageBytes).length, 0);
|
||||||
|
|
||||||
|
|
|
@ -1,55 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import io.netty.buffer.ByteBufUtil;
|
|
||||||
import io.netty.buffer.Unpooled;
|
|
||||||
import io.netty.channel.ChannelFutureListener;
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
|
||||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
class NoiseClientHandshakeHandler extends ChannelInboundHandlerAdapter {
|
|
||||||
|
|
||||||
private final NoiseClientHandshakeHelper handshakeHelper;
|
|
||||||
private final byte[] payload;
|
|
||||||
|
|
||||||
NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper, final byte[] payload) {
|
|
||||||
this.handshakeHelper = handshakeHelper;
|
|
||||||
this.payload = payload;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
|
||||||
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
|
|
||||||
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
|
|
||||||
byte[] handshakeMessage = handshakeHelper.write(payload);
|
|
||||||
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
|
|
||||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
super.userEventTriggered(context, event);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void channelRead(final ChannelHandlerContext context, final Object message)
|
|
||||||
throws NoiseHandshakeException {
|
|
||||||
if (message instanceof BinaryWebSocketFrame frame) {
|
|
||||||
try {
|
|
||||||
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame.content()));
|
|
||||||
final Optional<byte[]> fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload);
|
|
||||||
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
|
|
||||||
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse));
|
|
||||||
} finally {
|
|
||||||
frame.release();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
context.fireChannelRead(message);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
|
||||||
handshakeHelper.destroy();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -16,6 +16,7 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.EnumSource;
|
import org.junit.jupiter.params.provider.EnumSource;
|
||||||
import org.signal.libsignal.protocol.ecc.Curve;
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientHandshakeHelper;
|
||||||
|
|
||||||
|
|
||||||
public class NoiseHandshakeHelperTest {
|
public class NoiseHandshakeHelperTest {
|
||||||
|
|
|
@ -1,160 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import io.netty.bootstrap.ServerBootstrap;
|
|
||||||
import io.netty.buffer.ByteBufUtil;
|
|
||||||
import io.netty.channel.Channel;
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
|
||||||
import io.netty.channel.ChannelInitializer;
|
|
||||||
import io.netty.channel.local.LocalAddress;
|
|
||||||
import io.netty.channel.local.LocalChannel;
|
|
||||||
import io.netty.channel.local.LocalServerChannel;
|
|
||||||
import io.netty.channel.nio.NioEventLoopGroup;
|
|
||||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
|
||||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
|
||||||
import io.netty.handler.codec.http.HttpHeaders;
|
|
||||||
import java.net.SocketAddress;
|
|
||||||
import java.net.URI;
|
|
||||||
import java.security.cert.X509Certificate;
|
|
||||||
import java.util.UUID;
|
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.function.Supplier;
|
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
|
||||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
|
||||||
|
|
||||||
class NoiseWebSocketTunnelClient implements AutoCloseable {
|
|
||||||
|
|
||||||
private final ServerBootstrap serverBootstrap;
|
|
||||||
private Channel serverChannel;
|
|
||||||
|
|
||||||
static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
|
|
||||||
static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
|
|
||||||
|
|
||||||
static class Builder {
|
|
||||||
|
|
||||||
final SocketAddress remoteServerAddress;
|
|
||||||
NioEventLoopGroup eventLoopGroup;
|
|
||||||
ECPublicKey serverPublicKey;
|
|
||||||
|
|
||||||
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
|
|
||||||
HttpHeaders headers = new DefaultHttpHeaders();
|
|
||||||
WebSocketCloseListener webSocketCloseListener = WebSocketCloseListener.NOOP_LISTENER;
|
|
||||||
|
|
||||||
boolean authenticated = false;
|
|
||||||
ECKeyPair ecKeyPair = null;
|
|
||||||
UUID accountIdentifier = null;
|
|
||||||
byte deviceId = 0x00;
|
|
||||||
boolean useTls;
|
|
||||||
X509Certificate trustedServerCertificate = null;
|
|
||||||
Supplier<HAProxyMessage> proxyMessageSupplier = null;
|
|
||||||
|
|
||||||
Builder(
|
|
||||||
final SocketAddress remoteServerAddress,
|
|
||||||
final NioEventLoopGroup eventLoopGroup,
|
|
||||||
final ECPublicKey serverPublicKey) {
|
|
||||||
this.remoteServerAddress = remoteServerAddress;
|
|
||||||
this.eventLoopGroup = eventLoopGroup;
|
|
||||||
this.serverPublicKey = serverPublicKey;
|
|
||||||
}
|
|
||||||
|
|
||||||
Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
|
|
||||||
this.authenticated = true;
|
|
||||||
this.accountIdentifier = accountIdentifier;
|
|
||||||
this.deviceId = deviceId;
|
|
||||||
this.ecKeyPair = ecKeyPair;
|
|
||||||
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
Builder setWebsocketUri(final URI websocketUri) {
|
|
||||||
this.websocketUri = websocketUri;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
Builder setUseTls(X509Certificate trustedServerCertificate) {
|
|
||||||
this.useTls = true;
|
|
||||||
this.trustedServerCertificate = trustedServerCertificate;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
Builder setProxyMessageSupplier(Supplier<HAProxyMessage> proxyMessageSupplier) {
|
|
||||||
this.proxyMessageSupplier = proxyMessageSupplier;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
Builder setHeaders(final HttpHeaders headers) {
|
|
||||||
this.headers = headers;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
Builder setWebSocketCloseListener(final WebSocketCloseListener webSocketCloseListener) {
|
|
||||||
this.webSocketCloseListener = webSocketCloseListener;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
Builder setServerPublicKey(ECPublicKey serverPublicKey) {
|
|
||||||
this.serverPublicKey = serverPublicKey;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
NoiseWebSocketTunnelClient build() {
|
|
||||||
final NoiseWebSocketTunnelClient client =
|
|
||||||
new NoiseWebSocketTunnelClient(eventLoopGroup, fastOpenRequest -> new EstablishRemoteConnectionHandler(
|
|
||||||
useTls, trustedServerCertificate, websocketUri, authenticated, ecKeyPair, serverPublicKey,
|
|
||||||
accountIdentifier, deviceId, headers, remoteServerAddress, webSocketCloseListener, proxyMessageSupplier,
|
|
||||||
fastOpenRequest));
|
|
||||||
client.start();
|
|
||||||
return client;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private NoiseWebSocketTunnelClient(NioEventLoopGroup eventLoopGroup,
|
|
||||||
Function<byte[], EstablishRemoteConnectionHandler> handler) {
|
|
||||||
|
|
||||||
this.serverBootstrap = new ServerBootstrap()
|
|
||||||
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
|
|
||||||
.channel(LocalServerChannel.class)
|
|
||||||
.group(eventLoopGroup)
|
|
||||||
.childHandler(new ChannelInitializer<LocalChannel>() {
|
|
||||||
@Override
|
|
||||||
protected void initChannel(final LocalChannel localChannel) {
|
|
||||||
localChannel.pipeline()
|
|
||||||
// We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the
|
|
||||||
// stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put
|
|
||||||
// in the handshake.
|
|
||||||
.addLast(Http2Buffering.handler())
|
|
||||||
// Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At
|
|
||||||
// that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually
|
|
||||||
// connect to the remote service
|
|
||||||
.addLast(new ChannelInboundHandlerAdapter() {
|
|
||||||
@Override
|
|
||||||
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
|
|
||||||
if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) {
|
|
||||||
byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest());
|
|
||||||
requestBufferedEvent.fastOpenRequest().release();
|
|
||||||
ctx.pipeline().addLast(handler.apply(fastOpenRequest));
|
|
||||||
}
|
|
||||||
super.userEventTriggered(ctx, evt);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.addLast(new ClientErrorHandler());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
LocalAddress getLocalAddress() {
|
|
||||||
return (LocalAddress) serverChannel.localAddress();
|
|
||||||
}
|
|
||||||
|
|
||||||
private NoiseWebSocketTunnelClient start() {
|
|
||||||
serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel();
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() throws InterruptedException {
|
|
||||||
serverChannel.close().await();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,705 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
|
||||||
import static org.mockito.ArgumentMatchers.anyByte;
|
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.verify;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
import com.google.protobuf.ByteString;
|
|
||||||
import io.grpc.ManagedChannel;
|
|
||||||
import io.grpc.ServerBuilder;
|
|
||||||
import io.grpc.Status;
|
|
||||||
import io.grpc.netty.NettyChannelBuilder;
|
|
||||||
import io.grpc.stub.StreamObserver;
|
|
||||||
import io.netty.channel.DefaultEventLoopGroup;
|
|
||||||
import io.netty.channel.local.LocalAddress;
|
|
||||||
import io.netty.channel.local.LocalChannel;
|
|
||||||
import io.netty.channel.nio.NioEventLoopGroup;
|
|
||||||
import io.netty.handler.codec.haproxy.HAProxyCommand;
|
|
||||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
|
||||||
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
|
|
||||||
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
|
|
||||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
|
||||||
import io.netty.handler.codec.http.HttpHeaders;
|
|
||||||
import java.io.ByteArrayInputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.net.URI;
|
|
||||||
import java.net.http.HttpClient;
|
|
||||||
import java.net.http.HttpRequest;
|
|
||||||
import java.net.http.HttpResponse;
|
|
||||||
import java.nio.charset.StandardCharsets;
|
|
||||||
import java.security.KeyFactory;
|
|
||||||
import java.security.KeyStore;
|
|
||||||
import java.security.NoSuchAlgorithmException;
|
|
||||||
import java.security.PrivateKey;
|
|
||||||
import java.security.SecureRandom;
|
|
||||||
import java.security.cert.CertificateException;
|
|
||||||
import java.security.cert.CertificateFactory;
|
|
||||||
import java.security.cert.X509Certificate;
|
|
||||||
import java.security.spec.InvalidKeySpecException;
|
|
||||||
import java.security.spec.PKCS8EncodedKeySpec;
|
|
||||||
import java.util.Base64;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.UUID;
|
|
||||||
import java.util.concurrent.CompletableFuture;
|
|
||||||
import java.util.concurrent.CountDownLatch;
|
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
import java.util.concurrent.Executors;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
import java.util.function.Supplier;
|
|
||||||
import javax.net.ssl.SSLContext;
|
|
||||||
import javax.net.ssl.TrustManagerFactory;
|
|
||||||
import org.apache.commons.lang3.RandomStringUtils;
|
|
||||||
import org.junit.jupiter.api.AfterAll;
|
|
||||||
import org.junit.jupiter.api.AfterEach;
|
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
|
||||||
import org.junit.jupiter.params.provider.ValueSource;
|
|
||||||
import org.signal.chat.rpc.EchoRequest;
|
|
||||||
import org.signal.chat.rpc.EchoResponse;
|
|
||||||
import org.signal.chat.rpc.EchoServiceGrpc;
|
|
||||||
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
|
||||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
|
||||||
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
|
||||||
import org.signal.chat.rpc.GetRequestAttributesResponse;
|
|
||||||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
|
||||||
import org.signal.libsignal.protocol.ecc.Curve;
|
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
|
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.Device;
|
|
||||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
|
||||||
|
|
||||||
class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
|
|
||||||
|
|
||||||
private static NioEventLoopGroup nioEventLoopGroup;
|
|
||||||
private static DefaultEventLoopGroup defaultEventLoopGroup;
|
|
||||||
private static ExecutorService delegatedTaskExecutor;
|
|
||||||
private static ExecutorService serverCallExecutor;
|
|
||||||
|
|
||||||
private static X509Certificate serverTlsCertificate;
|
|
||||||
|
|
||||||
private GrpcClientConnectionManager grpcClientConnectionManager;
|
|
||||||
private ClientPublicKeysManager clientPublicKeysManager;
|
|
||||||
|
|
||||||
private ECKeyPair serverKeyPair;
|
|
||||||
private ECKeyPair clientKeyPair;
|
|
||||||
|
|
||||||
private ManagedLocalGrpcServer authenticatedGrpcServer;
|
|
||||||
private ManagedLocalGrpcServer anonymousGrpcServer;
|
|
||||||
|
|
||||||
private NoiseWebSocketTunnelServer tlsNoiseWebSocketTunnelServer;
|
|
||||||
private NoiseWebSocketTunnelServer plaintextNoiseWebSocketTunnelServer;
|
|
||||||
|
|
||||||
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
|
|
||||||
private static final byte DEVICE_ID = Device.PRIMARY_ID;
|
|
||||||
|
|
||||||
private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16);
|
|
||||||
|
|
||||||
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
|
|
||||||
// They were generated with:
|
|
||||||
//
|
|
||||||
// ```shell
|
|
||||||
// openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost"
|
|
||||||
// ```
|
|
||||||
private static final String SERVER_CERTIFICATE = """
|
|
||||||
-----BEGIN CERTIFICATE-----
|
|
||||||
MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw
|
|
||||||
FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx
|
|
||||||
MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA
|
|
||||||
IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV
|
|
||||||
jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq
|
|
||||||
SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME
|
|
||||||
GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG
|
|
||||||
SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw
|
|
||||||
XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi
|
|
||||||
iOr9sHiO8Rn2u0xRKgU5Ig==
|
|
||||||
-----END CERTIFICATE-----
|
|
||||||
""";
|
|
||||||
|
|
||||||
// BEGIN/END PRIVATE KEY header/footer removed for easier parsing
|
|
||||||
private static final String SERVER_PRIVATE_KEY = """
|
|
||||||
MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj
|
|
||||||
kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd
|
|
||||||
PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA
|
|
||||||
O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo=
|
|
||||||
""";
|
|
||||||
|
|
||||||
@BeforeAll
|
|
||||||
static void setUpBeforeAll() throws CertificateException {
|
|
||||||
nioEventLoopGroup = new NioEventLoopGroup();
|
|
||||||
defaultEventLoopGroup = new DefaultEventLoopGroup();
|
|
||||||
delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
|
||||||
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
|
||||||
|
|
||||||
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
|
|
||||||
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
|
|
||||||
new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8)));
|
|
||||||
}
|
|
||||||
|
|
||||||
@BeforeEach
|
|
||||||
void setUp() throws NoSuchAlgorithmException, InvalidKeySpecException, IOException, InterruptedException {
|
|
||||||
|
|
||||||
final PrivateKey serverTlsPrivateKey;
|
|
||||||
{
|
|
||||||
final KeyFactory keyFactory = KeyFactory.getInstance("EC");
|
|
||||||
serverTlsPrivateKey =
|
|
||||||
keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
|
|
||||||
}
|
|
||||||
|
|
||||||
clientKeyPair = Curve.generateKeyPair();
|
|
||||||
serverKeyPair = Curve.generateKeyPair();
|
|
||||||
|
|
||||||
grpcClientConnectionManager = new GrpcClientConnectionManager();
|
|
||||||
|
|
||||||
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
|
||||||
when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
|
|
||||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
|
||||||
|
|
||||||
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
|
||||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
|
||||||
|
|
||||||
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated");
|
|
||||||
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous");
|
|
||||||
|
|
||||||
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
|
|
||||||
@Override
|
|
||||||
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
|
||||||
serverBuilder
|
|
||||||
.executor(serverCallExecutor)
|
|
||||||
.addService(new RequestAttributesServiceImpl())
|
|
||||||
.addService(new EchoServiceImpl())
|
|
||||||
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
|
|
||||||
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
|
|
||||||
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
authenticatedGrpcServer.start();
|
|
||||||
|
|
||||||
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
|
|
||||||
@Override
|
|
||||||
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
|
||||||
serverBuilder
|
|
||||||
.executor(serverCallExecutor)
|
|
||||||
.addService(new RequestAttributesServiceImpl())
|
|
||||||
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
|
|
||||||
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
anonymousGrpcServer.start();
|
|
||||||
|
|
||||||
tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
|
|
||||||
new X509Certificate[]{serverTlsCertificate},
|
|
||||||
serverTlsPrivateKey,
|
|
||||||
nioEventLoopGroup,
|
|
||||||
delegatedTaskExecutor,
|
|
||||||
grpcClientConnectionManager,
|
|
||||||
clientPublicKeysManager,
|
|
||||||
serverKeyPair,
|
|
||||||
authenticatedGrpcServerAddress,
|
|
||||||
anonymousGrpcServerAddress,
|
|
||||||
RECOGNIZED_PROXY_SECRET);
|
|
||||||
|
|
||||||
tlsNoiseWebSocketTunnelServer.start();
|
|
||||||
|
|
||||||
plaintextNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
nioEventLoopGroup,
|
|
||||||
delegatedTaskExecutor,
|
|
||||||
grpcClientConnectionManager,
|
|
||||||
clientPublicKeysManager,
|
|
||||||
serverKeyPair,
|
|
||||||
authenticatedGrpcServerAddress,
|
|
||||||
anonymousGrpcServerAddress,
|
|
||||||
RECOGNIZED_PROXY_SECRET);
|
|
||||||
|
|
||||||
plaintextNoiseWebSocketTunnelServer.start();
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterEach
|
|
||||||
void tearDown() throws InterruptedException {
|
|
||||||
tlsNoiseWebSocketTunnelServer.stop();
|
|
||||||
plaintextNoiseWebSocketTunnelServer.stop();
|
|
||||||
authenticatedGrpcServer.stop();
|
|
||||||
anonymousGrpcServer.stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterAll
|
|
||||||
static void tearDownAfterAll() throws InterruptedException {
|
|
||||||
nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
|
|
||||||
defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
|
|
||||||
|
|
||||||
delegatedTaskExecutor.shutdown();
|
|
||||||
//noinspection ResultOfMethodCallIgnored
|
|
||||||
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
|
||||||
|
|
||||||
serverCallExecutor.shutdown();
|
|
||||||
//noinspection ResultOfMethodCallIgnored
|
|
||||||
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
|
||||||
}
|
|
||||||
|
|
||||||
@ParameterizedTest
|
|
||||||
@ValueSource(booleans = {true, false})
|
|
||||||
void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
|
|
||||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
|
||||||
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
|
|
||||||
.build()) {
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
|
||||||
|
|
||||||
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
|
||||||
assertEquals(DEVICE_ID, response.getDeviceId());
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@ParameterizedTest
|
|
||||||
@ValueSource(booleans = {true, false})
|
|
||||||
void connectAuthenticatedPlaintext(final boolean includeProxyMessage) throws InterruptedException {
|
|
||||||
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient
|
|
||||||
.Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
|
|
||||||
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID)
|
|
||||||
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
|
|
||||||
.build()) {
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
|
||||||
|
|
||||||
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
|
||||||
assertEquals(DEVICE_ID, response.getDeviceId());
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void connectAuthenticatedBadServerKeySignature() throws InterruptedException {
|
|
||||||
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
|
||||||
|
|
||||||
// Try to verify the server's public key with something other than the key with which it was signed
|
|
||||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
|
||||||
.setWebSocketCloseListener(webSocketCloseListener)
|
|
||||||
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
|
|
||||||
.build()) {
|
|
||||||
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
//noinspection ResultOfMethodCallIgnored
|
|
||||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
|
||||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
|
||||||
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException {
|
|
||||||
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
|
||||||
|
|
||||||
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
|
||||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
|
|
||||||
|
|
||||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
|
||||||
.setWebSocketCloseListener(webSocketCloseListener)
|
|
||||||
.build()) {
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
//noinspection ResultOfMethodCallIgnored
|
|
||||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
|
||||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
|
||||||
ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void connectAuthenticatedUnrecognizedDevice() throws InterruptedException {
|
|
||||||
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
|
||||||
|
|
||||||
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
|
||||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
|
||||||
|
|
||||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
|
||||||
.setWebSocketCloseListener(webSocketCloseListener)
|
|
||||||
.build()) {
|
|
||||||
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
//noinspection ResultOfMethodCallIgnored
|
|
||||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
|
||||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
|
||||||
ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void connectAuthenticatedToAnonymousService() throws InterruptedException {
|
|
||||||
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
|
||||||
|
|
||||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
|
||||||
.setWebsocketUri(NoiseWebSocketTunnelClient.ANONYMOUS_WEBSOCKET_URI)
|
|
||||||
.setWebSocketCloseListener(webSocketCloseListener)
|
|
||||||
.build()) {
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
//noinspection ResultOfMethodCallIgnored
|
|
||||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
|
||||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
|
||||||
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void connectAnonymous() throws InterruptedException {
|
|
||||||
try (final NoiseWebSocketTunnelClient client = anonymous().build()) {
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
|
||||||
|
|
||||||
assertTrue(response.getAccountIdentifier().isEmpty());
|
|
||||||
assertEquals(0, response.getDeviceId());
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void connectAnonymousBadServerKeySignature() throws InterruptedException {
|
|
||||||
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
|
||||||
|
|
||||||
// Try to verify the server's public key with something other than the key with which it was signed
|
|
||||||
try (final NoiseWebSocketTunnelClient client = anonymous()
|
|
||||||
.setWebSocketCloseListener(webSocketCloseListener)
|
|
||||||
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
|
|
||||||
.build()) {
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
//noinspection ResultOfMethodCallIgnored
|
|
||||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
|
||||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
|
||||||
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void connectAnonymousToAuthenticatedService() throws InterruptedException {
|
|
||||||
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
|
||||||
|
|
||||||
try (final NoiseWebSocketTunnelClient client = anonymous()
|
|
||||||
.setWebsocketUri(NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI)
|
|
||||||
.setWebSocketCloseListener(webSocketCloseListener)
|
|
||||||
.build()) {
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
//noinspection ResultOfMethodCallIgnored
|
|
||||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
|
||||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
|
||||||
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
|
||||||
}
|
|
||||||
|
|
||||||
private ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
|
|
||||||
return NettyChannelBuilder.forAddress(localAddress)
|
|
||||||
.channelType(LocalChannel.class)
|
|
||||||
.eventLoopGroup(defaultEventLoopGroup)
|
|
||||||
.usePlaintext()
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void rejectIllegalRequests() throws Exception {
|
|
||||||
|
|
||||||
final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
|
|
||||||
keyStore.load(null, null);
|
|
||||||
keyStore.setCertificateEntry("tunnel", serverTlsCertificate);
|
|
||||||
|
|
||||||
final TrustManagerFactory trustManagerFactory =
|
|
||||||
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
|
|
||||||
|
|
||||||
trustManagerFactory.init(keyStore);
|
|
||||||
|
|
||||||
final SSLContext sslContext = SSLContext.getInstance("TLS");
|
|
||||||
sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
|
|
||||||
|
|
||||||
final URI authenticatedUri =
|
|
||||||
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null);
|
|
||||||
|
|
||||||
final URI incorrectUri =
|
|
||||||
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null);
|
|
||||||
|
|
||||||
try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
|
|
||||||
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
|
|
||||||
.uri(authenticatedUri)
|
|
||||||
.PUT(HttpRequest.BodyPublishers.ofString("test"))
|
|
||||||
.build(),
|
|
||||||
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
|
||||||
"Non-GET requests should not be allowed");
|
|
||||||
|
|
||||||
assertEquals(426, httpClient.send(HttpRequest.newBuilder()
|
|
||||||
.GET()
|
|
||||||
.uri(authenticatedUri)
|
|
||||||
.build(),
|
|
||||||
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
|
||||||
"GET requests without upgrade headers should not be allowed");
|
|
||||||
|
|
||||||
assertEquals(404, httpClient.send(HttpRequest.newBuilder()
|
|
||||||
.GET()
|
|
||||||
.uri(incorrectUri)
|
|
||||||
.build(),
|
|
||||||
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
|
||||||
"GET requests to unrecognized URIs should not be allowed");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void getRequestAttributes() throws InterruptedException {
|
|
||||||
final String remoteAddress = "4.5.6.7";
|
|
||||||
final String acceptLanguage = "en";
|
|
||||||
final String userAgent = "Signal-Desktop/1.2.3 Linux";
|
|
||||||
|
|
||||||
final HttpHeaders headers = new DefaultHttpHeaders()
|
|
||||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
|
||||||
.add("X-Forwarded-For", remoteAddress)
|
|
||||||
.add("Accept-Language", acceptLanguage)
|
|
||||||
.add("User-Agent", userAgent);
|
|
||||||
|
|
||||||
try (final NoiseWebSocketTunnelClient client = anonymous().setHeaders(headers).build()) {
|
|
||||||
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
|
|
||||||
|
|
||||||
assertEquals(remoteAddress, response.getRemoteAddress());
|
|
||||||
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
|
|
||||||
assertEquals(userAgent, response.getUserAgent());
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void closeForReauthentication() throws InterruptedException {
|
|
||||||
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
|
|
||||||
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
|
|
||||||
final AtomicBoolean closedByServer = new AtomicBoolean(false);
|
|
||||||
|
|
||||||
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void handleWebSocketClosedByClient(final int statusCode) {
|
|
||||||
serverCloseStatusCode.set(statusCode);
|
|
||||||
closedByServer.set(false);
|
|
||||||
connectionCloseLatch.countDown();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void handleWebSocketClosedByServer(final int statusCode) {
|
|
||||||
serverCloseStatusCode.set(statusCode);
|
|
||||||
closedByServer.set(true);
|
|
||||||
connectionCloseLatch.countDown();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
|
||||||
.setWebSocketCloseListener(webSocketCloseListener)
|
|
||||||
.build()) {
|
|
||||||
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
|
||||||
|
|
||||||
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
|
||||||
assertEquals(DEVICE_ID, response.getDeviceId());
|
|
||||||
|
|
||||||
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
|
|
||||||
assertTrue(connectionCloseLatch.await(2, TimeUnit.SECONDS));
|
|
||||||
|
|
||||||
assertEquals(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED.getStatusCode(),
|
|
||||||
serverCloseStatusCode.get());
|
|
||||||
|
|
||||||
assertTrue(closedByServer.get());
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void waitForCallCompletion() throws InterruptedException {
|
|
||||||
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
|
|
||||||
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
|
|
||||||
final AtomicBoolean closedByServer = new AtomicBoolean(false);
|
|
||||||
|
|
||||||
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void handleWebSocketClosedByClient(final int statusCode) {
|
|
||||||
serverCloseStatusCode.set(statusCode);
|
|
||||||
closedByServer.set(false);
|
|
||||||
connectionCloseLatch.countDown();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void handleWebSocketClosedByServer(final int statusCode) {
|
|
||||||
serverCloseStatusCode.set(statusCode);
|
|
||||||
closedByServer.set(true);
|
|
||||||
connectionCloseLatch.countDown();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
|
||||||
.setWebSocketCloseListener(webSocketCloseListener)
|
|
||||||
.build()) {
|
|
||||||
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
|
||||||
|
|
||||||
try {
|
|
||||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
|
||||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
|
||||||
|
|
||||||
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
|
||||||
assertEquals(DEVICE_ID, response.getDeviceId());
|
|
||||||
|
|
||||||
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
|
|
||||||
|
|
||||||
// Start an open-ended server call and leave it in a non-complete state
|
|
||||||
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
|
|
||||||
new StreamObserver<>() {
|
|
||||||
@Override
|
|
||||||
public void onNext(final EchoResponse echoResponse) {
|
|
||||||
responseCountDownLatch.countDown();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onError(final Throwable throwable) {
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onCompleted() {
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
|
|
||||||
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
|
|
||||||
// truly started before requesting connection closure.
|
|
||||||
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
|
|
||||||
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
|
|
||||||
|
|
||||||
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
|
|
||||||
assertFalse(connectionCloseLatch.await(1, TimeUnit.SECONDS),
|
|
||||||
"Channel should not close until active requests have finished");
|
|
||||||
|
|
||||||
//noinspection ResultOfMethodCallIgnored
|
|
||||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
|
|
||||||
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
|
|
||||||
|
|
||||||
// Complete the open-ended server call
|
|
||||||
echoRequestStreamObserver.onCompleted();
|
|
||||||
|
|
||||||
assertTrue(connectionCloseLatch.await(1, TimeUnit.SECONDS),
|
|
||||||
"Channel should close once active requests have finished");
|
|
||||||
|
|
||||||
assertTrue(closedByServer.get());
|
|
||||||
assertEquals(4004, serverCloseStatusCode.get());
|
|
||||||
} finally {
|
|
||||||
channel.shutdown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private NoiseWebSocketTunnelClient.Builder anonymous() {
|
|
||||||
return new NoiseWebSocketTunnelClient
|
|
||||||
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
|
|
||||||
.setUseTls(serverTlsCertificate);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private NoiseWebSocketTunnelClient.Builder authenticated() {
|
|
||||||
return new NoiseWebSocketTunnelClient
|
|
||||||
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
|
|
||||||
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID)
|
|
||||||
.setUseTls(serverTlsCertificate);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Supplier<HAProxyMessage> proxyMessageSupplier(boolean includeProxyMesage) {
|
|
||||||
return includeProxyMesage
|
|
||||||
? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
|
|
||||||
"10.0.0.1", "10.0.0.2", 12345, 443)
|
|
||||||
: null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,24 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
|
||||||
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
|
||||||
import io.netty.channel.ChannelPromise;
|
|
||||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
|
||||||
|
|
||||||
class OutboundCloseWebSocketFrameHandler extends ChannelOutboundHandlerAdapter {
|
|
||||||
|
|
||||||
private final WebSocketCloseListener webSocketCloseListener;
|
|
||||||
|
|
||||||
OutboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) {
|
|
||||||
this.webSocketCloseListener = webSocketCloseListener;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
|
|
||||||
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
|
|
||||||
webSocketCloseListener.handleWebSocketClosedByClient(closeWebSocketFrame.statusCode());
|
|
||||||
}
|
|
||||||
|
|
||||||
super.write(context, message, promise);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,80 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2024 Signal Messenger, LLC
|
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
|
||||||
*/
|
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
|
||||||
import io.netty.channel.ChannelDuplexHandler;
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
|
||||||
import io.netty.channel.ChannelPromise;
|
|
||||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
|
||||||
import io.netty.util.ReferenceCountUtil;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A TypedNoiseChannelDuplexHandler is a convenience {@link ChannelDuplexHandler} that can be inserted in a pipeline
|
|
||||||
* after a successful websocket handshake. It expects inbound messages to be {@link BinaryWebSocketFrame}s and outbound
|
|
||||||
* messages to be bytes.
|
|
||||||
*/
|
|
||||||
abstract class TypedNoiseChannelDuplexHandler extends ChannelDuplexHandler {
|
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(TypedNoiseChannelDuplexHandler.class);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle an inbound message. The frame will be automatically released after the method is finished running.
|
|
||||||
*
|
|
||||||
* @param context The current {@link ChannelHandlerContext}
|
|
||||||
* @param frameBytes A {@link ByteBuf} extracted from a {@link BinaryWebSocketFrame} that contains a complete noise
|
|
||||||
* packet
|
|
||||||
* @throws Exception
|
|
||||||
*/
|
|
||||||
abstract void handleInbound(final ChannelHandlerContext context, ByteBuf frameBytes) throws Exception;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle an outbound byte message. The message will be automatically released after the method is finished running.
|
|
||||||
*
|
|
||||||
* @param context The current {@link ChannelHandlerContext}
|
|
||||||
* @param bytes The bytes to write
|
|
||||||
* @throws Exception
|
|
||||||
*/
|
|
||||||
abstract void handleOutbound(final ChannelHandlerContext context, final ByteBuf bytes,
|
|
||||||
final ChannelPromise promise) throws Exception;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
|
||||||
try {
|
|
||||||
if (message instanceof BinaryWebSocketFrame frame) {
|
|
||||||
handleInbound(context, frame.content());
|
|
||||||
} else {
|
|
||||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
|
||||||
// error
|
|
||||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
ReferenceCountUtil.release(message);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
|
|
||||||
throws Exception {
|
|
||||||
if (message instanceof ByteBuf serverResponse) {
|
|
||||||
try {
|
|
||||||
handleOutbound(context, serverResponse, promise);
|
|
||||||
} finally {
|
|
||||||
ReferenceCountUtil.release(serverResponse);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (!(message instanceof WebSocketFrame)) {
|
|
||||||
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
|
|
||||||
// get issued in response to exceptions)
|
|
||||||
log.warn("Unexpected object in pipeline: {}", message);
|
|
||||||
}
|
|
||||||
context.write(message, promise);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,18 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
interface WebSocketCloseListener {
|
|
||||||
|
|
||||||
WebSocketCloseListener NOOP_LISTENER = new WebSocketCloseListener() {
|
|
||||||
@Override
|
|
||||||
public void handleWebSocketClosedByClient(final int statusCode) {
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void handleWebSocketClosedByServer(final int statusCode) {
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void handleWebSocketClosedByClient(int statusCode);
|
|
||||||
|
|
||||||
void handleWebSocketClosedByServer(int statusCode);
|
|
||||||
}
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
public class ClientErrorHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(ClientErrorHandler.class);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
|
||||||
|
log.error("Caught inbound error in client; closing connection", cause);
|
||||||
|
context.channel().close();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
|
||||||
|
|
||||||
|
public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) {
|
||||||
|
|
||||||
|
public enum CloseReason {
|
||||||
|
SERVER_CLOSED,
|
||||||
|
NOISE_ERROR,
|
||||||
|
NOISE_HANDSHAKE_ERROR,
|
||||||
|
AUTHENTICATION_ERROR,
|
||||||
|
INTERNAL_SERVER_ERROR,
|
||||||
|
UNKNOWN
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum CloseInitiator {
|
||||||
|
SERVER,
|
||||||
|
CLIENT
|
||||||
|
}
|
||||||
|
|
||||||
|
public static CloseFrameEvent fromWebsocketCloseFrame(
|
||||||
|
CloseWebSocketFrame closeWebSocketFrame,
|
||||||
|
CloseInitiator closeInitiator) {
|
||||||
|
final CloseReason code = switch (closeWebSocketFrame.statusCode()) {
|
||||||
|
case 4003 -> CloseReason.NOISE_ERROR;
|
||||||
|
case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR;
|
||||||
|
case 4002 -> CloseReason.AUTHENTICATION_ERROR;
|
||||||
|
case 1011 -> CloseReason.INTERNAL_SERVER_ERROR;
|
||||||
|
case 1012 -> CloseReason.SERVER_CLOSED;
|
||||||
|
default -> CloseReason.UNKNOWN;
|
||||||
|
};
|
||||||
|
return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static CloseFrameEvent fromNoiseDirectErrorFrame(
|
||||||
|
NoiseDirectProtos.Error noiseDirectError,
|
||||||
|
CloseInitiator closeInitiator) {
|
||||||
|
final CloseReason code = switch (noiseDirectError.getType()) {
|
||||||
|
case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR;
|
||||||
|
case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR;
|
||||||
|
case UNAVAILABLE -> CloseReason.SERVER_CLOSED;
|
||||||
|
case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR;
|
||||||
|
case AUTHENTICATION_ERROR -> CloseReason.AUTHENTICATION_ERROR;
|
||||||
|
case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN;
|
||||||
|
};
|
||||||
|
return new CloseFrameEvent(code, closeInitiator, noiseDirectError.getMessage());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,136 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
|
import io.netty.bootstrap.Bootstrap;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandler;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.channel.ChannelInitializer;
|
||||||
|
import io.netty.channel.socket.SocketChannel;
|
||||||
|
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import java.net.SocketAddress;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import javax.annotation.Nullable;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote
|
||||||
|
* gRPC server.
|
||||||
|
* <p>
|
||||||
|
* This handler waits until the first gRPC client message is ready and then establishes a connection with the remote
|
||||||
|
* gRPC server. It expects the provided remoteHandlerStack to emit a {@link ReadyForNoiseHandshakeEvent} when the remote
|
||||||
|
* connection is ready for its first inbound payload, and to emit a {@link NoiseClientHandshakeCompleteEvent} when the
|
||||||
|
* handshake is finished.
|
||||||
|
*/
|
||||||
|
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final List<ChannelHandler> remoteHandlerStack;
|
||||||
|
@Nullable
|
||||||
|
private final AuthenticatedDevice authenticatedDevice;
|
||||||
|
|
||||||
|
private final SocketAddress remoteServerAddress;
|
||||||
|
// If provided, will be sent with the payload in the noise handshake
|
||||||
|
private final byte[] fastOpenRequest;
|
||||||
|
|
||||||
|
private final List<Object> pendingReads = new ArrayList<>();
|
||||||
|
|
||||||
|
private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
|
||||||
|
|
||||||
|
EstablishRemoteConnectionHandler(
|
||||||
|
final List<ChannelHandler> remoteHandlerStack,
|
||||||
|
@Nullable final AuthenticatedDevice authenticatedDevice,
|
||||||
|
final SocketAddress remoteServerAddress,
|
||||||
|
@Nullable byte[] fastOpenRequest) {
|
||||||
|
this.remoteHandlerStack = remoteHandlerStack;
|
||||||
|
this.authenticatedDevice = authenticatedDevice;
|
||||||
|
this.remoteServerAddress = remoteServerAddress;
|
||||||
|
this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerAdded(final ChannelHandlerContext localContext) {
|
||||||
|
new Bootstrap()
|
||||||
|
.channel(NioSocketChannel.class)
|
||||||
|
.group(localContext.channel().eventLoop())
|
||||||
|
.handler(new ChannelInitializer<SocketChannel>() {
|
||||||
|
@Override
|
||||||
|
protected void initChannel(final SocketChannel channel) throws Exception {
|
||||||
|
|
||||||
|
for (ChannelHandler handler : remoteHandlerStack) {
|
||||||
|
channel.pipeline().addLast(handler);
|
||||||
|
}
|
||||||
|
channel.pipeline()
|
||||||
|
.addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
|
||||||
|
throws Exception {
|
||||||
|
switch (event) {
|
||||||
|
case ReadyForNoiseHandshakeEvent ignored ->
|
||||||
|
remoteContext.writeAndFlush(Unpooled.wrappedBuffer(initialPayload()))
|
||||||
|
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||||
|
case NoiseClientHandshakeCompleteEvent(Optional<byte[]> fastResponse) -> {
|
||||||
|
remoteContext.pipeline()
|
||||||
|
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
|
||||||
|
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
|
||||||
|
|
||||||
|
// If there was a payload response on the handshake, write it back to our gRPC client
|
||||||
|
fastResponse.ifPresent(plaintext ->
|
||||||
|
localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext)));
|
||||||
|
|
||||||
|
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
|
||||||
|
pendingReads.forEach(localContext::fireChannelRead);
|
||||||
|
pendingReads.clear();
|
||||||
|
localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
|
||||||
|
}
|
||||||
|
default -> {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
super.userEventTriggered(remoteContext, event);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.addLast(new ClientErrorHandler());
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.connect(remoteServerAddress)
|
||||||
|
.addListener((ChannelFutureListener) future -> {
|
||||||
|
if (future.isSuccess()) {
|
||||||
|
// Close the local connection if the remote channel closes and vice versa
|
||||||
|
future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close());
|
||||||
|
localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
|
||||||
|
} else {
|
||||||
|
localContext.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||||
|
pendingReads.add(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||||
|
pendingReads.forEach(ReferenceCountUtil::release);
|
||||||
|
pendingReads.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
private byte[] initialPayload() {
|
||||||
|
if (authenticatedDevice == null) {
|
||||||
|
return fastOpenRequest;
|
||||||
|
}
|
||||||
|
|
||||||
|
final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length);
|
||||||
|
bb.putLong(authenticatedDevice.accountIdentifier().getMostSignificantBits());
|
||||||
|
bb.putLong(authenticatedDevice.accountIdentifier().getLeastSignificantBits());
|
||||||
|
bb.put(authenticatedDevice.deviceId());
|
||||||
|
bb.put(fastOpenRequest);
|
||||||
|
bb.flip();
|
||||||
|
return bb.array();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
|
||||||
|
public record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {}
|
|
@ -1,4 +1,4 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
|
@ -2,7 +2,7 @@
|
||||||
* Copyright 2024 Signal Messenger, LLC
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
|
@ -12,6 +12,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
import io.netty.handler.codec.ByteToMessageDecoder;
|
import io.netty.handler.codec.ByteToMessageDecoder;
|
||||||
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HexFormat;
|
import java.util.HexFormat;
|
||||||
|
@ -27,12 +28,12 @@ import java.util.stream.Stream;
|
||||||
* Once an entire request has been buffered, the handler will remove itself from the pipeline and emit a
|
* Once an entire request has been buffered, the handler will remove itself from the pipeline and emit a
|
||||||
* {@link FastOpenRequestBufferedEvent}
|
* {@link FastOpenRequestBufferedEvent}
|
||||||
*/
|
*/
|
||||||
class Http2Buffering {
|
public class Http2Buffering {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a pipeline handler that consumes serialized HTTP/2 ByteBufs and emits a fast-open request
|
* Create a pipeline handler that consumes serialized HTTP/2 ByteBufs and emits a fast-open request
|
||||||
*/
|
*/
|
||||||
static ChannelInboundHandler handler() {
|
public static ChannelInboundHandler handler() {
|
||||||
return new Http2PrefaceHandler();
|
return new Http2PrefaceHandler();
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
* Copyright 2024 Signal Messenger, LLC
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||||
|
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
public class NoiseClientHandshakeHandler extends ChannelDuplexHandler {
|
||||||
|
|
||||||
|
private final NoiseClientHandshakeHelper handshakeHelper;
|
||||||
|
|
||||||
|
public NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper) {
|
||||||
|
this.handshakeHelper = handshakeHelper;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||||
|
if (msg instanceof ByteBuf plaintextHandshakePayload) {
|
||||||
|
final byte[] payloadBytes = ByteBufUtil.getBytes(plaintextHandshakePayload,
|
||||||
|
plaintextHandshakePayload.readerIndex(), plaintextHandshakePayload.readableBytes(),
|
||||||
|
false);
|
||||||
|
final byte[] handshakeMessage = handshakeHelper.write(payloadBytes);
|
||||||
|
ctx.write(Unpooled.wrappedBuffer(handshakeMessage), promise);
|
||||||
|
} else {
|
||||||
|
ctx.write(msg, promise);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message)
|
||||||
|
throws NoiseHandshakeException {
|
||||||
|
if (message instanceof ByteBuf frame) {
|
||||||
|
try {
|
||||||
|
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame));
|
||||||
|
final Optional<byte[]> fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload);
|
||||||
|
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
|
||||||
|
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse));
|
||||||
|
} finally {
|
||||||
|
frame.release();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
context.fireChannelRead(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||||
|
handshakeHelper.destroy();
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,7 +2,7 @@
|
||||||
* Copyright 2024 Signal Messenger, LLC
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
import com.southernstorm.noise.protocol.HandshakeState;
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
|
@ -11,6 +11,8 @@ import javax.crypto.BadPaddingException;
|
||||||
import javax.crypto.ShortBufferException;
|
import javax.crypto.ShortBufferException;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||||
|
|
||||||
public class NoiseClientHandshakeHelper {
|
public class NoiseClientHandshakeHelper {
|
||||||
|
|
||||||
|
@ -22,7 +24,7 @@ public class NoiseClientHandshakeHelper {
|
||||||
this.handshakeState = handshakeState;
|
this.handshakeState = handshakeState;
|
||||||
}
|
}
|
||||||
|
|
||||||
static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) {
|
public static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) {
|
||||||
try {
|
try {
|
||||||
final HandshakeState state = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
|
final HandshakeState state = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
|
||||||
state.getLocalKeyPair().setPrivateKey(clientStaticKey.getPrivateKey().serialize(), 0);
|
state.getLocalKeyPair().setPrivateKey(clientStaticKey.getPrivateKey().serialize(), 0);
|
||||||
|
@ -34,7 +36,7 @@ public class NoiseClientHandshakeHelper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) {
|
public static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) {
|
||||||
try {
|
try {
|
||||||
final HandshakeState state = new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
|
final HandshakeState state = new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
|
||||||
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
|
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
|
||||||
|
@ -45,7 +47,7 @@ public class NoiseClientHandshakeHelper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
byte[] write(final byte[] requestPayload) throws ShortBufferException {
|
public byte[] write(final byte[] requestPayload) throws ShortBufferException {
|
||||||
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeKeysLength() + requestPayload.length + 16];
|
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeKeysLength() + requestPayload.length + 16];
|
||||||
handshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length);
|
handshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length);
|
||||||
return initiateHandshakeMessage;
|
return initiateHandshakeMessage;
|
||||||
|
@ -60,7 +62,7 @@ public class NoiseClientHandshakeHelper {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException {
|
public byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException {
|
||||||
// Don't process additional messages if the handshake failed and we're just waiting to close
|
// Don't process additional messages if the handshake failed and we're just waiting to close
|
||||||
if (handshakeState.getAction() != HandshakeState.READ_MESSAGE) {
|
if (handshakeState.getAction() != HandshakeState.READ_MESSAGE) {
|
||||||
throw new NoiseHandshakeException("Received message with handshake state " + handshakeState.getAction());
|
throw new NoiseHandshakeException("Received message with handshake state " + handshakeState.getAction());
|
||||||
|
@ -83,11 +85,11 @@ public class NoiseClientHandshakeHelper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CipherStatePair split() {
|
public CipherStatePair split() {
|
||||||
return this.handshakeState.split();
|
return this.handshakeState.split();
|
||||||
}
|
}
|
||||||
|
|
||||||
void destroy() {
|
public void destroy() {
|
||||||
this.handshakeState.destroy();
|
this.handshakeState.destroy();
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
import com.southernstorm.noise.protocol.CipherState;
|
import com.southernstorm.noise.protocol.CipherState;
|
||||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
|
@ -8,8 +8,6 @@ import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.ChannelDuplexHandler;
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelPromise;
|
import io.netty.channel.ChannelPromise;
|
||||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
@ -17,7 +15,7 @@ import org.slf4j.LoggerFactory;
|
||||||
/**
|
/**
|
||||||
* A Noise transport handler manages a bidirectional Noise session after a handshake has completed.
|
* A Noise transport handler manages a bidirectional Noise session after a handshake has completed.
|
||||||
*/
|
*/
|
||||||
class NoiseClientTransportHandler extends ChannelDuplexHandler {
|
public class NoiseClientTransportHandler extends ChannelDuplexHandler {
|
||||||
|
|
||||||
private final CipherStatePair cipherStatePair;
|
private final CipherStatePair cipherStatePair;
|
||||||
|
|
||||||
|
@ -30,19 +28,19 @@ class NoiseClientTransportHandler extends ChannelDuplexHandler {
|
||||||
@Override
|
@Override
|
||||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
try {
|
try {
|
||||||
if (message instanceof BinaryWebSocketFrame frame) {
|
if (message instanceof ByteBuf frame) {
|
||||||
final CipherState cipherState = cipherStatePair.getReceiver();
|
final CipherState cipherState = cipherStatePair.getReceiver();
|
||||||
|
|
||||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||||
// We'll need to copy it to a heap buffer.
|
// We'll need to copy it to a heap buffer.
|
||||||
final byte[] noiseBuffer = ByteBufUtil.getBytes(frame.content());
|
final byte[] noiseBuffer = ByteBufUtil.getBytes(frame);
|
||||||
|
|
||||||
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
|
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
|
||||||
final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length);
|
final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length);
|
||||||
|
|
||||||
context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength));
|
context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength));
|
||||||
} else {
|
} else {
|
||||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
// Anything except binary frames should have been filtered out of the pipeline by now; treat this as an
|
||||||
// error
|
// error
|
||||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||||
}
|
}
|
||||||
|
@ -69,16 +67,13 @@ class NoiseClientTransportHandler extends ChannelDuplexHandler {
|
||||||
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
||||||
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
||||||
|
|
||||||
context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
|
context.write(Unpooled.wrappedBuffer(noiseBuffer), promise);
|
||||||
} finally {
|
} finally {
|
||||||
ReferenceCountUtil.release(plaintext);
|
ReferenceCountUtil.release(plaintext);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!(message instanceof WebSocketFrame)) {
|
// Clients only write ByteBufs or close the connection on errors, so any other message is unexpected
|
||||||
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
|
|
||||||
// get issued in response to exceptions)
|
|
||||||
log.warn("Unexpected object in pipeline: {}", message);
|
log.warn("Unexpected object in pipeline: {}", message);
|
||||||
}
|
|
||||||
context.write(message, promise);
|
context.write(message, promise);
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,354 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
|
import io.netty.bootstrap.ServerBootstrap;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
|
import io.netty.channel.*;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.local.LocalChannel;
|
||||||
|
import io.netty.channel.local.LocalServerChannel;
|
||||||
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
|
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
||||||
|
import io.netty.handler.codec.MessageToMessageCodec;
|
||||||
|
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||||
|
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
|
||||||
|
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||||
|
import io.netty.handler.codec.http.HttpClientCodec;
|
||||||
|
import io.netty.handler.codec.http.HttpHeaders;
|
||||||
|
import java.net.SocketAddress;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.security.cert.X509Certificate;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
|
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
|
||||||
|
import io.netty.handler.ssl.SslContextBuilder;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrameCodec;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.websocket.WebsocketPayloadCodec;
|
||||||
|
|
||||||
|
import javax.net.ssl.SSLException;
|
||||||
|
|
||||||
|
public class NoiseTunnelClient implements AutoCloseable {
|
||||||
|
|
||||||
|
private final CompletableFuture<CloseFrameEvent> closeEventFuture;
|
||||||
|
private final ServerBootstrap serverBootstrap;
|
||||||
|
private Channel serverChannel;
|
||||||
|
|
||||||
|
public static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
|
||||||
|
public static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
|
||||||
|
|
||||||
|
public enum FramingType {
|
||||||
|
WEBSOCKET,
|
||||||
|
NOISE_DIRECT
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
|
||||||
|
final SocketAddress remoteServerAddress;
|
||||||
|
NioEventLoopGroup eventLoopGroup;
|
||||||
|
ECPublicKey serverPublicKey;
|
||||||
|
|
||||||
|
FramingType framingType = FramingType.WEBSOCKET;
|
||||||
|
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
|
||||||
|
HttpHeaders headers = new DefaultHttpHeaders();
|
||||||
|
|
||||||
|
boolean authenticated = false;
|
||||||
|
ECKeyPair ecKeyPair = null;
|
||||||
|
UUID accountIdentifier = null;
|
||||||
|
byte deviceId = 0x00;
|
||||||
|
boolean useTls;
|
||||||
|
X509Certificate trustedServerCertificate = null;
|
||||||
|
Supplier<HAProxyMessage> proxyMessageSupplier = null;
|
||||||
|
|
||||||
|
public Builder(
|
||||||
|
final SocketAddress remoteServerAddress,
|
||||||
|
final NioEventLoopGroup eventLoopGroup,
|
||||||
|
final ECPublicKey serverPublicKey) {
|
||||||
|
this.remoteServerAddress = remoteServerAddress;
|
||||||
|
this.eventLoopGroup = eventLoopGroup;
|
||||||
|
this.serverPublicKey = serverPublicKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
|
||||||
|
this.authenticated = true;
|
||||||
|
this.accountIdentifier = accountIdentifier;
|
||||||
|
this.deviceId = deviceId;
|
||||||
|
this.ecKeyPair = ecKeyPair;
|
||||||
|
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setWebsocketUri(final URI websocketUri) {
|
||||||
|
this.websocketUri = websocketUri;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setUseTls(X509Certificate trustedServerCertificate) {
|
||||||
|
this.useTls = true;
|
||||||
|
this.trustedServerCertificate = trustedServerCertificate;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setProxyMessageSupplier(Supplier<HAProxyMessage> proxyMessageSupplier) {
|
||||||
|
this.proxyMessageSupplier = proxyMessageSupplier;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setHeaders(final HttpHeaders headers) {
|
||||||
|
this.headers = headers;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setServerPublicKey(ECPublicKey serverPublicKey) {
|
||||||
|
this.serverPublicKey = serverPublicKey;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setFramingType(FramingType framingType) {
|
||||||
|
this.framingType = framingType;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NoiseTunnelClient build() {
|
||||||
|
final List<ChannelHandler> handlers = new ArrayList<>();
|
||||||
|
if (proxyMessageSupplier != null) {
|
||||||
|
handlers.addAll(List.of(HAProxyMessageEncoder.INSTANCE, new HAProxyMessageSender(proxyMessageSupplier)));
|
||||||
|
}
|
||||||
|
if (useTls) {
|
||||||
|
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
|
||||||
|
|
||||||
|
if (trustedServerCertificate != null) {
|
||||||
|
sslContextBuilder.trustManager(trustedServerCertificate);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
handlers.add(sslContextBuilder.build().newHandler(ByteBufAllocator.DEFAULT));
|
||||||
|
} catch (SSLException e) {
|
||||||
|
throw new IllegalArgumentException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handles the wrapping and unrwrapping the framing layer (websockets or noisedirect)
|
||||||
|
handlers.addAll(switch (framingType) {
|
||||||
|
case WEBSOCKET -> websocketHandlerStack(websocketUri, headers);
|
||||||
|
case NOISE_DIRECT -> noiseDirectHandlerStack(authenticated);
|
||||||
|
});
|
||||||
|
|
||||||
|
final NoiseClientHandshakeHelper helper = authenticated
|
||||||
|
? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair)
|
||||||
|
: NoiseClientHandshakeHelper.NK(serverPublicKey);
|
||||||
|
|
||||||
|
handlers.add(new NoiseClientHandshakeHandler(helper));
|
||||||
|
|
||||||
|
// Whenever the framing layer sends or receives a close frame, it will emit a CloseFrameEvent and we'll save off
|
||||||
|
// information about why the connection was closed.
|
||||||
|
final UserEventFuture<CloseFrameEvent> closeEventHandler = new UserEventFuture<>(CloseFrameEvent.class);
|
||||||
|
handlers.add(closeEventHandler);
|
||||||
|
|
||||||
|
final NoiseTunnelClient client =
|
||||||
|
new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, fastOpenRequest -> new EstablishRemoteConnectionHandler(
|
||||||
|
handlers,
|
||||||
|
authenticated ? new AuthenticatedDevice(accountIdentifier, deviceId) : null,
|
||||||
|
remoteServerAddress,
|
||||||
|
fastOpenRequest));
|
||||||
|
client.start();
|
||||||
|
return client;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private NoiseTunnelClient(NioEventLoopGroup eventLoopGroup,
|
||||||
|
CompletableFuture<CloseFrameEvent> closeEventFuture,
|
||||||
|
Function<byte[], EstablishRemoteConnectionHandler> handler) {
|
||||||
|
|
||||||
|
this.closeEventFuture = closeEventFuture;
|
||||||
|
this.serverBootstrap = new ServerBootstrap()
|
||||||
|
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
|
||||||
|
.channel(LocalServerChannel.class)
|
||||||
|
.group(eventLoopGroup)
|
||||||
|
.childHandler(new ChannelInitializer<LocalChannel>() {
|
||||||
|
@Override
|
||||||
|
protected void initChannel(final LocalChannel localChannel) {
|
||||||
|
localChannel.pipeline()
|
||||||
|
// We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the
|
||||||
|
// stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put
|
||||||
|
// in the handshake.
|
||||||
|
.addLast(Http2Buffering.handler())
|
||||||
|
// Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At
|
||||||
|
// that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually
|
||||||
|
// connect to the remote service
|
||||||
|
.addLast(new ChannelInboundHandlerAdapter() {
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
|
||||||
|
if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) {
|
||||||
|
byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest());
|
||||||
|
requestBufferedEvent.fastOpenRequest().release();
|
||||||
|
ctx.pipeline().addLast(handler.apply(fastOpenRequest));
|
||||||
|
}
|
||||||
|
super.userEventTriggered(ctx, evt);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.addLast(new ClientErrorHandler());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class UserEventFuture<T> extends ChannelInboundHandlerAdapter {
|
||||||
|
private final CompletableFuture<T> future = new CompletableFuture<>();
|
||||||
|
private final Class<T> cls;
|
||||||
|
|
||||||
|
UserEventFuture(Class<T> cls) {
|
||||||
|
this.cls = cls;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
|
||||||
|
if (cls.isInstance(evt)) {
|
||||||
|
future.complete((T) evt);
|
||||||
|
}
|
||||||
|
ctx.fireUserEventTriggered(evt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public LocalAddress getLocalAddress() {
|
||||||
|
return (LocalAddress) serverChannel.localAddress();
|
||||||
|
}
|
||||||
|
|
||||||
|
private NoiseTunnelClient start() {
|
||||||
|
serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel();
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws InterruptedException {
|
||||||
|
serverChannel.close().await();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return A future that completes when a close frame is observed
|
||||||
|
*/
|
||||||
|
public CompletableFuture<CloseFrameEvent> closeFrameFuture() {
|
||||||
|
return closeEventFuture;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<ChannelHandler> noiseDirectHandlerStack(boolean authenticated) {
|
||||||
|
return List.of(
|
||||||
|
new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2),
|
||||||
|
new NoiseDirectFrameCodec(),
|
||||||
|
new ChannelDuplexHandler() {
|
||||||
|
@Override
|
||||||
|
public void channelActive(ChannelHandlerContext ctx) {
|
||||||
|
ctx.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent());
|
||||||
|
ctx.fireChannelActive();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
|
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) {
|
||||||
|
try {
|
||||||
|
final NoiseDirectProtos.Error errorPayload =
|
||||||
|
NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||||
|
ctx.fireUserEventTriggered(
|
||||||
|
CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.SERVER));
|
||||||
|
} finally {
|
||||||
|
ReferenceCountUtil.release(msg);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ctx.fireChannelRead(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||||
|
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) {
|
||||||
|
final NoiseDirectProtos.Error errorPayload =
|
||||||
|
NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||||
|
ctx.fireUserEventTriggered(
|
||||||
|
CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT));
|
||||||
|
}
|
||||||
|
ctx.write(msg, promise);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
new MessageToMessageCodec<NoiseDirectFrame, ByteBuf>() {
|
||||||
|
boolean noiseHandshakeFinished = false;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void encode(final ChannelHandlerContext ctx, final ByteBuf msg, final List<Object> out) {
|
||||||
|
final NoiseDirectFrame.FrameType frameType = noiseHandshakeFinished
|
||||||
|
? NoiseDirectFrame.FrameType.DATA
|
||||||
|
: (authenticated ? NoiseDirectFrame.FrameType.IK_HANDSHAKE : NoiseDirectFrame.FrameType.NK_HANDSHAKE);
|
||||||
|
noiseHandshakeFinished = true;
|
||||||
|
out.add(new NoiseDirectFrame(frameType, msg.retain()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void decode(final ChannelHandlerContext ctx, final NoiseDirectFrame msg,
|
||||||
|
final List<Object> out) {
|
||||||
|
out.add(msg.content().retain());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<ChannelHandler> websocketHandlerStack(final URI websocketUri, final HttpHeaders headers) {
|
||||||
|
return List.of(
|
||||||
|
new HttpClientCodec(),
|
||||||
|
new HttpObjectAggregator(Noise.MAX_PACKET_LEN),
|
||||||
|
// Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we
|
||||||
|
// want to react to them on our own, we need to catch them before they hit that handler.
|
||||||
|
new ChannelInboundHandlerAdapter() {
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
|
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
|
||||||
|
context.fireUserEventTriggered(
|
||||||
|
CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.SERVER));
|
||||||
|
}
|
||||||
|
|
||||||
|
super.channelRead(context, message);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
new WebSocketClientProtocolHandler(websocketUri,
|
||||||
|
WebSocketVersion.V13,
|
||||||
|
null,
|
||||||
|
false,
|
||||||
|
headers,
|
||||||
|
Noise.MAX_PACKET_LEN,
|
||||||
|
10_000),
|
||||||
|
new ChannelOutboundHandlerAdapter() {
|
||||||
|
@Override
|
||||||
|
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
|
||||||
|
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
|
||||||
|
context.fireUserEventTriggered(
|
||||||
|
CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.CLIENT));
|
||||||
|
}
|
||||||
|
super.write(context, message, promise);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
new ChannelInboundHandlerAdapter() {
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||||
|
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
|
||||||
|
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
|
||||||
|
context.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
context.fireUserEventTriggered(event);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
new WebsocketPayloadCodec());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
|
public record ReadyForNoiseHandshakeEvent() {
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||||
|
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
|
import java.util.concurrent.Executor;
|
||||||
|
|
||||||
|
class DirectNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
|
||||||
|
private NoiseDirectTunnelServer noiseDirectTunnelServer;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void start(
|
||||||
|
final NioEventLoopGroup eventLoopGroup,
|
||||||
|
final Executor delegatedTaskExecutor,
|
||||||
|
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||||
|
final ClientPublicKeysManager clientPublicKeysManager,
|
||||||
|
final ECKeyPair serverKeyPair,
|
||||||
|
final LocalAddress authenticatedGrpcServerAddress,
|
||||||
|
final LocalAddress anonymousGrpcServerAddress,
|
||||||
|
final String recognizedProxySecret) throws Exception {
|
||||||
|
|
||||||
|
noiseDirectTunnelServer = new NoiseDirectTunnelServer(0,
|
||||||
|
eventLoopGroup,
|
||||||
|
grpcClientConnectionManager,
|
||||||
|
clientPublicKeysManager,
|
||||||
|
serverKeyPair,
|
||||||
|
authenticatedGrpcServerAddress,
|
||||||
|
anonymousGrpcServerAddress);
|
||||||
|
noiseDirectTunnelServer.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void stop() throws InterruptedException {
|
||||||
|
noiseDirectTunnelServer.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) {
|
||||||
|
return new NoiseTunnelClient
|
||||||
|
.Builder(noiseDirectTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey)
|
||||||
|
.setFramingType(NoiseTunnelClient.FramingType.NOISE_DIRECT);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
@ -18,6 +18,7 @@ import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
|
||||||
|
|
||||||
class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest {
|
class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
|
@ -0,0 +1,237 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
import io.grpc.ManagedChannel;
|
||||||
|
import io.grpc.Status;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
|
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||||
|
import io.netty.handler.codec.http.HttpHeaders;
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.http.HttpClient;
|
||||||
|
import java.net.http.HttpRequest;
|
||||||
|
import java.net.http.HttpResponse;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.security.KeyFactory;
|
||||||
|
import java.security.KeyStore;
|
||||||
|
import java.security.PrivateKey;
|
||||||
|
import java.security.SecureRandom;
|
||||||
|
import java.security.cert.CertificateFactory;
|
||||||
|
import java.security.cert.X509Certificate;
|
||||||
|
import java.security.spec.PKCS8EncodedKeySpec;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.ExecutionException;
|
||||||
|
import java.util.concurrent.Executor;
|
||||||
|
import java.util.concurrent.TimeoutException;
|
||||||
|
import javax.net.ssl.SSLContext;
|
||||||
|
import javax.net.ssl.TrustManagerFactory;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||||
|
import org.signal.chat.rpc.GetRequestAttributesResponse;
|
||||||
|
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
|
class TlsWebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
|
||||||
|
private NoiseWebSocketTunnelServer tlsNoiseWebSocketTunnelServer;
|
||||||
|
private X509Certificate serverTlsCertificate;
|
||||||
|
|
||||||
|
|
||||||
|
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
|
||||||
|
// They were generated with:
|
||||||
|
//
|
||||||
|
// ```shell
|
||||||
|
// openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost"
|
||||||
|
// ```
|
||||||
|
private static final String SERVER_CERTIFICATE = """
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw
|
||||||
|
FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx
|
||||||
|
MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA
|
||||||
|
IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV
|
||||||
|
jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq
|
||||||
|
SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME
|
||||||
|
GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG
|
||||||
|
SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw
|
||||||
|
XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi
|
||||||
|
iOr9sHiO8Rn2u0xRKgU5Ig==
|
||||||
|
-----END CERTIFICATE-----
|
||||||
|
""";
|
||||||
|
|
||||||
|
// BEGIN/END PRIVATE KEY header/footer removed for easier parsing
|
||||||
|
private static final String SERVER_PRIVATE_KEY = """
|
||||||
|
MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj
|
||||||
|
kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd
|
||||||
|
PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA
|
||||||
|
O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo=
|
||||||
|
""";
|
||||||
|
@Override
|
||||||
|
protected void start(
|
||||||
|
final NioEventLoopGroup eventLoopGroup,
|
||||||
|
final Executor delegatedTaskExecutor,
|
||||||
|
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||||
|
final ClientPublicKeysManager clientPublicKeysManager,
|
||||||
|
final ECKeyPair serverKeyPair,
|
||||||
|
final LocalAddress authenticatedGrpcServerAddress,
|
||||||
|
final LocalAddress anonymousGrpcServerAddress,
|
||||||
|
final String recognizedProxySecret) throws Exception {
|
||||||
|
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
|
||||||
|
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
|
||||||
|
new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8)));
|
||||||
|
final PrivateKey serverTlsPrivateKey;
|
||||||
|
final KeyFactory keyFactory = KeyFactory.getInstance("EC");
|
||||||
|
serverTlsPrivateKey =
|
||||||
|
keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
|
||||||
|
tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
|
||||||
|
new X509Certificate[]{serverTlsCertificate},
|
||||||
|
serverTlsPrivateKey,
|
||||||
|
eventLoopGroup,
|
||||||
|
delegatedTaskExecutor,
|
||||||
|
grpcClientConnectionManager,
|
||||||
|
clientPublicKeysManager,
|
||||||
|
serverKeyPair,
|
||||||
|
authenticatedGrpcServerAddress,
|
||||||
|
anonymousGrpcServerAddress,
|
||||||
|
recognizedProxySecret);
|
||||||
|
tlsNoiseWebSocketTunnelServer.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void stop() throws InterruptedException {
|
||||||
|
tlsNoiseWebSocketTunnelServer.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup,
|
||||||
|
final ECPublicKey serverPublicKey) {
|
||||||
|
return new NoiseTunnelClient
|
||||||
|
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey)
|
||||||
|
.setUseTls(serverTlsCertificate);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void getRequestAttributes() throws InterruptedException {
|
||||||
|
final String remoteAddress = "4.5.6.7";
|
||||||
|
final String acceptLanguage = "en";
|
||||||
|
final String userAgent = "Signal-Desktop/1.2.3 Linux";
|
||||||
|
|
||||||
|
final HttpHeaders headers = new DefaultHttpHeaders()
|
||||||
|
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||||
|
.add("X-Forwarded-For", remoteAddress)
|
||||||
|
.add("Accept-Language", acceptLanguage)
|
||||||
|
.add("User-Agent", userAgent);
|
||||||
|
|
||||||
|
try (final NoiseTunnelClient client = anonymous().setHeaders(headers).build()) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
|
||||||
|
|
||||||
|
assertEquals(remoteAddress, response.getRemoteAddress());
|
||||||
|
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
|
||||||
|
assertEquals(userAgent, response.getUserAgent());
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAuthenticatedToAnonymousService() throws InterruptedException, ExecutionException, TimeoutException {
|
||||||
|
try (final NoiseTunnelClient client = authenticated()
|
||||||
|
.setWebsocketUri(NoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI)
|
||||||
|
.build()) {
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAnonymousToAuthenticatedService() throws InterruptedException, ExecutionException, TimeoutException {
|
||||||
|
try (final NoiseTunnelClient client = anonymous()
|
||||||
|
.setWebsocketUri(NoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI)
|
||||||
|
.build()) {
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void rejectIllegalRequests() throws Exception {
|
||||||
|
|
||||||
|
final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
|
||||||
|
keyStore.load(null, null);
|
||||||
|
keyStore.setCertificateEntry("tunnel", serverTlsCertificate);
|
||||||
|
|
||||||
|
final TrustManagerFactory trustManagerFactory =
|
||||||
|
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
|
||||||
|
|
||||||
|
trustManagerFactory.init(keyStore);
|
||||||
|
|
||||||
|
final SSLContext sslContext = SSLContext.getInstance("TLS");
|
||||||
|
sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
|
||||||
|
|
||||||
|
final URI authenticatedUri =
|
||||||
|
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/authenticated",
|
||||||
|
null, null);
|
||||||
|
|
||||||
|
final URI incorrectUri =
|
||||||
|
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/incorrect",
|
||||||
|
null, null);
|
||||||
|
|
||||||
|
try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
|
||||||
|
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
|
||||||
|
.uri(authenticatedUri)
|
||||||
|
.PUT(HttpRequest.BodyPublishers.ofString("test"))
|
||||||
|
.build(),
|
||||||
|
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||||
|
"Non-GET requests should not be allowed");
|
||||||
|
|
||||||
|
assertEquals(426, httpClient.send(HttpRequest.newBuilder()
|
||||||
|
.GET()
|
||||||
|
.uri(authenticatedUri)
|
||||||
|
.build(),
|
||||||
|
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||||
|
"GET requests without upgrade headers should not be allowed");
|
||||||
|
|
||||||
|
assertEquals(404, httpClient.send(HttpRequest.newBuilder()
|
||||||
|
.GET()
|
||||||
|
.uri(incorrectUri)
|
||||||
|
.build(),
|
||||||
|
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||||
|
"GET requests to unrecognized URIs should not be allowed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
|
import java.util.concurrent.Executor;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
|
class WebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
|
||||||
|
private NoiseWebSocketTunnelServer plaintextNoiseWebSocketTunnelServer;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void start(
|
||||||
|
final NioEventLoopGroup eventLoopGroup,
|
||||||
|
final Executor delegatedTaskExecutor,
|
||||||
|
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||||
|
final ClientPublicKeysManager clientPublicKeysManager,
|
||||||
|
final ECKeyPair serverKeyPair,
|
||||||
|
final LocalAddress authenticatedGrpcServerAddress,
|
||||||
|
final LocalAddress anonymousGrpcServerAddress,
|
||||||
|
final String recognizedProxySecret) throws Exception {
|
||||||
|
plaintextNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
eventLoopGroup,
|
||||||
|
delegatedTaskExecutor,
|
||||||
|
grpcClientConnectionManager,
|
||||||
|
clientPublicKeysManager,
|
||||||
|
serverKeyPair,
|
||||||
|
authenticatedGrpcServerAddress,
|
||||||
|
anonymousGrpcServerAddress,
|
||||||
|
recognizedProxySecret);
|
||||||
|
plaintextNoiseWebSocketTunnelServer.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void stop() throws InterruptedException {
|
||||||
|
plaintextNoiseWebSocketTunnelServer.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) {
|
||||||
|
return new NoiseTunnelClient
|
||||||
|
.Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||||
|
@ -19,6 +19,7 @@ import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.ValueSource;
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
|
||||||
|
|
||||||
class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest {
|
class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
@ -34,6 +34,10 @@ import org.junit.jupiter.params.provider.Arguments;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.signal.libsignal.protocol.ecc.Curve;
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
|
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler;
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
Loading…
Reference in New Issue