From 542422b7b8a3cec3a641417c88304e9deaa785c3 Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Mon, 24 Jun 2024 17:20:42 -0500 Subject: [PATCH] Replace XX/NX handshakes with IK/NK --- service/config/sample.yml | 1 - .../textsecuregcm/WhisperServerService.java | 1 - .../NoiseWebSocketTunnelConfiguration.java | 1 - .../net/AbstractNoiseHandshakeHandler.java | 124 ----- .../textsecuregcm/grpc/net/ErrorHandler.java | 4 +- .../EstablishLocalGrpcConnectionHandler.java | 6 +- .../grpc/net/HandshakePattern.java | 21 + .../grpc/net/NoiseAnonymousHandler.java | 34 ++ .../grpc/net/NoiseAuthenticatedHandler.java | 96 ++++ .../textsecuregcm/grpc/net/NoiseHandler.java | 195 ++++++++ .../grpc/net/NoiseHandshakeCompleteEvent.java | 13 - .../grpc/net/NoiseHandshakeHelper.java | 124 +++++ .../net/NoiseIdentityDeterminedEvent.java | 13 + .../grpc/net/NoiseNXHandshakeHandler.java | 40 -- .../grpc/net/NoiseWebSocketTunnelServer.java | 3 +- .../grpc/net/NoiseXXHandshakeHandler.java | 178 ------- .../WebsocketHandshakeCompleteHandler.java | 7 +- .../grpc/net/AbstractNoiseClientHandler.java | 94 ---- .../grpc/net/AbstractNoiseHandlerTest.java | 257 ++++++++++ .../AbstractNoiseHandshakeHandlerTest.java | 141 ------ .../net/EstablishRemoteConnectionHandler.java | 50 +- .../net/FastOpenRequestBufferedEvent.java | 9 + .../grpc/net/Http2Buffering.java | 184 +++++++ .../grpc/net/NoiseAnonymousHandlerTest.java | 108 +++++ .../net/NoiseAuthenticatedHandlerTest.java | 301 ++++++++++++ .../NoiseClientHandshakeCompleteEvent.java | 15 + .../grpc/net/NoiseClientHandshakeHandler.java | 55 +++ .../grpc/net/NoiseClientHandshakeHelper.java | 93 ++++ .../net/NoiseClientTransportHandler.java} | 35 +- .../grpc/net/NoiseHandshakeHelperTest.java | 66 +++ .../net/NoiseNXClientHandshakeHandler.java | 47 -- .../grpc/net/NoiseNXHandshakeHandlerTest.java | 84 ---- .../grpc/net/NoiseTransportHandlerTest.java | 135 ------ .../grpc/net/NoiseWebSocketTunnelClient.java | 141 ++++-- ...eWebSocketTunnelServerIntegrationTest.java | 203 +++----- .../net/NoiseXXClientHandshakeHandler.java | 89 ---- .../grpc/net/NoiseXXHandshakeHandlerTest.java | 453 ------------------ .../net/TypedNoiseChannelDuplexHandler.java | 80 ++++ ...WebsocketHandshakeCompleteHandlerTest.java | 7 +- .../resources/config/test-secrets-bundle.yml | 4 +- service/src/test/resources/config/test.yml | 1 - 41 files changed, 1902 insertions(+), 1611 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelper.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/FastOpenRequestBufferedEvent.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/Http2Buffering.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeCompleteEvent.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHandler.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHelper.java rename service/src/{main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseTransportHandler.java => test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientTransportHandler.java} (79%) create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseTransportHandlerTest.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/TypedNoiseChannelDuplexHandler.java diff --git a/service/config/sample.yml b/service/config/sample.yml index 4800d9caa..2ee21cecc 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -478,7 +478,6 @@ noiseTunnel: tlsKeyStoreEntryAlias: example.com tlsKeyStorePassword: secret://noiseTunnel.tlsKeyStorePassword noiseStaticPrivateKey: secret://noiseTunnel.noiseStaticPrivateKey - noiseRootPublicKeySignature: ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret externalRequestFilter: diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index ad57324a6..9a4cc18f6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -933,7 +933,6 @@ public class WhisperServerService extends ApplicationThe Noise Protocol Framework - */ -abstract class AbstractNoiseHandshakeHandler extends ChannelInboundHandlerAdapter { - - private final ECKeyPair ecKeyPair; - private final byte[] publicKeySignature; - - private final HandshakeState handshakeState; - - private static final int EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH = 32; - - /** - * Constructs a new Noise handler with the given static server keys and static public key signature. The static public - * key must be signed by a trusted root private key whose public key is known to and trusted by authenticating - * clients. - * - * @param noiseProtocolName the name of the Noise protocol implemented by this handshake handler - * @param ecKeyPair the static key pair for this server - * @param publicKeySignature an Ed25519 signature of the raw bytes of the static public key - */ - AbstractNoiseHandshakeHandler(final String noiseProtocolName, - final ECKeyPair ecKeyPair, - final byte[] publicKeySignature) { - - this.ecKeyPair = ecKeyPair; - this.publicKeySignature = publicKeySignature; - - try { - this.handshakeState = new HandshakeState(noiseProtocolName, HandshakeState.RESPONDER); - } catch (final NoSuchAlgorithmException e) { - throw new AssertionError("Unsupported Noise algorithm: " + noiseProtocolName, e); - } - } - - protected HandshakeState getHandshakeState() { - return handshakeState; - } - - /** - * Handles an initial ephemeral key message from a client, advancing the handshake state and sending the server's - * static keys to the client. Both XX and NX patterns begin with a client sending its ephemeral key to the server. - * Clients must not include an additional payload with their ephemeral key message. The server's reply contains its - * static keys along with an Ed25519 signature of its public static key by a trusted root key. - * - * @param context the channel handler context for this message - * @param frame the websocket frame containing the ephemeral key message - * - * @throws NoiseHandshakeException if the ephemeral key message from the client was not of the expected size or if a - * general Noise encryption error occurred - */ - protected void handleEphemeralKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame) - throws NoiseHandshakeException { - - if (frame.content().readableBytes() != EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH) { - throw new NoiseHandshakeException("Unexpected ephemeral key message length"); - } - - // Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is - // making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key - // from the supplied private key (and will in fact overwrite any previously-set public key when setting a private - // key), so we just set the private key here. - handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0); - handshakeState.start(); - - // The initial message from the client should just include a plaintext ephemeral key with no payload. The frame is - // coming off the wire and so will be in a direct buffer that doesn't have a backing array. - final byte[] ephemeralKeyMessage = ByteBufUtil.getBytes(frame.content()); - frame.content().readBytes(ephemeralKeyMessage); - - try { - handshakeState.readMessage(ephemeralKeyMessage, 0, ephemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0); - } catch (final ShortBufferException e) { - // This should never happen since we're checking the length of the frame up front - throw new NoiseHandshakeException("Unexpected client payload"); - } catch (final BadPaddingException e) { - // It turns out this should basically never happen because (a) we're not using padding and (b) the "bad AEAD tag" - // subclass of a bad padding exception can only happen if we have some AD to check, which we don't for an - // ephemeral-key-only message - throw new NoiseHandshakeException("Invalid keys"); - } - - // Send our key material and public key signature back to the client; this buffer will include: - // - // - A 32-byte plaintext ephemeral key - // - A 32-byte encrypted static key - // - A 16-byte AEAD tag for the static key - // - The public key signature payload - // - A 16-byte AEAD tag for the payload - final byte[] keyMaterial = new byte[32 + 32 + 16 + publicKeySignature.length + 16]; - - try { - handshakeState.writeMessage(keyMaterial, 0, publicKeySignature, 0, publicKeySignature.length); - - context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(keyMaterial))) - .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); - } catch (final ShortBufferException e) { - // This should never happen for messages of known length that we control - throw new AssertionError("Key material buffer was too short for message", e); - } - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - handshakeState.destroy(); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java index 1828ee689..1ad4e9f59 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java @@ -9,6 +9,7 @@ import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import javax.crypto.BadPaddingException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; /** * An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a @@ -38,7 +39,7 @@ class ErrorHandler extends ChannelInboundHandlerAdapter { @Override public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) { if (websocketHandshakeComplete) { - final WebSocketCloseStatus webSocketCloseStatus = switch (cause) { + final WebSocketCloseStatus webSocketCloseStatus = switch (ExceptionUtils.unwrap(cause)) { case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage()); case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated"); case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error"); @@ -51,6 +52,7 @@ class ErrorHandler extends ChannelInboundHandlerAdapter { context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus)) .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); } else { + log.debug("Error occurred before websocket handshake complete", cause); // We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful // way; just close the connection instead. context.close(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java index c8295a1c7..e06782c9e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java @@ -48,12 +48,12 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { @Override public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) { - if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) { + if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) { // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to // connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the // authenticated service. - final LocalAddress grpcServerAddress = noiseHandshakeCompleteEvent.authenticatedDevice().isPresent() + final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent() ? authenticatedGrpcServerAddress : anonymousGrpcServerAddress; @@ -72,7 +72,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { if (localChannelFuture.isSuccess()) { clientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(), remoteChannelContext.channel(), - noiseHandshakeCompleteEvent.authenticatedDevice()); + noiseIdentityDeterminedEvent.authenticatedDevice()); // Close the local connection if the remote channel closes and vice versa remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java new file mode 100644 index 000000000..d65ee230f --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java @@ -0,0 +1,21 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +enum HandshakePattern { + NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"), + IK("Noise_IK_25519_ChaChaPoly_BLAKE2b"); + + private final String protocol; + + public String protocol() { + return protocol; + } + + + HandshakePattern(String protocol) { + this.protocol = protocol; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java new file mode 100644 index 000000000..063a83a87 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java @@ -0,0 +1,34 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import org.signal.libsignal.protocol.ecc.ECKeyPair; + +/** + * A NoiseAnonymousHandler is a netty pipeline element that handles the responder side of an unauthenticated handshake + * and noise encryption/decryption. + *

+ * A noise NK handshake must be used for unauthenticated connections. Optionally, the initiator can also include an + * initial request in their payload. If provided, this allows the server to begin processing the request without an + * initial message delay (fast open). + *

+ * Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent} + * indicating that initiator connected anonymously. + */ +class NoiseAnonymousHandler extends NoiseHandler { + + public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) { + super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair)); + } + + @Override + CompletableFuture handleHandshakePayload(final ChannelHandlerContext context, + final Optional initiatorPublicKey, final ByteBuf handshakePayload) { + return CompletableFuture.completedFuture(new HandshakeResult( + handshakePayload, + Optional.empty() + )); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java new file mode 100644 index 000000000..b6df55e9d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java @@ -0,0 +1,96 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.ReferenceCountUtil; +import java.security.MessageDigest; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; + +/** + * A NoiseAuthenticatedHandler is a netty pipeline element that handles the responder side of an authenticated handshake + * and noise encryption/decryption. Authenticated handshakes are noise IK handshakes where the initiator's static public + * key is authenticated by the responder. + *

+ * The authenticated handshake requires the initiator to provide a payload with their first handshake message that + * includes their account identifier and device id in network byte-order. Optionally, the initiator can also include an + * initial request in their payload. If provided, this allows the server to begin processing the request without an + * initial message delay (fast open). + *

+ * +-----------------+----------------+------------------------+
+ * |    UUID (16)    |  deviceId (1)  |     request bytes (N)  |
+ * +-----------------+----------------+------------------------+
+ * 
+ *

+ * For a successful handshake, the static key provided in the handshake message must match the server's stored public + * key for the device identified by the provided ACI and deviceId. + *

+ * As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}. + */ +class NoiseAuthenticatedHandler extends NoiseHandler { + + private final ClientPublicKeysManager clientPublicKeysManager; + + NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager, + final ECKeyPair ecKeyPair) { + super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair)); + this.clientPublicKeysManager = clientPublicKeysManager; + } + + @Override + CompletableFuture handleHandshakePayload( + final ChannelHandlerContext context, + final Optional initiatorPublicKey, + final ByteBuf handshakePayload) throws NoiseHandshakeException { + if (handshakePayload.readableBytes() < 17) { + throw new NoiseHandshakeException("Invalid handshake payload"); + } + + final byte[] publicKeyFromClient = initiatorPublicKey + .orElseThrow(() -> new IllegalStateException("No remote public key")); + + // Advances the read index by 16 bytes + final UUID accountIdentifier = parseUUID(handshakePayload); + + // Advances the read index by 1 byte + final byte deviceId = handshakePayload.readByte(); + + final ByteBuf fastOpenRequest = handshakePayload.slice(); + return clientPublicKeysManager + .findPublicKey(accountIdentifier, deviceId) + .handleAsync((storedPublicKey, throwable) -> { + if (throwable != null) { + ReferenceCountUtil.release(fastOpenRequest); + throw ExceptionUtils.wrap(throwable); + } + final boolean valid = storedPublicKey + .map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes())) + .orElse(false); + if (!valid) { + throw ExceptionUtils.wrap(new ClientAuthenticationException()); + } + return new HandshakeResult( + fastOpenRequest, + Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))); + }, context.executor()); + } + + /** + * Parse a {@link UUID} out of bytes, advancing the readerIdx by 16 + * + * @param bytes The {@link ByteBuf} to read from + * @return The parsed UUID + * @throws NoiseHandshakeException If a UUID could not be parsed from bytes + */ + private UUID parseUUID(final ByteBuf bytes) throws NoiseHandshakeException { + if (bytes.readableBytes() < 16) { + throw new NoiseHandshakeException("Could not parse account identifier"); + } + return new UUID(bytes.readLong(), bytes.readLong()); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java new file mode 100644 index 000000000..c402bce24 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java @@ -0,0 +1,195 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.CipherState; +import com.southernstorm.noise.protocol.CipherStatePair; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.EmptyArrays; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; + +/** + * A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts + * inbound messages, and encrypts outbound messages + */ +abstract class NoiseHandler extends ChannelDuplexHandler { + + private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class); + + private enum State { + // Waiting for handshake to complete + HANDSHAKE, + // Can freely exchange encrypted noise messages on an established session + TRANSPORT, + // Finished with error + ERROR + } + + private final NoiseHandshakeHelper handshakeHelper; + + private State state = State.HANDSHAKE; + private CipherStatePair cipherStatePair; + + NoiseHandler(NoiseHandshakeHelper handshakeHelper) { + this.handshakeHelper = handshakeHelper; + } + + /** + * The result of processing an initiator handshake payload + * + * @param fastOpenRequest A fast-open request included in the handshake. If none was present, this should be an + * empty ByteBuf + * @param authenticatedDevice If present, the successfully authenticated initiator identity + */ + record HandshakeResult(ByteBuf fastOpenRequest, Optional authenticatedDevice) {} + + /** + * Parse and potentially authenticate the initiator handshake message + * + * @param context A {@link ChannelHandlerContext} + * @param initiatorPublicKey The initiator's static public key, if a handshake pattern that includes it was used + * @param handshakePayload The handshake payload provided in the initiator message + * @return A {@link HandshakeResult} that includes an authenticated device and a parsed fast-open request if one was + * present in the handshake payload. + * @throws NoiseHandshakeException If the handshake payload was invalid + * @throws ClientAuthenticationException If the initiatorPublicKey could not be authenticated + */ + abstract CompletableFuture handleHandshakePayload( + final ChannelHandlerContext context, + final Optional initiatorPublicKey, + final ByteBuf handshakePayload) throws NoiseHandshakeException, ClientAuthenticationException; + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + try { + if (message instanceof BinaryWebSocketFrame frame) { + // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array. + // We'll need to copy it to a heap buffer. + handleInboundMessage(context, ByteBufUtil.getBytes(frame.content())); + } else { + // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an + // error + throw new IllegalArgumentException("Unexpected message in pipeline: " + message); + } + } catch (Exception e) { + fail(context, e); + } finally { + ReferenceCountUtil.release(message); + } + } + + private void handleInboundMessage(final ChannelHandlerContext context, final byte[] frameBytes) + throws NoiseHandshakeException, ShortBufferException, BadPaddingException, ClientAuthenticationException { + switch (state) { + + // Got an initiator handshake message + case HANDSHAKE -> { + final ByteBuf payload = handshakeHelper.read(frameBytes); + handleHandshakePayload(context, handshakeHelper.remotePublicKey(), payload).whenCompleteAsync( + (result, throwable) -> { + if (state == State.ERROR) { + return; + } + if (throwable != null) { + fail(context, ExceptionUtils.unwrap(throwable)); + return; + } + context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(result.authenticatedDevice())); + + // Now that we've authenticated, write the handshake response + byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES); + context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage))) + .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + + // The handshake is complete. We can start intercepting read/write for noise encryption/decryption + this.state = State.TRANSPORT; + this.cipherStatePair = handshakeHelper.getHandshakeState().split(); + if (result.fastOpenRequest().isReadable()) { + // The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll + // encrypt the response when the server writes back through us + context.fireChannelRead(result.fastOpenRequest()); + } else { + ReferenceCountUtil.release(result.fastOpenRequest()); + } + }, context.executor()); + } + + // Got a client message that should be decrypted and forwarded + case TRANSPORT -> { + final CipherState cipherState = cipherStatePair.getReceiver(); + // Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer + final int plaintextLength = cipherState.decryptWithAd(null, + frameBytes, 0, + frameBytes, 0, + frameBytes.length); + + // Forward the decrypted plaintext along + context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength)); + } + + // The session is already in an error state, drop the message + case ERROR -> { + } + } + } + + /** + * Set the state to the error state (so subsequent messages fast-fail) and propagate the failure reason on the + * context + */ + private void fail(final ChannelHandlerContext context, final Throwable cause) { + this.state = State.ERROR; + context.fireExceptionCaught(cause); + } + + @Override + public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) + throws Exception { + if (message instanceof ByteBuf plaintext) { + try { + // TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames + final CipherState cipherState = cipherStatePair.getSender(); + final int plaintextLength = plaintext.readableBytes(); + + // We've read these bytes from a local connection; although that likely means they're backed by a heap array, the + // buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a + // mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC. + final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()]; + plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes()); + + // Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer + cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength); + + context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise); + + } finally { + ReferenceCountUtil.release(plaintext); + } + } else { + if (!(message instanceof WebSocketFrame)) { + // Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that + // get issued in response to exceptions) + log.warn("Unexpected object in pipeline: {}", message); + } + context.write(message, promise); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java deleted file mode 100644 index 5a2f1ae99..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java +++ /dev/null @@ -1,13 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import java.util.Optional; - -/** - * An event that indicates that a Noise handshake has completed, possibly authenticating a caller in the process. - * - * @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a - * type that performs authentication - */ -record NoiseHandshakeCompleteEvent(Optional authenticatedDevice) { -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelper.java new file mode 100644 index 000000000..e47120fb4 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelper.java @@ -0,0 +1,124 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.HandshakeState; +import com.southernstorm.noise.protocol.Noise; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import java.security.NoSuchAlgorithmException; +import java.util.Optional; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import org.signal.libsignal.protocol.ecc.ECKeyPair; + +/** + * Helper for the responder of a 2-message handshake with a pre-shared responder static key + */ +class NoiseHandshakeHelper { + + private final static int AEAD_TAG_LENGTH = 16; + private final static int KEY_LENGTH = 32; + + private final HandshakePattern handshakePattern; + private final ECKeyPair serverStaticKeyPair; + private final HandshakeState handshakeState; + + NoiseHandshakeHelper(HandshakePattern handshakePattern, ECKeyPair serverStaticKeyPair) { + this.handshakePattern = handshakePattern; + this.serverStaticKeyPair = serverStaticKeyPair; + try { + this.handshakeState = new HandshakeState(handshakePattern.protocol(), HandshakeState.RESPONDER); + } catch (final NoSuchAlgorithmException e) { + throw new AssertionError("Unsupported Noise algorithm: " + handshakePattern.protocol(), e); + } + } + + /** + * Get the length of the initiator's keys + * + * @return length of the handshake message sent by the remote party (the initiator) not including the payload + */ + private int initiatorHandshakeMessageKeyLength() { + return switch (handshakePattern) { + // ephemeral key, static key (encrypted), AEAD tag for static key + case IK -> KEY_LENGTH + KEY_LENGTH + AEAD_TAG_LENGTH; + // ephemeral key only + case NK -> KEY_LENGTH; + }; + } + + HandshakeState getHandshakeState() { + return this.handshakeState; + } + + ByteBuf read(byte[] remoteHandshakeMessage) throws NoiseHandshakeException { + if (handshakeState.getAction() != HandshakeState.NO_ACTION) { + throw new NoiseHandshakeException("Cannot send more data before handshake is complete"); + } + + // Length for an empty payload + final int minMessageLength = initiatorHandshakeMessageKeyLength() + AEAD_TAG_LENGTH; + if (remoteHandshakeMessage.length < minMessageLength || remoteHandshakeMessage.length > Noise.MAX_PACKET_LEN) { + throw new NoiseHandshakeException("Unexpected ephemeral key message length"); + } + + final int payloadLength = remoteHandshakeMessage.length - initiatorHandshakeMessageKeyLength() - AEAD_TAG_LENGTH; + + // Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is + // making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key + // from the supplied private key (and will in fact overwrite any previously-set public key when setting a private + // key), so we just set the private key here. + handshakeState.getLocalKeyPair().setPrivateKey(serverStaticKeyPair.getPrivateKey().serialize(), 0); + handshakeState.start(); + + int payloadBytesRead; + + try { + payloadBytesRead = handshakeState.readMessage(remoteHandshakeMessage, 0, remoteHandshakeMessage.length, + remoteHandshakeMessage, 0); + } catch (final ShortBufferException e) { + // This should never happen since we're checking the length of the frame up front + throw new NoiseHandshakeException("Unexpected client payload"); + } catch (final BadPaddingException e) { + // We aren't using padding but may get this error if the AEAD tag does not match the encrypted client static key + // or payload + throw new NoiseHandshakeException("Invalid keys or payload"); + } + if (payloadBytesRead != payloadLength) { + throw new NoiseHandshakeException( + "Unexpected payload length, required " + payloadLength + " but got " + payloadBytesRead); + } + return Unpooled.wrappedBuffer(remoteHandshakeMessage, 0, payloadBytesRead); + } + + byte[] write(byte[] payload) { + if (handshakeState.getAction() != HandshakeState.WRITE_MESSAGE) { + throw new IllegalStateException("Cannot send data before handshake is complete"); + } + + // Currently only support handshake patterns where the server static key is known + // Send our ephemeral key and the response to the initiator with the encrypted payload + final byte[] response = new byte[KEY_LENGTH + payload.length + AEAD_TAG_LENGTH]; + try { + int written = handshakeState.writeMessage(response, 0, payload, 0, payload.length); + if (written != response.length) { + throw new IllegalStateException("Unexpected handshake response length"); + } + return response; + } catch (final ShortBufferException e) { + // This should never happen for messages of known length that we control + throw new IllegalStateException("Key material buffer was too short for message", e); + } + } + + Optional remotePublicKey() { + return Optional.ofNullable(handshakeState.getRemotePublicKey()).map(dhstate -> { + final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()]; + handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0); + return publicKeyFromClient; + }); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java new file mode 100644 index 000000000..57616c0f3 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java @@ -0,0 +1,13 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import java.util.Optional; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; + +/** + * An event that indicates that an identity of a noise handshake initiator has been determined. If the initiator is + * connecting anonymously, the identity is empty, otherwise it will be present and already authenticated. + * + * @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a + * type that performs authentication + */ +record NoiseIdentityDeterminedEvent(Optional authenticatedDevice) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java deleted file mode 100644 index f59bdcd45..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java +++ /dev/null @@ -1,40 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import java.util.Optional; -import io.netty.util.ReferenceCountUtil; -import org.signal.libsignal.protocol.ecc.ECKeyPair; - -/** - * A Noise NX handler handles the responder side of a Noise NX handshake. - */ -class NoiseNXHandshakeHandler extends AbstractNoiseHandshakeHandler { - - static final String NOISE_PROTOCOL_NAME = "Noise_NX_25519_ChaChaPoly_BLAKE2b"; - - NoiseNXHandshakeHandler(final ECKeyPair ecKeyPair, final byte[] publicKeySignature) { - super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature); - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - if (message instanceof BinaryWebSocketFrame frame) { - try { - handleEphemeralKeyMessage(context, frame); - } finally { - frame.release(); - } - - // All we need to do is accept the client's ephemeral key and send our own static keys; after that, we can consider - // the handshake complete - context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty())); - context.pipeline().replace(NoiseNXHandshakeHandler.this, null, new NoiseTransportHandler(getHandshakeState().split())); - } else { - // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an - // error - ReferenceCountUtil.release(message); - throw new IllegalArgumentException("Unexpected message in pipeline: " + message); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java index ef0be249c..537a2baff 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServer.java @@ -53,7 +53,6 @@ public class NoiseWebSocketTunnelServer implements Managed { final ClientConnectionManager clientConnectionManager, final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair, - final byte[] publicKeySignature, final LocalAddress authenticatedGrpcServerAddress, final LocalAddress anonymousGrpcServerAddress, final String recognizedProxySecret) throws SSLException { @@ -107,7 +106,7 @@ public class NoiseWebSocketTunnelServer implements Managed { .addLast(new RejectUnsupportedMessagesHandler()) // The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once // a WebSocket handshake has been completed - .addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature, recognizedProxySecret)) + .addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret)) // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler // once the Noise handshake has completed .addLast(new EstablishLocalGrpcConnectionHandler(clientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress)) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java deleted file mode 100644 index c8a944791..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java +++ /dev/null @@ -1,178 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import com.southernstorm.noise.protocol.HandshakeState; -import io.netty.buffer.ByteBufUtil; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.util.ReferenceCountUtil; -import java.security.MessageDigest; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.UUID; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; -import org.whispersystems.textsecuregcm.util.UUIDUtil; - -/** - * A Noise XX handler handles the responder side of a Noise XX handshake. This implementation expects clients to send - * identifying information (an account identifier and device ID) as an additional payload when sending its static key - * material. It compares the static public key against the stored public key for the identified device asynchronously, - * buffering traffic from the client until the authentication check completes. - */ -class NoiseXXHandshakeHandler extends AbstractNoiseHandshakeHandler { - - private final ClientPublicKeysManager clientPublicKeysManager; - - private AuthenticationState authenticationState = AuthenticationState.GET_EPHEMERAL_KEY; - - private final List pendingInboundFrames = new ArrayList<>(); - - static final String NOISE_PROTOCOL_NAME = "Noise_XX_25519_ChaChaPoly_BLAKE2b"; - - // When the client sends its static key message, we expect: - // - // - A 32-byte encrypted static public key - // - A 16-byte AEAD tag for the static key - // - 17 bytes of identity data in the message payload (a UUID and a one-byte device ID) - // - A 16-byte AEAD tag for the identity payload - private static final int EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH = 81; - - private enum AuthenticationState { - GET_EPHEMERAL_KEY, - GET_STATIC_KEY, - CHECK_PUBLIC_KEY, - ERROR - } - - public NoiseXXHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, - final ECKeyPair ecKeyPair, - final byte[] publicKeySignature) { - - super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature); - - this.clientPublicKeysManager = clientPublicKeysManager; - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - if (message instanceof BinaryWebSocketFrame frame) { - try { - switch (authenticationState) { - case GET_EPHEMERAL_KEY -> { - try { - handleEphemeralKeyMessage(context, frame); - authenticationState = AuthenticationState.GET_STATIC_KEY; - } finally { - frame.release(); - } - } - case GET_STATIC_KEY -> { - try { - handleStaticKey(context, frame); - authenticationState = AuthenticationState.CHECK_PUBLIC_KEY; - } finally { - frame.release(); - } - } - case CHECK_PUBLIC_KEY -> { - // Buffer any inbound traffic until we've finished checking the client's public key - pendingInboundFrames.add(frame); - } - case ERROR -> { - // If authentication has failed for any reason, just discard inbound traffic until the channel closes - frame.release(); - } - } - } catch (final ShortBufferException e) { - authenticationState = AuthenticationState.ERROR; - throw new NoiseHandshakeException("Unexpected payload length"); - } catch (final BadPaddingException e) { - authenticationState = AuthenticationState.ERROR; - throw new ClientAuthenticationException(); - } - } else { - // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an - // error - ReferenceCountUtil.release(message); - throw new IllegalArgumentException("Unexpected message in pipeline: " + message); - } - } - - private void handleStaticKey(final ChannelHandlerContext context, final BinaryWebSocketFrame frame) - throws NoiseHandshakeException, ShortBufferException, BadPaddingException { - - if (frame.content().readableBytes() != EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH) { - throw new NoiseHandshakeException("Unexpected client static key message length"); - } - - final HandshakeState handshakeState = getHandshakeState(); - - // The websocket frame will have come right off the wire, and so needs to be copied from a non-array-backed direct - // buffer into a heap buffer. - final byte[] staticKeyAndClientIdentityMessage = ByteBufUtil.getBytes(frame.content()); - - // The payload from the client should be a UUID (16 bytes) followed by a device ID (1 byte) - final byte[] payload = new byte[17]; - - final UUID accountIdentifier; - final byte deviceId; - - final int payloadBytesRead = handshakeState.readMessage(staticKeyAndClientIdentityMessage, - 0, staticKeyAndClientIdentityMessage.length, payload, 0); - - if (payloadBytesRead != 17) { - throw new NoiseHandshakeException("Unexpected identity payload length"); - } - - try { - accountIdentifier = UUIDUtil.fromBytes(payload, 0); - } catch (final IllegalArgumentException e) { - throw new NoiseHandshakeException("Could not parse account identifier"); - } - - deviceId = payload[16]; - - // Verify the identity of the caller by comparing the submitted static public key against the stored public key for - // the identified device - clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId) - .whenCompleteAsync((maybePublicKey, throwable) -> maybePublicKey.ifPresentOrElse(storedPublicKey -> { - final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()]; - handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0); - - if (MessageDigest.isEqual(publicKeyFromClient, storedPublicKey.getPublicKeyBytes())) { - context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent( - Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)))); - - context.pipeline().addAfter(context.name(), null, new NoiseTransportHandler(handshakeState.split())); - - // Flush any buffered reads - pendingInboundFrames.forEach(context::fireChannelRead); - pendingInboundFrames.clear(); - - context.pipeline().remove(NoiseXXHandshakeHandler.this); - } else { - // We found a key, but it doesn't match what the caller submitted - context.fireExceptionCaught(new ClientAuthenticationException()); - authenticationState = AuthenticationState.ERROR; - } - }, - () -> { - // We couldn't find a key for the identified account/device - context.fireExceptionCaught(new ClientAuthenticationException()); - authenticationState = AuthenticationState.ERROR; - }), - context.executor()); - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - super.handlerRemoved(context); - - pendingInboundFrames.forEach(BinaryWebSocketFrame::release); - pendingInboundFrames.clear(); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java index 89d2d7fe8..fbcfc5470 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandler.java @@ -31,7 +31,6 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { private final ClientPublicKeysManager clientPublicKeysManager; private final ECKeyPair ecKeyPair; - private final byte[] publicKeySignature; private final byte[] recognizedProxySecret; @@ -45,12 +44,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair, - final byte[] publicKeySignature, final String recognizedProxySecret) { this.clientPublicKeysManager = clientPublicKeysManager; this.ecKeyPair = ecKeyPair; - this.publicKeySignature = publicKeySignature; // The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex- // encoded value). We convert it into a byte array here for easier constant-time comparisons via @@ -84,10 +81,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) { case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> - new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature); + new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair); case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> - new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature); + new NoiseAnonymousHandler(ecKeyPair); default -> { // The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java deleted file mode 100644 index 63b114af5..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java +++ /dev/null @@ -1,94 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import com.southernstorm.noise.protocol.HandshakeState; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; -import java.security.NoSuchAlgorithmException; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import org.signal.libsignal.protocol.ecc.ECPublicKey; - -abstract class AbstractNoiseClientHandler extends ChannelInboundHandlerAdapter { - - private final ECPublicKey rootPublicKey; - - private final HandshakeState handshakeState; - - AbstractNoiseClientHandler(final ECPublicKey rootPublicKey) { - this.rootPublicKey = rootPublicKey; - - try { - handshakeState = new HandshakeState(getNoiseProtocolName(), HandshakeState.INITIATOR); - } catch (final NoSuchAlgorithmException e) { - throw new AssertionError("Unsupported Noise algorithm: " + getNoiseProtocolName(), e); - } - } - - protected abstract String getNoiseProtocolName(); - - protected abstract void startHandshake(); - - protected HandshakeState getHandshakeState() { - return handshakeState; - } - - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception { - if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) { - if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { - startHandshake(); - - final byte[] ephemeralKeyMessage = new byte[32]; - handshakeState.writeMessage(ephemeralKeyMessage, 0, null, 0, 0); - - context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessage))) - .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); - } - } - - super.userEventTriggered(context, event); - } - - protected void handleServerStaticKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame) - throws NoiseHandshakeException { - - // The frame is coming right off the wire and so will be a direct buffer not backed by an array; copy it to a heap - // buffer so we can Noise at it. - final ByteBuf keyMaterialBuffer = context.alloc().heapBuffer(frame.content().readableBytes()); - final byte[] serverPublicKeySignature = new byte[64]; - - try { - frame.content().readBytes(keyMaterialBuffer); - - final int payloadBytesRead = - handshakeState.readMessage(keyMaterialBuffer.array(), keyMaterialBuffer.arrayOffset(), keyMaterialBuffer.readableBytes(), serverPublicKeySignature, 0); - - if (payloadBytesRead != 64) { - throw new NoiseHandshakeException("Unexpected signature length"); - } - } catch (final ShortBufferException e) { - throw new NoiseHandshakeException("Unexpected signature length"); - } catch (final BadPaddingException e) { - throw new NoiseHandshakeException("Invalid keys"); - } finally { - keyMaterialBuffer.release(); - } - - final byte[] serverPublicKey = new byte[32]; - handshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); - - if (!rootPublicKey.verifySignature(serverPublicKey, serverPublicKeySignature)) { - throw new NoiseHandshakeException("Invalid server public key signature"); - } - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) throws Exception { - handshakeState.destroy(); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java new file mode 100644 index 000000000..aa8c2c6e0 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java @@ -0,0 +1,257 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.southernstorm.noise.protocol.CipherStatePair; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.ThreadLocalRandom; +import javax.annotation.Nullable; +import javax.crypto.AEADBadTagException; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; + +abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { + + protected ECKeyPair serverKeyPair; + + private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler; + + private EmbeddedChannel embeddedChannel; + + private static class PongHandler extends ChannelInboundHandlerAdapter { + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) { + try { + if (msg instanceof ByteBuf bb) { + if (new String(ByteBufUtil.getBytes(bb)).equals("ping")) { + ctx.writeAndFlush(Unpooled.wrappedBuffer("pong".getBytes())) + .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + } else { + throw new IllegalArgumentException("Unexpected message: " + new String(ByteBufUtil.getBytes(bb))); + } + } else { + throw new IllegalArgumentException("Unexpected message type: " + msg); + } + } finally { + ReferenceCountUtil.release(msg); + } + } + } + + private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { + + @Nullable + private NoiseIdentityDeterminedEvent handshakeCompleteEvent = null; + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) { + if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) { + handshakeCompleteEvent = noiseIdentityDeterminedEvent; + context.pipeline().addAfter(context.name(), null, new PongHandler()); + context.pipeline().remove(NoiseHandshakeCompleteHandler.class); + } else { + context.fireUserEventTriggered(event); + } + } + + @Nullable + public NoiseIdentityDeterminedEvent getHandshakeCompleteEvent() { + return handshakeCompleteEvent; + } + } + + @BeforeEach + void setUp() { + serverKeyPair = Curve.generateKeyPair(); + noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler(); + embeddedChannel = new EmbeddedChannel(getHandler(serverKeyPair), noiseHandshakeCompleteHandler); + } + + @AfterEach + void tearDown() { + embeddedChannel.close(); + } + + protected EmbeddedChannel getEmbeddedChannel() { + return embeddedChannel; + } + + @Nullable + protected NoiseIdentityDeterminedEvent getNoiseHandshakeCompleteEvent() { + return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent(); + } + + protected abstract ChannelHandler getHandler(final ECKeyPair serverKeyPair); + + protected abstract CipherStatePair doHandshake() throws Throwable; + + /** + * Read a message from the embedded channel and deserialize it with the provided client cipher state. If there are no + * waiting messages in the channel, return null. + */ + byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException { + final BinaryWebSocketFrame responseFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); + if (responseFrame == null) { + return null; + } + final byte[] plaintext = new byte[responseFrame.content().readableBytes() - 16]; + final int read = clientCipherPair.getReceiver().decryptWithAd(null, + ByteBufUtil.getBytes(responseFrame.content()), 0, + plaintext, 0, + responseFrame.content().readableBytes()); + assertEquals(read, plaintext.length); + return plaintext; + } + + + @Test + void handleInvalidInitialMessage() throws InterruptedException { + final byte[] contentBytes = new byte[17]; + ThreadLocalRandom.current().nextBytes(contentBytes); + + final ByteBuf content = Unpooled.wrappedBuffer(contentBytes); + + final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await(); + + assertFalse(writeFuture.isSuccess()); + assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause()); + assertEquals(0, content.refCnt()); + assertNull(getNoiseHandshakeCompleteEvent()); + } + + @Test + void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException { + final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7]; + + for (int i = 0; i < frames.length; i++) { + final byte[] contentBytes = new byte[17]; + ThreadLocalRandom.current().nextBytes(contentBytes); + + frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes)); + + embeddedChannel.writeOneInbound(frames[i]).await(); + } + + for (final BinaryWebSocketFrame frame : frames) { + assertEquals(0, frame.refCnt()); + } + + assertNull(getNoiseHandshakeCompleteEvent()); + } + + @Test + void handleNonWebSocketBinaryFrame() throws Throwable { + final byte[] contentBytes = new byte[17]; + ThreadLocalRandom.current().nextBytes(contentBytes); + + final ByteBuf message = Unpooled.wrappedBuffer(contentBytes); + + final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await(); + + assertFalse(writeFuture.isSuccess()); + assertInstanceOf(IllegalArgumentException.class, writeFuture.cause()); + assertEquals(0, message.refCnt()); + assertNull(getNoiseHandshakeCompleteEvent()); + + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + } + + @Test + void channelRead() throws Throwable { + final CipherStatePair clientCipherStatePair = doHandshake(); + final byte[] plaintext = "ping".getBytes(StandardCharsets.UTF_8); + final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()]; + clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length); + + final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext)); + assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess()); + assertEquals(0, ciphertextFrame.refCnt()); + + final byte[] response = readNextPlaintext(clientCipherStatePair); + assertArrayEquals("pong".getBytes(StandardCharsets.UTF_8), response); + } + + @Test + void channelReadBadCiphertext() throws Throwable { + doHandshake(); + final byte[] bogusCiphertext = new byte[32]; + io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext); + + final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext)); + final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await(); + + assertEquals(0, ciphertextFrame.refCnt()); + assertFalse(readCiphertextFuture.isSuccess()); + assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause()); + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + } + + @Test + void channelReadUnexpectedMessageType() throws Throwable { + doHandshake(); + final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await(); + + assertFalse(readFuture.isSuccess()); + assertInstanceOf(IllegalArgumentException.class, readFuture.cause()); + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + } + + @Test + void write() throws Throwable { + final CipherStatePair clientCipherStatePair = doHandshake(); + final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8); + final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext); + + final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer); + assertTrue(writePlaintextFuture.await().isSuccess()); + assertEquals(0, plaintextBuffer.refCnt()); + + final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); + assertNotNull(ciphertextFrame); + assertTrue(embeddedChannel.outboundMessages().isEmpty()); + + final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content()); + ciphertextFrame.release(); + + final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()]; + clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length); + + assertArrayEquals(plaintext, decryptedPlaintext); + } + + @Test + void writeUnexpectedMessageType() throws Throwable { + doHandshake(); + final Object unexpectedMessaged = new Object(); + + final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged); + assertTrue(writeFuture.await().isSuccess()); + + assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll()); + assertTrue(embeddedChannel.outboundMessages().isEmpty()); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java deleted file mode 100644 index 885b08645..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java +++ /dev/null @@ -1,141 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.util.ReferenceCountUtil; -import java.util.concurrent.ThreadLocalRandom; -import javax.annotation.Nullable; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.signal.libsignal.protocol.ecc.Curve; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; - -abstract class AbstractNoiseHandshakeHandlerTest extends AbstractLeakDetectionTest { - - private ECPublicKey rootPublicKey; - - private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler; - - private EmbeddedChannel embeddedChannel; - - private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { - - @Nullable - private NoiseHandshakeCompleteEvent handshakeCompleteEvent = null; - - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) { - if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) { - handshakeCompleteEvent = noiseHandshakeCompleteEvent; - } else { - context.fireUserEventTriggered(event); - } - } - - @Nullable - public NoiseHandshakeCompleteEvent getHandshakeCompleteEvent() { - return handshakeCompleteEvent; - } - } - - @BeforeEach - void setUp() { - final ECKeyPair rootKeyPair = Curve.generateKeyPair(); - final ECKeyPair serverKeyPair = Curve.generateKeyPair(); - - rootPublicKey = rootKeyPair.getPublicKey(); - - final byte[] serverPublicKeySignature = - rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()); - - noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler(); - - embeddedChannel = - new EmbeddedChannel(getHandler(serverKeyPair, serverPublicKeySignature), noiseHandshakeCompleteHandler); - } - - @AfterEach - void tearDown() { - embeddedChannel.close(); - } - - protected EmbeddedChannel getEmbeddedChannel() { - return embeddedChannel; - } - - protected ECPublicKey getRootPublicKey() { - return rootPublicKey; - } - - @Nullable - protected NoiseHandshakeCompleteEvent getNoiseHandshakeCompleteEvent() { - return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent(); - } - - protected abstract AbstractNoiseHandshakeHandler getHandler(final ECKeyPair serverKeyPair, final byte[] serverPublicKeySignature); - - @Test - void handleInvalidInitialMessage() throws InterruptedException { - final byte[] contentBytes = new byte[17]; - ThreadLocalRandom.current().nextBytes(contentBytes); - - final ByteBuf content = Unpooled.wrappedBuffer(contentBytes); - - final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await(); - - assertFalse(writeFuture.isSuccess()); - assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause()); - assertEquals(0, content.refCnt()); - assertNull(getNoiseHandshakeCompleteEvent()); - } - - @Test - void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException { - final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7]; - - for (int i = 0; i < frames.length; i++) { - final byte[] contentBytes = new byte[17]; - ThreadLocalRandom.current().nextBytes(contentBytes); - - frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes)); - - embeddedChannel.writeOneInbound(frames[i]).await(); - } - - for (final BinaryWebSocketFrame frame : frames) { - assertEquals(0, frame.refCnt()); - } - - assertNull(getNoiseHandshakeCompleteEvent()); - } - - @Test - void handleNonWebSocketBinaryFrame() throws InterruptedException { - final byte[] contentBytes = new byte[17]; - ThreadLocalRandom.current().nextBytes(contentBytes); - - final ByteBuf message = Unpooled.wrappedBuffer(contentBytes); - - final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await(); - - assertFalse(writeFuture.isSuccess()); - assertInstanceOf(IllegalArgumentException.class, writeFuture.cause()); - assertEquals(0, message.refCnt()); - assertNull(getNoiseHandshakeCompleteEvent()); - - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java index b6d1e05c7..66039f7b9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java @@ -2,6 +2,7 @@ package org.whispersystems.textsecuregcm.grpc.net; import com.southernstorm.noise.protocol.Noise; import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -19,6 +20,7 @@ import io.netty.handler.ssl.SslContextBuilder; import io.netty.util.ReferenceCountUtil; import java.net.SocketAddress; import java.net.URI; +import java.nio.ByteBuffer; import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.List; @@ -29,6 +31,10 @@ import javax.net.ssl.SSLException; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; +/** + * Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote + * gRPC server + */ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { private final boolean useTls; @@ -36,13 +42,15 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { private final URI websocketUri; private final boolean authenticated; @Nullable private final ECKeyPair ecKeyPair; - private final ECPublicKey rootPublicKey; + private final ECPublicKey serverPublicKey; @Nullable private final UUID accountIdentifier; private final byte deviceId; private final HttpHeaders headers; private final SocketAddress remoteServerAddress; private final WebSocketCloseListener webSocketCloseListener; @Nullable private final Supplier proxyMessageSupplier; + // If provided, will be sent with the payload in the noise handshake + private final byte[] fastOpenRequest; private final List pendingReads = new ArrayList<>(); @@ -54,26 +62,28 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { final URI websocketUri, final boolean authenticated, @Nullable final ECKeyPair ecKeyPair, - final ECPublicKey rootPublicKey, + final ECPublicKey serverPublicKey, @Nullable final UUID accountIdentifier, final byte deviceId, final HttpHeaders headers, final SocketAddress remoteServerAddress, final WebSocketCloseListener webSocketCloseListener, - @Nullable Supplier proxyMessageSupplier) { + @Nullable Supplier proxyMessageSupplier, + @Nullable byte[] fastOpenRequest) { this.useTls = useTls; this.trustedServerCertificate = trustedServerCertificate; this.websocketUri = websocketUri; this.authenticated = authenticated; this.ecKeyPair = ecKeyPair; - this.rootPublicKey = rootPublicKey; + this.serverPublicKey = serverPublicKey; this.accountIdentifier = accountIdentifier; this.deviceId = deviceId; this.headers = headers; this.remoteServerAddress = remoteServerAddress; this.webSocketCloseListener = webSocketCloseListener; this.proxyMessageSupplier = proxyMessageSupplier; + this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest; } @Override @@ -104,6 +114,10 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { channel.pipeline().addLast(sslContextBuilder.build().newHandler(channel.alloc())); } + final NoiseClientHandshakeHelper helper = authenticated + ? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair) + : NoiseClientHandshakeHelper.NK(serverPublicKey); + channel.pipeline() .addLast(new HttpClientCodec()) .addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN)) @@ -118,22 +132,24 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { Noise.MAX_PACKET_LEN, 10_000)) .addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener)) - .addLast(authenticated - ? new NoiseXXClientHandshakeHandler(ecKeyPair, rootPublicKey, accountIdentifier, deviceId) - : new NoiseNXClientHandshakeHandler(rootPublicKey)) + // Listens for a Websocket HANDSHAKE_COMPLETE and begins the noise handshake when it is done + .addLast(new NoiseClientHandshakeHandler(helper, initialPayload())) .addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event) throws Exception { - if (event instanceof NoiseHandshakeCompleteEvent) { + if (event instanceof NoiseClientHandshakeCompleteEvent handshakeCompleteEvent) { remoteContext.pipeline() .replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel())); - localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel())); + // If there was a payload response on the handshake, write it back to our gRPC client + handshakeCompleteEvent.fastResponse().ifPresent(plaintext -> + localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext))); + + // Forward any messages we got from our gRPC client, now will be proxied to the remote context pendingReads.forEach(localContext::fireChannelRead); pendingReads.clear(); - localContext.pipeline().remove(EstablishRemoteConnectionHandler.this); } @@ -165,4 +181,18 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { pendingReads.forEach(ReferenceCountUtil::release); pendingReads.clear(); } + + private byte[] initialPayload() { + if (!authenticated) { + return fastOpenRequest; + } + + final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length); + bb.putLong(accountIdentifier.getMostSignificantBits()); + bb.putLong(accountIdentifier.getLeastSignificantBits()); + bb.put(deviceId); + bb.put(fastOpenRequest); + bb.flip(); + return bb.array(); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/FastOpenRequestBufferedEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/FastOpenRequestBufferedEvent.java new file mode 100644 index 000000000..7f7b68df6 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/FastOpenRequestBufferedEvent.java @@ -0,0 +1,9 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.buffer.ByteBuf; + +record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/Http2Buffering.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/Http2Buffering.java new file mode 100644 index 000000000..24e2e84c4 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/Http2Buffering.java @@ -0,0 +1,184 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.util.ReferenceCountUtil; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HexFormat; +import java.util.List; +import java.util.stream.Stream; + +/** + * The noise tunnel streams bytes out of a gRPC client through noise and to a remote server. The server supports a "fast + * open" optimization where the client can send a request along with the noise handshake. There's no direct way to + * extract the request boundaries from the gRPC client's byte-stream, so {@link Http2Buffering#handler()} provides an + * inbound pipeline handler that will parse the byte-stream back into HTTP/2 frames and buffer the first request. + *

+ * Once an entire request has been buffered, the handler will remove itself from the pipeline and emit a + * {@link FastOpenRequestBufferedEvent} + */ +class Http2Buffering { + + /** + * Create a pipeline handler that consumes serialized HTTP/2 ByteBufs and emits a fast-open request + */ + static ChannelInboundHandler handler() { + return new Http2PrefaceHandler(); + } + + private Http2Buffering() { + } + + private static class Http2PrefaceHandler extends ChannelInboundHandlerAdapter { + + // https://www.rfc-editor.org/rfc/rfc7540.html#section-3.5 + private static final byte[] HTTP2_PREFACE = + HexFormat.of().parseHex("505249202a20485454502f322e300d0a0d0a534d0d0a0d0a"); + private final ByteBuf read = Unpooled.buffer(HTTP2_PREFACE.length, HTTP2_PREFACE.length); + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) { + if (message instanceof ByteBuf bb) { + bb.readBytes(read); + if (read.readableBytes() < HTTP2_PREFACE.length) { + // Copied the message into the read buffer, but haven't yet got a full HTTP2 preface. Wait for more. + return; + } + if (!Arrays.equals(read.array(), HTTP2_PREFACE)) { + throw new IllegalStateException("HTTP/2 stream must start with HTTP/2 preface"); + } + context.pipeline().replace(this, "http2frame1", new Http2LengthFieldFrameDecoder()); + context.pipeline().addAfter("http2frame1", "http2frame2", new Http2FrameDecoder()); + context.pipeline().addAfter("http2frame2", "http2frame3", new Http2FirstRequestHandler()); + context.fireChannelRead(bb); + } else { + throw new IllegalStateException("Unexpected message: " + message); + } + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + ReferenceCountUtil.release(read); + } + } + + + private record Http2Frame(ByteBuf bytes, FrameType type, boolean endStream) { + + private static final byte FLAG_END_STREAM = 0x01; + + enum FrameType { + SETTINGS, + HEADERS, + DATA, + WINDOW_UPDATE, + OTHER; + + static FrameType fromSerializedType(final byte type) { + return switch (type) { + case 0x00 -> Http2Frame.FrameType.DATA; + case 0x01 -> Http2Frame.FrameType.HEADERS; + case 0x04 -> Http2Frame.FrameType.SETTINGS; + case 0x08 -> Http2Frame.FrameType.WINDOW_UPDATE; + default -> Http2Frame.FrameType.OTHER; + }; + } + } + } + + /** + * Emit ByteBuf of entire HTTP/2 frame + */ + private static class Http2LengthFieldFrameDecoder extends LengthFieldBasedFrameDecoder { + + public Http2LengthFieldFrameDecoder() { + // Frames are 3 bytes of length, 6 bytes of other header, and then length bytes of payload + super(16 * 1024 * 1024, 0, 3, 6, 0); + } + } + + /** + * Parse the serialized Http/2 frames into {@link Http2Frame} objects + */ + private static class Http2FrameDecoder extends ByteToMessageDecoder { + + @Override + protected void decode(final ChannelHandlerContext ctx, final ByteBuf in, final List out) throws Exception { + // https://www.rfc-editor.org/rfc/rfc7540.html#section-4.1 + final Http2Frame.FrameType frameType = Http2Frame.FrameType.fromSerializedType(in.getByte(in.readerIndex() + 3)); + final boolean endStream = endStream(frameType, in.getByte(in.readerIndex() + 4)); + out.add(new Http2Frame(in.readBytes(in.readableBytes()), frameType, endStream)); + } + + boolean endStream(Http2Frame.FrameType frameType, byte flags) { + // A gRPC request are packed into HTTP/2 frames like: + // HEADERS frame | DATA frame 1 (endStream=0) | ... | DATA frame N (endstream=1) + // + // Our goal is to get an entire request buffered, so as soon as we see a DATA frame with the end stream flag set + // we have a whole request. Note that we could have pieces of multiple requests, but the only thing we care about + // is having at least one complete request. In total, we can expect something like: + // HTTP-preface | SETTINGS frame | Frames we don't care about ... | DATA (endstream=1) + // + // The connection isn't 'established' until the server has responded with their own SETTINGS frame with the ack + // bit set, but HTTP/2 allows the client to send frames before getting the ACK. + if (frameType == Http2Frame.FrameType.DATA) { + return (flags & Http2Frame.FLAG_END_STREAM) == Http2Frame.FLAG_END_STREAM; + } + + // In theory, at least. Unfortunately, the java gRPC client always waits for the HTTP/2 handshake to complete + // (which requires the server sending back the ack) before it actually sends any requests. So if we waited for a + // DATA frame, it would never come. The gRPC-java implementation always at least sends a WINDOW_UPDATE, so we + // might as well pack that in. + return frameType == Http2Frame.FrameType.WINDOW_UPDATE; + } + } + + /** + * Collect HTTP/2 frames until we get an entire "request" to send + */ + private static class Http2FirstRequestHandler extends ChannelInboundHandlerAdapter { + + final List pendingFrames = new ArrayList<>(); + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) { + if (message instanceof Http2Frame http2Frame) { + if (pendingFrames.isEmpty() && http2Frame.type != Http2Frame.FrameType.SETTINGS) { + throw new IllegalStateException( + "HTTP/2 stream must start with HTTP/2 SETTINGS frame, got " + http2Frame.type); + } + pendingFrames.add(http2Frame); + if (http2Frame.endStream) { + // We have a whole "request", emit the first request event and remove the http2 buffering handlers + final ByteBuf request = Unpooled.wrappedBuffer(Stream.concat( + Stream.of(Unpooled.wrappedBuffer(Http2PrefaceHandler.HTTP2_PREFACE)), + pendingFrames.stream().map(Http2Frame::bytes)) + .toArray(ByteBuf[]::new)); + pendingFrames.clear(); + context.pipeline().remove(Http2LengthFieldFrameDecoder.class); + context.pipeline().remove(Http2FrameDecoder.class); + context.pipeline().remove(this); + context.fireUserEventTriggered(new FastOpenRequestBufferedEvent(request)); + } + } else { + throw new IllegalStateException("Unexpected message: " + message); + } + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + pendingFrames.forEach(frame -> ReferenceCountUtil.release(frame.bytes())); + pendingFrames.clear(); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java new file mode 100644 index 000000000..51518a274 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java @@ -0,0 +1,108 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.southernstorm.noise.protocol.CipherStatePair; +import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import java.util.Optional; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import org.junit.jupiter.api.Test; +import org.signal.libsignal.protocol.ecc.ECKeyPair; + +class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest { + + @Override + protected NoiseAnonymousHandler getHandler(final ECKeyPair serverKeyPair) { + + return new NoiseAnonymousHandler(serverKeyPair); + } + + @Override + protected CipherStatePair doHandshake() throws Exception { + return doHandshake(new byte[0]); + } + + private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception { + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + + final HandshakeState clientHandshakeState = + new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR); + + clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0); + clientHandshakeState.start(); + + // Send initiator handshake message + + // 32 byte key, request payload, 16 byte AEAD tag + final int initiateHandshakeMessageLength = 32 + requestPayload.length + 16; + final byte[] initiateHandshakeMessage = new byte[initiateHandshakeMessageLength]; + assertEquals( + initiateHandshakeMessageLength, + clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length)); + + final BinaryWebSocketFrame initiateHandshakeFrame = new BinaryWebSocketFrame( + Unpooled.wrappedBuffer(initiateHandshakeMessage)); + + assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeFrame).await().isSuccess()); + assertEquals(0, initiateHandshakeFrame.refCnt()); + + embeddedChannel.runPendingTasks(); + + // Read responder handshake message + assertFalse(embeddedChannel.outboundMessages().isEmpty()); + final BinaryWebSocketFrame responderHandshakeFrame = (BinaryWebSocketFrame) + embeddedChannel.outboundMessages().poll(); + @SuppressWarnings("DataFlowIssue") final byte[] responderHandshakeBytes = + new byte[responderHandshakeFrame.content().readableBytes()]; + responderHandshakeFrame.content().readBytes(responderHandshakeBytes); + + // ephemeral key, empty encrypted payload AEAD tag + final byte[] handshakeResponsePayload = new byte[32 + 16]; + + assertEquals(0, + clientHandshakeState.readMessage( + responderHandshakeBytes, 0, responderHandshakeBytes.length, + handshakeResponsePayload, 0)); + + final byte[] serverPublicKey = new byte[32]; + clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); + assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes()); + + return clientHandshakeState.split(); + } + + @Test + void handleCompleteHandshakeWithRequest() throws ShortBufferException, BadPaddingException { + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + + assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class)); + + final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake("ping".getBytes())); + final byte[] response = readNextPlaintext(cipherStatePair); + assertArrayEquals(response, "pong".getBytes()); + + assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent()); + } + + @Test + void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException { + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + + assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class)); + + final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake(new byte[0])); + assertNull(readNextPlaintext(cipherStatePair)); + + assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java new file mode 100644 index 000000000..1aeb0f516 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java @@ -0,0 +1,301 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import com.southernstorm.noise.protocol.CipherStatePair; +import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.util.internal.EmptyArrays; +import java.nio.ByteBuffer; +import java.security.NoSuchAlgorithmException; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ThreadLocalRandom; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { + + private ClientPublicKeysManager clientPublicKeysManager; + private final ECKeyPair clientKeyPair = Curve.generateKeyPair(); + + @Override + @BeforeEach + void setUp() { + clientPublicKeysManager = mock(ClientPublicKeysManager.class); + + super.setUp(); + } + + @Override + protected NoiseAuthenticatedHandler getHandler(final ECKeyPair serverKeyPair) { + return new NoiseAuthenticatedHandler(clientPublicKeysManager, serverKeyPair); + } + + @Override + protected CipherStatePair doHandshake() throws Throwable { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); + return doHandshake(identityPayload(accountIdentifier, deviceId)); + } + + @Test + void handleCompleteHandshakeNoInitialRequest() throws Throwable { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); + + assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId)))); + + assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))), + getNoiseHandshakeCompleteEvent()); + } + + @Test + void handleCompleteHandshakeWithInitialRequest() throws Throwable { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); + + final ByteBuffer bb = ByteBuffer.allocate(17 + 4); + bb.put(identityPayload(accountIdentifier, deviceId)); + bb.put("ping".getBytes()); + + final byte[] response = readNextPlaintext(doHandshake(bb.array())); + assertEquals(response.length, 4); + assertEquals(new String(response), "pong"); + + assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))), + getNoiseHandshakeCompleteEvent()); + } + + @Test + void handleCompleteHandshakeMissingIdentityInformation() { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + + assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES)); + + verifyNoInteractions(clientPublicKeysManager); + + assertNull(getNoiseHandshakeCompleteEvent()); + + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class), + "Handshake handler should not remove self from pipeline after failed handshake"); + + assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class), + "Noise stream handler should not be added to pipeline after failed handshake"); + } + + @Test + void handleCompleteHandshakeMalformedIdentityInformation() { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + + // no deviceId byte + byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID()); + assertThrows(NoiseHandshakeException.class, () -> doHandshake(malformedIdentityPayload)); + + verifyNoInteractions(clientPublicKeysManager); + + assertNull(getNoiseHandshakeCompleteEvent()); + + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class), + "Handshake handler should not remove self from pipeline after failed handshake"); + + assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class), + "Noise stream handler should not be added to pipeline after failed handshake"); + } + + @Test + void handleCompleteHandshakeUnrecognizedDevice() { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId))); + + assertNull(getNoiseHandshakeCompleteEvent()); + + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class), + "Handshake handler should not remove self from pipeline after failed handshake"); + + assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class), + "Noise stream handler should not be added to pipeline after failed handshake"); + } + + @Test + void handleCompleteHandshakePublicKeyMismatch() { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey()))); + + assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId))); + + assertNull(getNoiseHandshakeCompleteEvent()); + + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class), + "Handshake handler should not remove self from pipeline after failed handshake"); + } + + @Test + void handleInvalidExtraWrites() throws NoSuchAlgorithmException, ShortBufferException, InterruptedException { + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + + final HandshakeState clientHandshakeState = clientHandshakeState(); + + final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>(); + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture); + + final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer( + initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId)))); + assertTrue(embeddedChannel.writeOneInbound(initiatorMessageFrame).await().isSuccess()); + + // While waiting for the public key, send another message + final ChannelFuture f = embeddedChannel.writeOneInbound( + new BinaryWebSocketFrame(Unpooled.wrappedBuffer(new byte[0]))).await(); + assertInstanceOf(NoiseHandshakeException.class, f.exceptionNow()); + + findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey())); + embeddedChannel.runPendingTasks(); + + // shouldn't return any response or error, we've already processed an error + embeddedChannel.checkException(); + assertNull(embeddedChannel.outboundMessages().poll()); + } + + private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException { + final HandshakeState clientHandshakeState = + new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR); + + clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0); + clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0); + clientHandshakeState.start(); + return clientHandshakeState; + } + + private byte[] initiatorHandshakeMessage(final HandshakeState clientHandshakeState, final byte[] payload) + throws ShortBufferException { + // Ephemeral key, encrypted static key, AEAD tag, encrypted payload, AEAD tag + final byte[] initiatorMessageBytes = new byte[32 + 32 + 16 + payload.length + 16]; + int written = clientHandshakeState.writeMessage(initiatorMessageBytes, 0, payload, 0, payload.length); + assertEquals(written, initiatorMessageBytes.length); + return initiatorMessageBytes; + } + + private byte[] readHandshakeResponse(final HandshakeState clientHandshakeState, final byte[] message) + throws ShortBufferException, BadPaddingException { + + // 32 byte ephemeral server key, 16 byte AEAD tag for encrypted payload + final int expectedResponsePayloadLength = message.length - 32 - 16; + final byte[] responsePayload = new byte[expectedResponsePayloadLength]; + final int responsePayloadLength = clientHandshakeState.readMessage(message, 0, message.length, responsePayload, 0); + assertEquals(expectedResponsePayloadLength, responsePayloadLength); + return responsePayload; + } + + private CipherStatePair doHandshake(final byte[] payload) throws Throwable { + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + + final HandshakeState clientHandshakeState = clientHandshakeState(); + final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload); + + final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame( + Unpooled.wrappedBuffer(initiatorMessage)); + final ChannelFuture await = embeddedChannel.writeOneInbound(initiatorMessageFrame).await(); + assertEquals(0, initiatorMessageFrame.refCnt()); + if (!await.isSuccess()) { + throw await.cause(); + } + + // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the + // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same + // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback + // and issue a "handshake complete" event. + embeddedChannel.runPendingTasks(); + + // rethrow if running the task caused an error + embeddedChannel.checkException(); + + assertFalse(embeddedChannel.outboundMessages().isEmpty()); + + final BinaryWebSocketFrame serverStaticKeyMessageFrame = + (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); + @SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes = + new byte[serverStaticKeyMessageFrame.content().readableBytes()]; + serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes); + + assertEquals(readHandshakeResponse(clientHandshakeState, serverStaticKeyMessageBytes).length, 0); + + final byte[] serverPublicKey = new byte[32]; + clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); + assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes()); + + return clientHandshakeState.split(); + } + + + private static byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) { + final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17); + clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits()); + clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits()); + clientIdentityPayloadBuffer.put(deviceId); + clientIdentityPayloadBuffer.flip(); + return clientIdentityPayloadBuffer.array(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeCompleteEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeCompleteEvent.java new file mode 100644 index 000000000..ab94f1d2f --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeCompleteEvent.java @@ -0,0 +1,15 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import java.util.Optional; + +/** + * A netty user event that indicates that the noise handshake finished successfully. + * + * @param fastResponse A response if the client included a request to send in the initiate handshake message payload and + * the server included a payload in the handshake response. + */ +public record NoiseClientHandshakeCompleteEvent(Optional fastResponse) {} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHandler.java new file mode 100644 index 000000000..c8df5b2be --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHandler.java @@ -0,0 +1,55 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; +import java.util.Optional; + +class NoiseClientHandshakeHandler extends ChannelInboundHandlerAdapter { + + private final NoiseClientHandshakeHelper handshakeHelper; + private final byte[] payload; + + NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper, final byte[] payload) { + this.handshakeHelper = handshakeHelper; + this.payload = payload; + } + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception { + if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) { + if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + byte[] handshakeMessage = handshakeHelper.write(payload); + context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage))) + .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + } + } + super.userEventTriggered(context, event); + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) + throws NoiseHandshakeException { + if (message instanceof BinaryWebSocketFrame frame) { + try { + final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame.content())); + final Optional fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload); + context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split())); + context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse)); + } finally { + frame.release(); + } + } else { + context.fireChannelRead(message); + } + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + handshakeHelper.destroy(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHelper.java new file mode 100644 index 000000000..cb3998628 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientHandshakeHelper.java @@ -0,0 +1,93 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.CipherStatePair; +import com.southernstorm.noise.protocol.HandshakeState; +import java.security.NoSuchAlgorithmException; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; + +public class NoiseClientHandshakeHelper { + + private final HandshakePattern handshakePattern; + private final HandshakeState handshakeState; + + private NoiseClientHandshakeHelper(HandshakePattern handshakePattern, HandshakeState handshakeState) { + this.handshakePattern = handshakePattern; + this.handshakeState = handshakeState; + } + + static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) { + try { + final HandshakeState state = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR); + state.getLocalKeyPair().setPrivateKey(clientStaticKey.getPrivateKey().serialize(), 0); + state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0); + state.start(); + return new NoiseClientHandshakeHelper(HandshakePattern.IK, state); + } catch (NoSuchAlgorithmException e) { + throw new IllegalArgumentException(e); + } + } + + static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) { + try { + final HandshakeState state = new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR); + state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0); + state.start(); + return new NoiseClientHandshakeHelper(HandshakePattern.NK, state); + } catch (NoSuchAlgorithmException e) { + throw new IllegalArgumentException(e); + } + } + + byte[] write(final byte[] requestPayload) throws ShortBufferException { + final byte[] initiateHandshakeMessage = new byte[initiateHandshakeKeysLength() + requestPayload.length + 16]; + handshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length); + return initiateHandshakeMessage; + } + + private int initiateHandshakeKeysLength() { + return switch (handshakePattern) { + // 32-byte ephemeral key, 32-byte encrypted static key, 16-byte AEAD tag + case IK -> 32 + 32 + 16; + // 32-byte ephemeral key + case NK -> 32; + }; + } + + byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException { + // Don't process additional messages if the handshake failed and we're just waiting to close + if (handshakeState.getAction() != HandshakeState.READ_MESSAGE) { + throw new NoiseHandshakeException("Received message with handshake state " + handshakeState.getAction()); + } + final int payloadLength = responderHandshakeMessage.length - 16 - 32; + final byte[] responsePayload = new byte[payloadLength]; + final int payloadBytesRead; + try { + payloadBytesRead = handshakeState + .readMessage(responderHandshakeMessage, 0, responderHandshakeMessage.length, responsePayload, 0); + if (payloadBytesRead != responsePayload.length) { + throw new IllegalStateException( + "Unexpected payload length, required " + payloadLength + " got " + payloadBytesRead); + } + return responsePayload; + } catch (ShortBufferException e) { + throw new IllegalStateException("Failed to deserialize payload of known length" + e.getMessage()); + } catch (BadPaddingException e) { + throw new NoiseHandshakeException(e.getMessage()); + } + } + + CipherStatePair split() { + return this.handshakeState.split(); + } + + void destroy() { + this.handshakeState.destroy(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseTransportHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientTransportHandler.java similarity index 79% rename from service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseTransportHandler.java rename to service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientTransportHandler.java index dff884046..57229af83 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseTransportHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseClientTransportHandler.java @@ -11,30 +11,26 @@ import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.util.ReferenceCountUtil; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * A Noise transport handler manages a bidirectional Noise session after a handshake has completed. */ -class NoiseTransportHandler extends ChannelDuplexHandler { +class NoiseClientTransportHandler extends ChannelDuplexHandler { private final CipherStatePair cipherStatePair; - private static final Logger log = LoggerFactory.getLogger(NoiseTransportHandler.class); + private static final Logger log = LoggerFactory.getLogger(NoiseClientTransportHandler.class); - NoiseTransportHandler(CipherStatePair cipherStatePair) { + NoiseClientTransportHandler(CipherStatePair cipherStatePair) { this.cipherStatePair = cipherStatePair; } @Override - public void channelRead(final ChannelHandlerContext context, final Object message) - throws ShortBufferException, BadPaddingException { - - if (message instanceof BinaryWebSocketFrame frame) { - try { + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + try { + if (message instanceof BinaryWebSocketFrame frame) { final CipherState cipherState = cipherStatePair.getReceiver(); // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array. @@ -45,22 +41,22 @@ class NoiseTransportHandler extends ChannelDuplexHandler { final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length); context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength)); - } finally { - frame.release(); + } else { + // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an + // error + throw new IllegalArgumentException("Unexpected message in pipeline: " + message); } - } else { - // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an - // error + } finally { ReferenceCountUtil.release(message); - throw new IllegalArgumentException("Unexpected message in pipeline: " + message); } } + @Override - public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception { + public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) + throws Exception { if (message instanceof ByteBuf plaintext) { try { - // TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames final CipherState cipherState = cipherStatePair.getSender(); final int plaintextLength = plaintext.readableBytes(); @@ -75,7 +71,7 @@ class NoiseTransportHandler extends ChannelDuplexHandler { context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise); } finally { - plaintext.release(); + ReferenceCountUtil.release(plaintext); } } else { if (!(message instanceof WebSocketFrame)) { @@ -83,7 +79,6 @@ class NoiseTransportHandler extends ChannelDuplexHandler { // get issued in response to exceptions) log.warn("Unexpected object in pipeline: {}", message); } - context.write(message, promise); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java new file mode 100644 index 000000000..e818a93e3 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java @@ -0,0 +1,66 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; + +import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.ByteBuf; +import java.nio.charset.StandardCharsets; +import javax.crypto.ShortBufferException; +import io.netty.buffer.ByteBufUtil; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; + + +public class NoiseHandshakeHelperTest { + + @ParameterizedTest + @EnumSource(HandshakePattern.class) + void testWithPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException { + doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), "pong".getBytes(StandardCharsets.UTF_8)); + } + + @ParameterizedTest + @EnumSource(HandshakePattern.class) + void testWithRequestPayload(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException { + doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), new byte[0]); + } + + @ParameterizedTest + @EnumSource(HandshakePattern.class) + void testWithoutPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException { + doHandshake(pattern, new byte[0], new byte[0]); + } + + void doHandshake(final HandshakePattern pattern, final byte[] requestPayload, final byte[] responsePayload) throws ShortBufferException, NoiseHandshakeException { + final ECKeyPair serverKeyPair = Curve.generateKeyPair(); + final ECKeyPair clientKeyPair = Curve.generateKeyPair(); + + NoiseHandshakeHelper serverHelper = new NoiseHandshakeHelper(pattern, serverKeyPair); + NoiseClientHandshakeHelper clientHelper = switch (pattern) { + case IK -> NoiseClientHandshakeHelper.IK(serverKeyPair.getPublicKey(), clientKeyPair); + case NK -> NoiseClientHandshakeHelper.NK(serverKeyPair.getPublicKey()); + }; + + final byte[] initiate = clientHelper.write(requestPayload); + final ByteBuf actualRequestPayload = serverHelper.read(initiate); + assertThat(ByteBufUtil.getBytes(actualRequestPayload)).isEqualTo(requestPayload); + + assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.WRITE_MESSAGE); + + final byte[] respond = serverHelper.write(responsePayload); + byte[] actualResponsePayload = clientHelper.read(respond); + assertThat(actualResponsePayload).isEqualTo(responsePayload); + + assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.SPLIT); + assertThatNoException().isThrownBy(() -> serverHelper.getHandshakeState().split()); + assertThatNoException().isThrownBy(() -> clientHelper.split()); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java deleted file mode 100644 index 8e39de99d..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java +++ /dev/null @@ -1,47 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import java.util.Optional; -import org.signal.libsignal.protocol.ecc.ECPublicKey; - -class NoiseNXClientHandshakeHandler extends AbstractNoiseClientHandler { - - private boolean receivedServerStaticKeyMessage = false; - - NoiseNXClientHandshakeHandler(final ECPublicKey rootPublicKey) { - super(rootPublicKey); - } - - @Override - protected String getNoiseProtocolName() { - return NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME; - } - - @Override - protected void startHandshake() { - getHandshakeState().start(); - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - if (message instanceof BinaryWebSocketFrame frame) { - try { - // Don't process additional messages if we're just waiting to close because the handshake failed - if (receivedServerStaticKeyMessage) { - return; - } - - receivedServerStaticKeyMessage = true; - handleServerStaticKeyMessage(context, frame); - - context.pipeline().replace(this, null, new NoiseTransportHandler(getHandshakeState().split())); - context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty())); - } finally { - frame.release(); - } - } else { - context.fireChannelRead(message); - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java deleted file mode 100644 index 53760e265..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java +++ /dev/null @@ -1,84 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import com.southernstorm.noise.protocol.HandshakeState; -import io.netty.buffer.Unpooled; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import java.security.NoSuchAlgorithmException; -import java.util.Optional; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import org.junit.jupiter.api.Test; -import org.signal.libsignal.protocol.ecc.ECKeyPair; - -class NoiseNXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest { - - @Override - protected NoiseNXHandshakeHandler getHandler(final ECKeyPair serverKeyPair, - final byte[] serverPublicKeySignature) { - - return new NoiseNXHandshakeHandler(serverKeyPair, serverPublicKeySignature); - } - - @Test - void handleCompleteHandshake() - throws NoSuchAlgorithmException, ShortBufferException, InterruptedException, BadPaddingException { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - - assertNotNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class)); - - final HandshakeState clientHandshakeState = - new HandshakeState(NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR); - - clientHandshakeState.start(); - - { - final byte[] ephemeralKeyMessageBytes = new byte[32]; - clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0); - - final BinaryWebSocketFrame ephemeralKeyMessageFrame = - new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes)); - - assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess()); - assertEquals(0, ephemeralKeyMessageFrame.refCnt()); - } - - { - assertEquals(1, embeddedChannel.outboundMessages().size()); - - final BinaryWebSocketFrame serverStaticKeyMessageFrame = - (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); - - @SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes = - new byte[serverStaticKeyMessageFrame.content().readableBytes()]; - - serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes); - - final byte[] serverPublicKeySignature = new byte[64]; - - final int payloadLength = - clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0); - - assertEquals(serverPublicKeySignature.length, payloadLength); - - final byte[] serverPublicKey = new byte[32]; - clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); - - assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature)); - } - - assertEquals(new NoiseHandshakeCompleteEvent(Optional.empty()), getNoiseHandshakeCompleteEvent()); - - assertNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class), - "Handshake handler should remove self from pipeline after successful handshake"); - - assertNotNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class), - "Handshake handler should insert a Noise stream handler after successful handshake"); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseTransportHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseTransportHandlerTest.java deleted file mode 100644 index 08d290702..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseTransportHandlerTest.java +++ /dev/null @@ -1,135 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import com.southernstorm.noise.protocol.CipherStatePair; -import com.southernstorm.noise.protocol.HandshakeState; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFuture; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.util.internal.EmptyArrays; -import io.netty.util.internal.ThreadLocalRandom; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import javax.crypto.AEADBadTagException; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import java.nio.charset.StandardCharsets; -import java.security.NoSuchAlgorithmException; - -import static org.junit.jupiter.api.Assertions.*; - -class NoiseTransportHandlerTest extends AbstractLeakDetectionTest { - - private CipherStatePair clientCipherStatePair; - private EmbeddedChannel embeddedChannel; - - // We use an NN handshake for this test just because it's a little shorter and easier to set up - private static final String NOISE_PROTOCOL_NAME = "Noise_NN_25519_ChaChaPoly_BLAKE2b"; - - @BeforeEach - void setUp() throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException { - final HandshakeState clientHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR); - final HandshakeState serverHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.RESPONDER); - - clientHandshakeState.start(); - serverHandshakeState.start(); - - final byte[] clientEphemeralKeyMessage = new byte[32]; - assertEquals(clientEphemeralKeyMessage.length, - clientHandshakeState.writeMessage(clientEphemeralKeyMessage, 0, null, 0, 0)); - - serverHandshakeState.readMessage(clientEphemeralKeyMessage, 0, clientEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0); - - // 32 bytes of key material plus a 16-byte MAC - final byte[] serverEphemeralKeyMessage = new byte[48]; - assertEquals(serverEphemeralKeyMessage.length, - serverHandshakeState.writeMessage(serverEphemeralKeyMessage, 0, null, 0, 0)); - - clientHandshakeState.readMessage(serverEphemeralKeyMessage, 0, serverEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0); - - clientCipherStatePair = clientHandshakeState.split(); - embeddedChannel = new EmbeddedChannel(new NoiseTransportHandler(serverHandshakeState.split())); - - clientHandshakeState.destroy(); - serverHandshakeState.destroy(); - } - - @Test - void channelRead() throws ShortBufferException, InterruptedException { - final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8); - final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()]; - clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length); - - final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext)); - assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess()); - assertEquals(0, ciphertextFrame.refCnt()); - - final ByteBuf decryptedPlaintextBuffer = (ByteBuf) embeddedChannel.inboundMessages().poll(); - assertNotNull(decryptedPlaintextBuffer); - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - - final byte[] decryptedPlaintext = ByteBufUtil.getBytes(decryptedPlaintextBuffer); - decryptedPlaintextBuffer.release(); - - assertArrayEquals(plaintext, decryptedPlaintext); - } - - @Test - void channelReadBadCiphertext() throws InterruptedException { - final byte[] bogusCiphertext = new byte[32]; - ThreadLocalRandom.current().nextBytes(bogusCiphertext); - - final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext)); - final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await(); - - assertEquals(0, ciphertextFrame.refCnt()); - assertFalse(readCiphertextFuture.isSuccess()); - assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause()); - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - } - - @Test - void channelReadUnexpectedMessageType() throws InterruptedException { - final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await(); - - assertFalse(readFuture.isSuccess()); - assertInstanceOf(IllegalArgumentException.class, readFuture.cause()); - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - } - - @Test - void write() throws InterruptedException, ShortBufferException, BadPaddingException { - final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8); - final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext); - - final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer); - assertTrue(writePlaintextFuture.await().isSuccess()); - assertEquals(0, plaintextBuffer.refCnt()); - - final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); - assertNotNull(ciphertextFrame); - assertTrue(embeddedChannel.outboundMessages().isEmpty()); - - final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content()); - ciphertextFrame.release(); - - final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()]; - clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length); - - assertArrayEquals(plaintext, decryptedPlaintext); - } - - @Test - void writeUnexpectedMessageType() throws InterruptedException { - final Object unexpectedMessaged = new Object(); - - final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged); - assertTrue(writeFuture.await().isSuccess()); - - assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll()); - assertTrue(embeddedChannel.outboundMessages().isEmpty()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java index 44c2974ef..74e760e6c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelClient.java @@ -1,22 +1,26 @@ package org.whispersystems.textsecuregcm.grpc.net; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufUtil; import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalServerChannel; import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; import java.net.SocketAddress; import java.net.URI; import java.security.cert.X509Certificate; import java.util.UUID; +import java.util.function.Function; import java.util.function.Supplier; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import io.netty.handler.codec.http.HttpHeaders; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; -import javax.annotation.Nullable; class NoiseWebSocketTunnelClient implements AutoCloseable { @@ -26,19 +30,86 @@ class NoiseWebSocketTunnelClient implements AutoCloseable { static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated"); static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous"); - public NoiseWebSocketTunnelClient(final SocketAddress remoteServerAddress, - final URI websocketUri, - final boolean authenticated, - final ECKeyPair ecKeyPair, - final ECPublicKey rootPublicKey, - @Nullable final UUID accountIdentifier, - final byte deviceId, - final HttpHeaders headers, - final boolean useTls, - @Nullable final X509Certificate trustedServerCertificate, - @Nullable final Supplier proxyMessageSupplier, - final NioEventLoopGroup eventLoopGroup, - final WebSocketCloseListener webSocketCloseListener) { + static class Builder { + + final SocketAddress remoteServerAddress; + NioEventLoopGroup eventLoopGroup; + ECPublicKey serverPublicKey; + + URI websocketUri = ANONYMOUS_WEBSOCKET_URI; + HttpHeaders headers = new DefaultHttpHeaders(); + WebSocketCloseListener webSocketCloseListener = WebSocketCloseListener.NOOP_LISTENER; + + boolean authenticated = false; + ECKeyPair ecKeyPair = null; + UUID accountIdentifier = null; + byte deviceId = 0x00; + boolean useTls; + X509Certificate trustedServerCertificate = null; + Supplier proxyMessageSupplier = null; + + Builder( + final SocketAddress remoteServerAddress, + final NioEventLoopGroup eventLoopGroup, + final ECPublicKey serverPublicKey) { + this.remoteServerAddress = remoteServerAddress; + this.eventLoopGroup = eventLoopGroup; + this.serverPublicKey = serverPublicKey; + } + + Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) { + this.authenticated = true; + this.accountIdentifier = accountIdentifier; + this.deviceId = deviceId; + this.ecKeyPair = ecKeyPair; + this.websocketUri = AUTHENTICATED_WEBSOCKET_URI; + return this; + } + + Builder setWebsocketUri(final URI websocketUri) { + this.websocketUri = websocketUri; + return this; + } + + Builder setUseTls(X509Certificate trustedServerCertificate) { + this.useTls = true; + this.trustedServerCertificate = trustedServerCertificate; + return this; + } + + Builder setProxyMessageSupplier(Supplier proxyMessageSupplier) { + this.proxyMessageSupplier = proxyMessageSupplier; + return this; + } + + Builder setHeaders(final HttpHeaders headers) { + this.headers = headers; + return this; + } + + Builder setWebSocketCloseListener(final WebSocketCloseListener webSocketCloseListener) { + this.webSocketCloseListener = webSocketCloseListener; + return this; + } + + Builder setServerPublicKey(ECPublicKey serverPublicKey) { + this.serverPublicKey = serverPublicKey; + return this; + } + + NoiseWebSocketTunnelClient build() { + final NoiseWebSocketTunnelClient client = + new NoiseWebSocketTunnelClient(eventLoopGroup, fastOpenRequest -> new EstablishRemoteConnectionHandler( + useTls, trustedServerCertificate, websocketUri, authenticated, ecKeyPair, serverPublicKey, + accountIdentifier, deviceId, headers, remoteServerAddress, webSocketCloseListener, proxyMessageSupplier, + fastOpenRequest)); + client.start(); + return client; + } + } + + private NoiseWebSocketTunnelClient(NioEventLoopGroup eventLoopGroup, + Function handler) { this.serverBootstrap = new ServerBootstrap() .localAddress(new LocalAddress("websocket-noise-tunnel-client")) @@ -47,28 +118,37 @@ class NoiseWebSocketTunnelClient implements AutoCloseable { .childHandler(new ChannelInitializer() { @Override protected void initChannel(final LocalChannel localChannel) { - localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(useTls, - trustedServerCertificate, - websocketUri, - authenticated, - ecKeyPair, - rootPublicKey, - accountIdentifier, - deviceId, - headers, - remoteServerAddress, - webSocketCloseListener, - proxyMessageSupplier)); + localChannel.pipeline() + // We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the + // stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put + // in the handshake. + .addLast(Http2Buffering.handler()) + // Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At + // that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually + // connect to the remote service + .addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { + if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) { + byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest()); + requestBufferedEvent.fastOpenRequest().release(); + ctx.pipeline().addLast(handler.apply(fastOpenRequest)); + } + super.userEventTriggered(ctx, evt); + } + }) + .addLast(new ClientErrorHandler()); } }); } + LocalAddress getLocalAddress() { return (LocalAddress) serverChannel.localAddress(); } - NoiseWebSocketTunnelClient start() throws InterruptedException { - serverChannel = serverBootstrap.bind().await().channel(); + private NoiseWebSocketTunnelClient start() { + serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel(); return this; } @@ -76,4 +156,5 @@ class NoiseWebSocketTunnelClient implements AutoCloseable { public void close() throws InterruptedException { serverChannel.close().await(); } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java index 844004c1e..655d17357 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java @@ -50,6 +50,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManagerFactory; import org.apache.commons.lang3.RandomStringUtils; @@ -67,7 +68,6 @@ import org.signal.chat.rpc.GetRequestAttributesResponse; import org.signal.chat.rpc.RequestAttributesGrpc; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; @@ -89,7 +89,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes private ClientConnectionManager clientConnectionManager; private ClientPublicKeysManager clientPublicKeysManager; - private ECKeyPair rootKeyPair; + private ECKeyPair serverKeyPair; private ECKeyPair clientKeyPair; private ManagedLocalGrpcServer authenticatedGrpcServer; @@ -153,9 +153,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY))); } - rootKeyPair = Curve.generateKeyPair(); clientKeyPair = Curve.generateKeyPair(); - final ECKeyPair serverKeyPair = Curve.generateKeyPair(); + serverKeyPair = Curve.generateKeyPair(); clientConnectionManager = new ClientConnectionManager(); @@ -192,14 +191,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes anonymousGrpcServer.start(); tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0, - new X509Certificate[] { serverTlsCertificate }, + new X509Certificate[]{serverTlsCertificate}, serverTlsPrivateKey, nioEventLoopGroup, delegatedTaskExecutor, clientConnectionManager, clientPublicKeysManager, serverKeyPair, - rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()), authenticatedGrpcServerAddress, anonymousGrpcServerAddress, RECOGNIZED_PROXY_SECRET); @@ -214,7 +212,6 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes clientConnectionManager, clientPublicKeysManager, serverKeyPair, - rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()), authenticatedGrpcServerAddress, anonymousGrpcServerAddress, RECOGNIZED_PROXY_SECRET); @@ -241,9 +238,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } @ParameterizedTest - @ValueSource(booleans = { true, false }) + @ValueSource(booleans = {true, false}) void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException { - try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), includeProxyMessage)) { + try (final NoiseWebSocketTunnelClient client = authenticated() + .setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage)) + .build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); try { @@ -259,24 +258,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } @ParameterizedTest - @ValueSource(booleans = { true, false }) + @ValueSource(booleans = {true, false}) void connectAuthenticatedPlaintext(final boolean includeProxyMessage) throws InterruptedException { - try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient( - tlsNoiseWebSocketTunnelServer.getLocalAddress(), - NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI, - true, - clientKeyPair, - rootKeyPair.getPublicKey(), - ACCOUNT_IDENTIFIER, - DEVICE_ID, - new DefaultHttpHeaders(), - true, - serverTlsCertificate, - includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null, - nioEventLoopGroup, - WebSocketCloseListener.NOOP_LISTENER) - .start()) { - + try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient + .Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey()) + .setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID) + .setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage)) + .build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); try { @@ -295,10 +283,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes void connectAuthenticatedBadServerKeySignature() throws InterruptedException { final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); - // Try to verify the server's public key with something other than the key with which it was signed - try (final NoiseWebSocketTunnelClient client = - buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders(), false)) { + try (final NoiseWebSocketTunnelClient client = authenticated() + .setWebSocketCloseListener(webSocketCloseListener) + .setServerPublicKey(Curve.generateKeyPair().getPublicKey()) + .build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); @@ -312,7 +301,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } } - verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); + verify(webSocketCloseListener).handleWebSocketClosedByServer( + ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); } @Test @@ -322,7 +312,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) .thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey()))); - try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(webSocketCloseListener)) { + try (final NoiseWebSocketTunnelClient client = authenticated() + .setWebSocketCloseListener(webSocketCloseListener) + .build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); try { @@ -335,7 +327,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } } - verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode()); + verify(webSocketCloseListener).handleWebSocketClosedByServer( + ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode()); } @Test @@ -345,8 +338,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - try (final NoiseWebSocketTunnelClient client = - buildAndStartAuthenticatedClient(webSocketCloseListener)) { + try (final NoiseWebSocketTunnelClient client = authenticated() + .setWebSocketCloseListener(webSocketCloseListener) + .build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); @@ -360,29 +354,18 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } } - verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode()); + verify(webSocketCloseListener).handleWebSocketClosedByServer( + ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode()); } @Test void connectAuthenticatedToAnonymousService() throws InterruptedException { final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); - try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient( - tlsNoiseWebSocketTunnelServer.getLocalAddress(), - URI.create("wss://localhost/anonymous"), - true, - clientKeyPair, - rootKeyPair.getPublicKey(), - ACCOUNT_IDENTIFIER, - DEVICE_ID, - new DefaultHttpHeaders(), - true, - serverTlsCertificate, - null, - nioEventLoopGroup, - webSocketCloseListener) - .start()) { - + try (final NoiseWebSocketTunnelClient client = authenticated() + .setWebsocketUri(NoiseWebSocketTunnelClient.ANONYMOUS_WEBSOCKET_URI) + .setWebSocketCloseListener(webSocketCloseListener) + .build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); try { @@ -395,12 +378,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } } - verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); + verify(webSocketCloseListener).handleWebSocketClosedByServer( + ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); } @Test void connectAnonymous() throws InterruptedException { - try (final NoiseWebSocketTunnelClient client = buildAndStartAnonymousClient()) { + try (final NoiseWebSocketTunnelClient client = anonymous().build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); try { @@ -420,9 +404,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); // Try to verify the server's public key with something other than the key with which it was signed - try (final NoiseWebSocketTunnelClient client = - buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) { - + try (final NoiseWebSocketTunnelClient client = anonymous() + .setWebSocketCloseListener(webSocketCloseListener) + .setServerPublicKey(Curve.generateKeyPair().getPublicKey()) + .build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); try { @@ -435,29 +420,18 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } } - verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); + verify(webSocketCloseListener).handleWebSocketClosedByServer( + ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); } @Test void connectAnonymousToAuthenticatedService() throws InterruptedException { final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); - try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient( - tlsNoiseWebSocketTunnelServer.getLocalAddress(), - URI.create("wss://localhost/authenticated"), - false, - null, - rootKeyPair.getPublicKey(), - null, - (byte) 0, - new DefaultHttpHeaders(), - true, - serverTlsCertificate, - null, - nioEventLoopGroup, - webSocketCloseListener) - .start()) { - + try (final NoiseWebSocketTunnelClient client = anonymous() + .setWebsocketUri(NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI) + .setWebSocketCloseListener(webSocketCloseListener) + .build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); try { @@ -470,7 +444,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } } - verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); + verify(webSocketCloseListener).handleWebSocketClosedByServer( + ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); } private ManagedChannel buildManagedChannel(final LocalAddress localAddress) { @@ -506,8 +481,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes assertEquals(405, httpClient.send(HttpRequest.newBuilder() .uri(authenticatedUri) .PUT(HttpRequest.BodyPublishers.ofString("test")) - .build(), - HttpResponse.BodyHandlers.ofString()).statusCode(), + .build(), + HttpResponse.BodyHandlers.ofString()).statusCode(), "Non-GET requests should not be allowed"); assertEquals(426, httpClient.send(HttpRequest.newBuilder() @@ -538,8 +513,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes .add("Accept-Language", acceptLanguage) .add("User-Agent", userAgent); - try (final NoiseWebSocketTunnelClient client = - buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), headers)) { + try (final NoiseWebSocketTunnelClient client = anonymous().setHeaders(headers).build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); @@ -582,7 +556,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } }; - try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(webSocketCloseListener)) { + try (final NoiseWebSocketTunnelClient client = authenticated() + .setWebSocketCloseListener(webSocketCloseListener) + .build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); @@ -606,63 +582,24 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } } - private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException { - return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER); + private NoiseWebSocketTunnelClient.Builder anonymous() { + return new NoiseWebSocketTunnelClient + .Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey()) + .setUseTls(serverTlsCertificate); + } - private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener) - throws InterruptedException { - - return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), false); + private NoiseWebSocketTunnelClient.Builder authenticated() { + return new NoiseWebSocketTunnelClient + .Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey()) + .setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID) + .setUseTls(serverTlsCertificate); } - private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener, - final ECPublicKey rootPublicKey, - final HttpHeaders headers, - final boolean includeProxyMessage) throws InterruptedException { - - return new NoiseWebSocketTunnelClient(tlsNoiseWebSocketTunnelServer.getLocalAddress(), - NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI, - true, - clientKeyPair, - rootPublicKey, - ACCOUNT_IDENTIFIER, - DEVICE_ID, - headers, - true, - serverTlsCertificate, - includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null, - nioEventLoopGroup, - webSocketCloseListener) - .start(); - } - - private NoiseWebSocketTunnelClient buildAndStartAnonymousClient() throws InterruptedException { - return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders()); - } - - private NoiseWebSocketTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener, - final ECPublicKey rootPublicKey, - final HttpHeaders headers) throws InterruptedException { - - return new NoiseWebSocketTunnelClient(tlsNoiseWebSocketTunnelServer.getLocalAddress(), - NoiseWebSocketTunnelClient.ANONYMOUS_WEBSOCKET_URI, - false, - null, - rootPublicKey, - null, - (byte) 0, - headers, - true, - serverTlsCertificate, - null, - nioEventLoopGroup, - webSocketCloseListener) - .start(); - } - - private static HAProxyMessage buildProxyMessage() { - return new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, - "10.0.0.1", "10.0.0.2", 12345, 443); + private static Supplier proxyMessageSupplier(boolean includeProxyMesage) { + return includeProxyMesage + ? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, + "10.0.0.1", "10.0.0.2", 12345, 443) + : null; } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java deleted file mode 100644 index 7b8dab92e..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java +++ /dev/null @@ -1,89 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import com.southernstorm.noise.protocol.HandshakeState; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import java.nio.ByteBuffer; -import java.util.Optional; -import java.util.UUID; -import javax.crypto.ShortBufferException; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; - -class NoiseXXClientHandshakeHandler extends AbstractNoiseClientHandler { - - private final ECKeyPair ecKeyPair; - - private final UUID accountIdentifier; - private final byte deviceId; - - private boolean receivedServerStaticKeyMessage = false; - - NoiseXXClientHandshakeHandler(final ECKeyPair ecKeyPair, - final ECPublicKey rootPublicKey, - final UUID accountIdentifier, - final byte deviceId) { - - super(rootPublicKey); - - this.ecKeyPair = ecKeyPair; - - this.accountIdentifier = accountIdentifier; - this.deviceId = deviceId; - } - - @Override - protected String getNoiseProtocolName() { - return NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME; - } - - @Override - protected void startHandshake() { - final HandshakeState handshakeState = getHandshakeState(); - - // Noise-java derives the public key from the private key, so we just need to set the private key - handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0); - handshakeState.start(); - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) - throws NoiseHandshakeException, ShortBufferException { - if (message instanceof BinaryWebSocketFrame frame) { - try { - // Don't process additional messages if the handshake failed and we're just waiting to close - if (receivedServerStaticKeyMessage) { - return; - } - - receivedServerStaticKeyMessage = true; - handleServerStaticKeyMessage(context, frame); - - final ByteBuffer clientIdentityBuffer = ByteBuffer.allocate(17); - clientIdentityBuffer.putLong(accountIdentifier.getMostSignificantBits()); - clientIdentityBuffer.putLong(accountIdentifier.getLeastSignificantBits()); - clientIdentityBuffer.put(deviceId); - clientIdentityBuffer.flip(); - - final HandshakeState handshakeState = getHandshakeState(); - - // We're sending two 32-byte keys plus the client identity payload - final byte[] staticKeyAndIdentityMessage = new byte[64 + clientIdentityBuffer.remaining()]; - handshakeState.writeMessage( - staticKeyAndIdentityMessage, 0, clientIdentityBuffer.array(), 0, clientIdentityBuffer.remaining()); - - context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(staticKeyAndIdentityMessage))) - .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); - - context.pipeline().replace(this, null, new NoiseTransportHandler(handshakeState.split())); - context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty())); - } finally { - frame.release(); - } - } else { - context.fireChannelRead(message); - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java deleted file mode 100644 index 6718b3532..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java +++ /dev/null @@ -1,453 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import com.southernstorm.noise.protocol.CipherState; -import com.southernstorm.noise.protocol.HandshakeState; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFuture; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import java.nio.ByteBuffer; -import java.security.NoSuchAlgorithmException; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ThreadLocalRandom; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import io.netty.util.internal.EmptyArrays; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.signal.libsignal.protocol.ecc.Curve; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; -import org.whispersystems.textsecuregcm.storage.Device; - -class NoiseXXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest { - - private ClientPublicKeysManager clientPublicKeysManager; - - @Override - @BeforeEach - void setUp() { - clientPublicKeysManager = mock(ClientPublicKeysManager.class); - - super.setUp(); - } - - @Override - protected NoiseXXHandshakeHandler getHandler(final ECKeyPair serverKeyPair, - final byte[] serverPublicKeySignature) { - - return new NoiseXXHandshakeHandler(clientPublicKeysManager, serverKeyPair, serverPublicKeySignature); - } - - @Test - void handleCompleteHandshake() - throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); - final ECKeyPair clientKeyPair = Curve.generateKeyPair(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); - - final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); - sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId); - - // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the - // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same - // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback - // and issue a "handshake complete" event. - embeddedChannel.runPendingTasks(); - - assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))), - getNoiseHandshakeCompleteEvent()); - - assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), - "Handshake handler should remove self from pipeline after successful handshake"); - - assertNotNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class), - "Handshake handler should insert a Noise stream handler after successful handshake"); - } - - @Test - void handleCompleteHandshakeMissingIdentityInformation() - throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); - final ECKeyPair clientKeyPair = Curve.generateKeyPair(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); - - final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); - - { - final byte[] clientStaticKeyMessageBytes = new byte[64]; - final int messageLength = - clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, EmptyArrays.EMPTY_BYTES, 0, 0); - - assertEquals(clientStaticKeyMessageBytes.length, messageLength); - - final BinaryWebSocketFrame clientStaticKeyMessageFrame = - new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes)); - - final ChannelFuture writeClientStaticKeyMessageFuture = - getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await(); - - assertFalse(writeClientStaticKeyMessageFuture.isSuccess()); - assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause()); - assertEquals(0, clientStaticKeyMessageFrame.refCnt()); - } - - // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the - // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same - // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback - // and issue a "handshake complete" event. - embeddedChannel.runPendingTasks(); - - assertNull(getNoiseHandshakeCompleteEvent()); - - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), - "Handshake handler should not remove self from pipeline after failed handshake"); - - assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class), - "Noise stream handler should not be added to pipeline after failed handshake"); - } - - @Test - void handleCompleteHandshakeMalformedIdentityInformation() - throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); - final ECKeyPair clientKeyPair = Curve.generateKeyPair(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); - - final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); - - { - final byte[] clientStaticKeyMessageBytes = new byte[96]; - final int messageLength = - clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, new byte[32], 0, 32); - - assertEquals(clientStaticKeyMessageBytes.length, messageLength); - - final BinaryWebSocketFrame clientStaticKeyMessageFrame = - new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes)); - - final ChannelFuture writeClientStaticKeyMessageFuture = - getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await(); - - assertFalse(writeClientStaticKeyMessageFuture.isSuccess()); - assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause()); - assertEquals(0, clientStaticKeyMessageFrame.refCnt()); - } - - // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the - // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same - // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback - // and issue a "handshake complete" event. - embeddedChannel.runPendingTasks(); - - assertNull(getNoiseHandshakeCompleteEvent()); - - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), - "Handshake handler should not remove self from pipeline after failed handshake"); - - assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class), - "Noise stream handler should not be added to pipeline after failed handshake"); - } - - @Test - void handleCompleteHandshakeUnrecognizedDevice() - throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); - final ECKeyPair clientKeyPair = Curve.generateKeyPair(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - - final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); - sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId); - - // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the - // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same - // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback - // and issue a "handshake complete" event. - embeddedChannel.runPendingTasks(); - - assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException); - - assertNull(getNoiseHandshakeCompleteEvent()); - - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), - "Handshake handler should not remove self from pipeline after failed handshake"); - - assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class), - "Noise stream handler should not be added to pipeline after failed handshake"); - } - - @Test - void handleCompleteHandshakePublicKeyMismatch() - throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); - final ECKeyPair clientKeyPair = Curve.generateKeyPair(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey()))); - - final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); - sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId); - - // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the - // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same - // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback - // and issue a "handshake complete" event. - embeddedChannel.runPendingTasks(); - - assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException); - - assertNull(getNoiseHandshakeCompleteEvent()); - - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), - "Handshake handler should not remove self from pipeline after failed handshake"); - - assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class), - "Noise stream handler should not be added to pipeline after failed handshake"); - } - - @Test - void handleCompleteHandshakeBufferedReads() - throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); - final ECKeyPair clientKeyPair = Curve.generateKeyPair(); - - final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture); - - final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); - sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId); - - final ByteBuf[] additionalMessages = new ByteBuf[4]; - final CipherState senderState = clientHandshakeState.split().getSender(); - - try { - for (int i = 0; i < additionalMessages.length; i++) { - final byte[] contentBytes = new byte[32]; - ThreadLocalRandom.current().nextBytes(contentBytes); - - // Copy the "plaintext" portion of the content bytes for future assertions - additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16); - - // Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD - // tag - senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16); - - assertTrue( - embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await() - .isSuccess()); - } - - findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey())); - - // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the - // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same - // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback - // and issue a "handshake complete" event. - embeddedChannel.runPendingTasks(); - - assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))), - getNoiseHandshakeCompleteEvent()); - - assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), - "Handshake handler should remove self from pipeline after successful handshake"); - - assertNotNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class), - "Handshake handler should insert a Noise stream handler after successful handshake"); - - for (final ByteBuf additionalMessage : additionalMessages) { - assertEquals(additionalMessage, embeddedChannel.inboundMessages().poll(), - "Buffered message should pass through pipeline after successful handshake"); - } - } finally { - for (final ByteBuf additionalMessage : additionalMessages) { - additionalMessage.release(); - } - } - } - - @Test - void handleCompleteHandshakeFailureBufferedReads() - throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); - final ECKeyPair clientKeyPair = Curve.generateKeyPair(); - - final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture); - - final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); - sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId); - - final ByteBuf[] additionalMessages = new ByteBuf[4]; - final CipherState senderState = clientHandshakeState.split().getSender(); - - try { - for (int i = 0; i < additionalMessages.length; i++) { - final byte[] contentBytes = new byte[32]; - ThreadLocalRandom.current().nextBytes(contentBytes); - - // Copy the "plaintext" portion of the content bytes for future assertions - additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16); - - // Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD - // tag - senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16); - - assertTrue(embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await().isSuccess()); - } - - findPublicKeyFuture.complete(Optional.empty()); - - // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the - // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same - // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback - // and issue a "handshake complete" event. - embeddedChannel.runPendingTasks(); - - assertNull(getNoiseHandshakeCompleteEvent()); - - assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), - "Handshake handler should not remove self from pipeline after failed handshake"); - - assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class), - "Noise stream handler should not be added to pipeline after failed handshake"); - - assertTrue(embeddedChannel.inboundMessages().isEmpty(), - "Buffered messages should not pass through pipeline after failed handshake"); - } finally { - for (final ByteBuf additionalMessage : additionalMessages) { - additionalMessage.release(); - } - } - } - - private HandshakeState exchangeClientEphemeralAndServerStaticMessages(final ECKeyPair clientKeyPair) - throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException, InterruptedException { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - - final HandshakeState clientHandshakeState = - new HandshakeState(NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR); - - clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0); - clientHandshakeState.start(); - - { - final byte[] ephemeralKeyMessageBytes = new byte[32]; - clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0); - - final BinaryWebSocketFrame ephemeralKeyMessageFrame = - new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes)); - - assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess()); - assertEquals(0, ephemeralKeyMessageFrame.refCnt()); - } - - { - assertEquals(1, embeddedChannel.outboundMessages().size()); - - final BinaryWebSocketFrame serverStaticKeyMessageFrame = - (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); - - @SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes = - new byte[serverStaticKeyMessageFrame.content().readableBytes()]; - - serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes); - - final byte[] serverPublicKeySignature = new byte[64]; - - final int payloadLength = - clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0); - - assertEquals(serverPublicKeySignature.length, payloadLength); - - final byte[] serverPublicKey = new byte[32]; - clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); - - assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature)); - } - - return clientHandshakeState; - } - - private void sendClientStaticKey(final HandshakeState handshakeState, final UUID accountIdentifier, final byte deviceId) - throws ShortBufferException, InterruptedException { - - final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17); - clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits()); - clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits()); - clientIdentityPayloadBuffer.put(deviceId); - clientIdentityPayloadBuffer.flip(); - - final byte[] clientStaticKeyMessageBytes = new byte[81]; - final int messageLength = - handshakeState.writeMessage(clientStaticKeyMessageBytes, 0, clientIdentityPayloadBuffer.array(), 0, clientIdentityPayloadBuffer.remaining()); - - assertEquals(clientStaticKeyMessageBytes.length, messageLength); - - final BinaryWebSocketFrame clientStaticKeyMessageFrame = - new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes)); - - assertTrue(getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await().isSuccess()); - assertEquals(0, clientStaticKeyMessageFrame.refCnt()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/TypedNoiseChannelDuplexHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/TypedNoiseChannelDuplexHandler.java new file mode 100644 index 000000000..6a316d745 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/TypedNoiseChannelDuplexHandler.java @@ -0,0 +1,80 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A TypedNoiseChannelDuplexHandler is a convenience {@link ChannelDuplexHandler} that can be inserted in a pipeline + * after a successful websocket handshake. It expects inbound messages to be {@link BinaryWebSocketFrame}s and outbound + * messages to be bytes. + */ +abstract class TypedNoiseChannelDuplexHandler extends ChannelDuplexHandler { + + private static final Logger log = LoggerFactory.getLogger(TypedNoiseChannelDuplexHandler.class); + + /** + * Handle an inbound message. The frame will be automatically released after the method is finished running. + * + * @param context The current {@link ChannelHandlerContext} + * @param frameBytes A {@link ByteBuf} extracted from a {@link BinaryWebSocketFrame} that contains a complete noise + * packet + * @throws Exception + */ + abstract void handleInbound(final ChannelHandlerContext context, ByteBuf frameBytes) throws Exception; + + /** + * Handle an outbound byte message. The message will be automatically released after the method is finished running. + * + * @param context The current {@link ChannelHandlerContext} + * @param bytes The bytes to write + * @throws Exception + */ + abstract void handleOutbound(final ChannelHandlerContext context, final ByteBuf bytes, + final ChannelPromise promise) throws Exception; + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + try { + if (message instanceof BinaryWebSocketFrame frame) { + handleInbound(context, frame.content()); + } else { + // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an + // error + throw new IllegalArgumentException("Unexpected message in pipeline: " + message); + } + } finally { + ReferenceCountUtil.release(message); + } + } + + + @Override + public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) + throws Exception { + if (message instanceof ByteBuf serverResponse) { + try { + handleOutbound(context, serverResponse, promise); + } finally { + ReferenceCountUtil.release(serverResponse); + } + } else { + if (!(message instanceof WebSocketFrame)) { + // Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that + // get issued in response to exceptions) + log.warn("Unexpected object in pipeline: {}", message); + } + context.write(message, promise); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java index cc0abe611..cb8f85f30 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteHandlerTest.java @@ -79,7 +79,6 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { embeddedChannel = new MutableRemoteAddressEmbeddedChannel( new WebsocketHandshakeCompleteHandler(mock(ClientPublicKeysManager.class), Curve.generateKeyPair(), - new byte[64], RECOGNIZED_PROXY_SECRET), userEventRecordingHandler); @@ -88,7 +87,7 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { @ParameterizedTest @MethodSource - void handleWebSocketHandshakeComplete(final String uri, final Class expectedHandlerClass) { + void handleWebSocketHandshakeComplete(final String uri, final Class expectedHandlerClass) { final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null); @@ -102,8 +101,8 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { private static List handleWebSocketHandshakeComplete() { return List.of( - Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class), - Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class)); + Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseAuthenticatedHandler.class), + Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseAnonymousHandler.class)); } @Test diff --git a/service/src/test/resources/config/test-secrets-bundle.yml b/service/src/test/resources/config/test-secrets-bundle.yml index 7c72ee488..4fb7dfa5f 100644 --- a/service/src/test/resources/config/test-secrets-bundle.yml +++ b/service/src/test/resources/config/test-secrets-bundle.yml @@ -130,5 +130,7 @@ linkDevice.secret: AAAAAAAAAAA= tlsKeyStore.password: unset -noiseTunnel.noiseStaticPrivateKey: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= +# The below private key was generated exclusively for testing purposes. Do not use it in any other context. +# Corresponding public key: cYUAFtkWK/4x3AfW/yw7qgIo/mQUaRSWaPolGQkiL14= +noiseTunnel.noiseStaticPrivateKey: qK5FD9WmuhoLPsS/Z4swcZkwDn9OpeM5ZmcEVMpEQ24= noiseTunnel.recognizedProxySecret: ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789AAAAAAA diff --git a/service/src/test/resources/config/test.yml b/service/src/test/resources/config/test.yml index 6ea660997..691ec2bf6 100644 --- a/service/src/test/resources/config/test.yml +++ b/service/src/test/resources/config/test.yml @@ -452,7 +452,6 @@ callingTurnManualTable: noiseTunnel: port: 8443 noiseStaticPrivateKey: secret://noiseTunnel.noiseStaticPrivateKey - noiseRootPublicKeySignature: ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret externalRequestFilter: