From 0cc5431867a6c30405c25ea48d45b14068c372bb Mon Sep 17 00:00:00 2001 From: ravi-signal <99042880+ravi-signal@users.noreply.github.com> Date: Tue, 13 May 2025 14:16:23 -0500 Subject: [PATCH] Update noise-gRPC protocol errors --- .../textsecuregcm/grpc/DeviceIdUtil.java | 6 +- .../net/ClientAuthenticationException.java | 9 - .../textsecuregcm/grpc/net/ErrorHandler.java | 7 +- .../EstablishLocalGrpcConnectionHandler.java | 8 +- .../grpc/net/NoiseAnonymousHandler.java | 34 --- .../grpc/net/NoiseAuthenticatedHandler.java | 96 --------- .../textsecuregcm/grpc/net/NoiseHandler.java | 144 +++---------- .../grpc/net/NoiseHandshakeHandler.java | 197 ++++++++++++++++++ .../grpc/net/NoiseHandshakeInit.java | 33 +++ .../net/NoiseIdentityDeterminedEvent.java | 10 +- .../grpc/net/OutboundCloseErrorMessage.java | 6 +- .../net/noisedirect/NoiseDirectFrame.java | 11 +- .../noisedirect/NoiseDirectFrameCodec.java | 2 +- .../NoiseDirectHandshakeSelector.java | 54 ++--- .../NoiseDirectInboundCloseHandler.java | 36 ++++ .../NoiseDirectOutboundErrorHandler.java | 17 +- .../noisedirect/NoiseDirectTunnelServer.java | 15 +- .../ApplicationWebSocketCloseReason.java | 3 +- .../websocket/NoiseWebSocketTunnelServer.java | 9 +- .../WebSocketOutboundErrorHandler.java | 6 - .../WebsocketHandshakeCompleteHandler.java | 92 ++++---- service/src/main/proto/NoiseDirect.proto | 47 ++++- service/src/main/proto/NoiseTunnel.proto | 56 +++++ .../grpc/net/AbstractNoiseHandlerTest.java | 41 +++- ...tractNoiseTunnelServerIntegrationTest.java | 32 ++- .../grpc/net/NoiseAnonymousHandlerTest.java | 62 +++--- .../net/NoiseAuthenticatedHandlerTest.java | 162 ++++++++------ .../grpc/net/client/CloseFrameEvent.java | 16 +- .../EstablishRemoteConnectionHandler.java | 38 +--- .../NoiseClientHandshakeCompleteEvent.java | 4 +- .../client/NoiseClientHandshakeHandler.java | 14 +- .../client/NoiseClientTransportHandler.java | 8 +- .../grpc/net/client/NoiseTunnelClient.java | 120 ++++++++--- ...ocketNoiseTunnelServerIntegrationTest.java | 10 +- ...ocketNoiseTunnelServerIntegrationTest.java | 2 - ...WebsocketHandshakeCompleteHandlerTest.java | 51 ++--- 36 files changed, 856 insertions(+), 602 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeInit.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectInboundCloseHandler.java create mode 100644 service/src/main/proto/NoiseTunnel.proto diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DeviceIdUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DeviceIdUtil.java index d529754a1..fc016104a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DeviceIdUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DeviceIdUtil.java @@ -10,8 +10,12 @@ import org.whispersystems.textsecuregcm.storage.Device; public class DeviceIdUtil { + public static boolean isValid(int deviceId) { + return deviceId >= Device.PRIMARY_ID && deviceId <= Byte.MAX_VALUE; + } + static byte validate(int deviceId) { - if (deviceId < Device.PRIMARY_ID || deviceId > Byte.MAX_VALUE) { + if (!isValid(deviceId)) { throw Status.INVALID_ARGUMENT.withDescription("Device ID is out of range").asRuntimeException(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java deleted file mode 100644 index beeb1469a..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java +++ /dev/null @@ -1,9 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import org.whispersystems.textsecuregcm.util.NoStackTraceException; - -/** - * Indicates that an attempt to authenticate a remote client failed for some reason. - */ -public class ClientAuthenticationException extends NoStackTraceException { -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java index a62fe0684..a65c3f5aa 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java @@ -1,10 +1,9 @@ package org.whispersystems.textsecuregcm.grpc.net; -import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; -import javax.crypto.BadPaddingException; import io.netty.channel.ChannelInboundHandlerAdapter; +import javax.crypto.BadPaddingException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.util.ExceptionUtils; @@ -16,9 +15,6 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils; public class ErrorHandler extends ChannelInboundHandlerAdapter { private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class); - private static OutboundCloseErrorMessage UNAUTHENTICATED_CLOSE = new OutboundCloseErrorMessage( - OutboundCloseErrorMessage.Code.AUTHENTICATION_ERROR, - "Not authenticated"); private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage( OutboundCloseErrorMessage.Code.NOISE_ERROR, "Noise encryption error"); @@ -29,7 +25,6 @@ public class ErrorHandler extends ChannelInboundHandlerAdapter { case NoiseHandshakeException e -> new OutboundCloseErrorMessage( OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR, e.getMessage()); - case ClientAuthenticationException ignored -> UNAUTHENTICATED_CLOSE; case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE; case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE; default -> { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java index 59f691dbe..e554afa5d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java @@ -8,6 +8,7 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.util.ReferenceCountUtil; +import java.net.InetAddress; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -48,7 +49,9 @@ public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAd @Override public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) { - if (event instanceof NoiseIdentityDeterminedEvent(final Optional authenticatedDevice)) { + if (event instanceof NoiseIdentityDeterminedEvent( + final Optional authenticatedDevice, + InetAddress remoteAddress, String userAgent, String acceptLanguage)) { // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to // connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the @@ -57,6 +60,9 @@ public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAd ? authenticatedGrpcServerAddress : anonymousGrpcServerAddress; + GrpcClientConnectionManager.handleHandshakeInitiated( + remoteChannelContext.channel(), remoteAddress, userAgent, acceptLanguage); + new Bootstrap() .remoteAddress(grpcServerAddress) .channel(LocalChannel.class) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java deleted file mode 100644 index 411f5094e..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandler.java +++ /dev/null @@ -1,34 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import java.util.Optional; -import java.util.concurrent.CompletableFuture; -import org.signal.libsignal.protocol.ecc.ECKeyPair; - -/** - * A NoiseAnonymousHandler is a netty pipeline element that handles the responder side of an unauthenticated handshake - * and noise encryption/decryption. - *

- * A noise NK handshake must be used for unauthenticated connections. Optionally, the initiator can also include an - * initial request in their payload. If provided, this allows the server to begin processing the request without an - * initial message delay (fast open). - *

- * Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent} - * indicating that initiator connected anonymously. - */ -public class NoiseAnonymousHandler extends NoiseHandler { - - public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) { - super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair)); - } - - @Override - CompletableFuture handleHandshakePayload(final ChannelHandlerContext context, - final Optional initiatorPublicKey, final ByteBuf handshakePayload) { - return CompletableFuture.completedFuture(new HandshakeResult( - handshakePayload, - Optional.empty() - )); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java deleted file mode 100644 index d28910eee..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandler.java +++ /dev/null @@ -1,96 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.util.ReferenceCountUtil; -import java.security.MessageDigest; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; -import org.whispersystems.textsecuregcm.util.ExceptionUtils; - -/** - * A NoiseAuthenticatedHandler is a netty pipeline element that handles the responder side of an authenticated handshake - * and noise encryption/decryption. Authenticated handshakes are noise IK handshakes where the initiator's static public - * key is authenticated by the responder. - *

- * The authenticated handshake requires the initiator to provide a payload with their first handshake message that - * includes their account identifier and device id in network byte-order. Optionally, the initiator can also include an - * initial request in their payload. If provided, this allows the server to begin processing the request without an - * initial message delay (fast open). - *

- * +-----------------+----------------+------------------------+
- * |    UUID (16)    |  deviceId (1)  |     request bytes (N)  |
- * +-----------------+----------------+------------------------+
- * 
- *

- * For a successful handshake, the static key provided in the handshake message must match the server's stored public - * key for the device identified by the provided ACI and deviceId. - *

- * As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}. - */ -public class NoiseAuthenticatedHandler extends NoiseHandler { - - private final ClientPublicKeysManager clientPublicKeysManager; - - public NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager, - final ECKeyPair ecKeyPair) { - super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair)); - this.clientPublicKeysManager = clientPublicKeysManager; - } - - @Override - CompletableFuture handleHandshakePayload( - final ChannelHandlerContext context, - final Optional initiatorPublicKey, - final ByteBuf handshakePayload) throws NoiseHandshakeException { - if (handshakePayload.readableBytes() < 17) { - throw new NoiseHandshakeException("Invalid handshake payload"); - } - - final byte[] publicKeyFromClient = initiatorPublicKey - .orElseThrow(() -> new IllegalStateException("No remote public key")); - - // Advances the read index by 16 bytes - final UUID accountIdentifier = parseUUID(handshakePayload); - - // Advances the read index by 1 byte - final byte deviceId = handshakePayload.readByte(); - - final ByteBuf fastOpenRequest = handshakePayload.slice(); - return clientPublicKeysManager - .findPublicKey(accountIdentifier, deviceId) - .handleAsync((storedPublicKey, throwable) -> { - if (throwable != null) { - ReferenceCountUtil.release(fastOpenRequest); - throw ExceptionUtils.wrap(throwable); - } - final boolean valid = storedPublicKey - .map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes())) - .orElse(false); - if (!valid) { - throw ExceptionUtils.wrap(new ClientAuthenticationException()); - } - return new HandshakeResult( - fastOpenRequest, - Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))); - }, context.executor()); - } - - /** - * Parse a {@link UUID} out of bytes, advancing the readerIdx by 16 - * - * @param bytes The {@link ByteBuf} to read from - * @return The parsed UUID - * @throws NoiseHandshakeException If a UUID could not be parsed from bytes - */ - private UUID parseUUID(final ByteBuf bytes) throws NoiseHandshakeException { - if (bytes.readableBytes() < 16) { - throw new NoiseHandshakeException("Could not parse account identifier"); - } - return new UUID(bytes.readLong(), bytes.readLong()); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java index 7390b099c..42e871e35 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java @@ -11,159 +11,59 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.PromiseCombiner; -import io.netty.util.internal.EmptyArrays; -import java.util.Optional; -import java.util.concurrent.CompletableFuture; import javax.crypto.BadPaddingException; import javax.crypto.ShortBufferException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame; -import org.whispersystems.textsecuregcm.util.ExceptionUtils; /** - * A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts - * inbound messages, and encrypts outbound messages + * A bidirectional {@link io.netty.channel.ChannelHandler} that decrypts inbound messages, and encrypts outbound + * messages */ -public abstract class NoiseHandler extends ChannelDuplexHandler { +public class NoiseHandler extends ChannelDuplexHandler { private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class); + private final CipherStatePair cipherStatePair; - private enum State { - // Waiting for handshake to complete - HANDSHAKE, - // Can freely exchange encrypted noise messages on an established session - TRANSPORT, - // Finished with error - ERROR + NoiseHandler(CipherStatePair cipherStatePair) { + this.cipherStatePair = cipherStatePair; } - private final NoiseHandshakeHelper handshakeHelper; - - private State state = State.HANDSHAKE; - private CipherStatePair cipherStatePair; - - NoiseHandler(NoiseHandshakeHelper handshakeHelper) { - this.handshakeHelper = handshakeHelper; - } - - /** - * The result of processing an initiator handshake payload - * - * @param fastOpenRequest A fast-open request included in the handshake. If none was present, this should be an - * empty ByteBuf - * @param authenticatedDevice If present, the successfully authenticated initiator identity - */ - record HandshakeResult(ByteBuf fastOpenRequest, Optional authenticatedDevice) {} - - /** - * Parse and potentially authenticate the initiator handshake message - * - * @param context A {@link ChannelHandlerContext} - * @param initiatorPublicKey The initiator's static public key, if a handshake pattern that includes it was used - * @param handshakePayload The handshake payload provided in the initiator message - * @return A {@link HandshakeResult} that includes an authenticated device and a parsed fast-open request if one was - * present in the handshake payload. - * @throws NoiseHandshakeException If the handshake payload was invalid - * @throws ClientAuthenticationException If the initiatorPublicKey could not be authenticated - */ - abstract CompletableFuture handleHandshakePayload( - final ChannelHandlerContext context, - final Optional initiatorPublicKey, - final ByteBuf handshakePayload) throws NoiseHandshakeException, ClientAuthenticationException; - @Override public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { try { if (message instanceof ByteBuf frame) { if (frame.readableBytes() > Noise.MAX_PACKET_LEN) { - final String error = "Invalid noise message length " + frame.readableBytes(); - throw state == State.HANDSHAKE ? new NoiseHandshakeException(error) : new NoiseException(error); + throw new NoiseException("Invalid noise message length " + frame.readableBytes()); } // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array. // We'll need to copy it to a heap buffer. - handleInboundMessage(context, ByteBufUtil.getBytes(frame)); + handleInboundDataMessage(context, ByteBufUtil.getBytes(frame)); } else { // Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error throw new IllegalArgumentException("Unexpected message in pipeline: " + message); } - } catch (Exception e) { - fail(context, e); } finally { ReferenceCountUtil.release(message); } } - private void handleInboundMessage(final ChannelHandlerContext context, final byte[] frameBytes) - throws NoiseHandshakeException, ShortBufferException, BadPaddingException, ClientAuthenticationException { - switch (state) { - // Got an initiator handshake message - case HANDSHAKE -> { - final ByteBuf payload = handshakeHelper.read(frameBytes); - handleHandshakePayload(context, handshakeHelper.remotePublicKey(), payload).whenCompleteAsync( - (result, throwable) -> { - if (state == State.ERROR) { - return; - } - if (throwable != null) { - fail(context, ExceptionUtils.unwrap(throwable)); - return; - } - context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(result.authenticatedDevice())); + private void handleInboundDataMessage(final ChannelHandlerContext context, final byte[] frameBytes) + throws ShortBufferException, BadPaddingException { + final CipherState cipherState = cipherStatePair.getReceiver(); + // Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer + final int plaintextLength = cipherState.decryptWithAd(null, + frameBytes, 0, + frameBytes, 0, + frameBytes.length); - // Now that we've authenticated, write the handshake response - byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES); - context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage)) - .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); - - // The handshake is complete. We can start intercepting read/write for noise encryption/decryption - this.state = State.TRANSPORT; - this.cipherStatePair = handshakeHelper.getHandshakeState().split(); - if (result.fastOpenRequest().isReadable()) { - // The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll - // encrypt the response when the server writes back through us - context.fireChannelRead(result.fastOpenRequest()); - } else { - ReferenceCountUtil.release(result.fastOpenRequest()); - } - }, context.executor()); - } - - // Got a client message that should be decrypted and forwarded - case TRANSPORT -> { - final CipherState cipherState = cipherStatePair.getReceiver(); - // Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer - final int plaintextLength = cipherState.decryptWithAd(null, - frameBytes, 0, - frameBytes, 0, - frameBytes.length); - - // Forward the decrypted plaintext along - context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength)); - } - - // The session is already in an error state, drop the message - case ERROR -> { - } - } - } - - /** - * Set the state to the error state (so subsequent messages fast-fail) and propagate the failure reason on the - * context - */ - private void fail(final ChannelHandlerContext context, final Throwable cause) { - this.state = State.ERROR; - context.fireExceptionCaught(cause); + // Forward the decrypted plaintext along + context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength)); } @Override @@ -208,4 +108,12 @@ public abstract class NoiseHandler extends ChannelDuplexHandler { context.write(message, promise); } } + + @Override + public void handlerRemoved(ChannelHandlerContext var1) { + if (cipherStatePair != null) { + cipherStatePair.destroy(); + } + } + } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHandler.java new file mode 100644 index 000000000..80d7f6105 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHandler.java @@ -0,0 +1,197 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.Noise; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.ReferenceCountUtil; +import java.io.IOException; +import java.net.InetAddress; +import java.security.MessageDigest; +import java.util.Optional; +import java.util.UUID; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.grpc.DeviceIdUtil; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +/** + * Handles the responder side of a noise handshake and then replaces itself with a {@link NoiseHandler} which will + * encrypt/decrypt subsequent data frames + *

+ * The handler expects to receive a single inbound message, a {@link NoiseHandshakeInit} that includes the initiator + * handshake message, connection metadata, and the type of handshake determined by the framing layer. This handler + * currently supports two types of handshakes. + *

+ * The first are IK handshakes where the initiator's static public key is authenticated by the responder. The initiator + * handshake message must contain the ACI and deviceId of the initiator. To be authenticated, the static key provided in + * the handshake message must match the server's stored public key for the device identified by the provided ACI and + * deviceId. + *

+ * The second are NK handshakes which are anonymous. + *

+ * Optionally, the initiator can also include an initial request in their payload. If provided, this allows the server + * to begin processing the request without an initial message delay (fast open). + *

+ * Once the handshake has been validated, a {@link NoiseIdentityDeterminedEvent} will be fired. For an IK handshake, + * this will include the {@link org.whispersystems.textsecuregcm.auth.AuthenticatedDevice} of the initiator. This + * handler will then replace itself with a {@link NoiseHandler} with a noise state pair ready to encrypt/decrypt data + * frames. + */ +public class NoiseHandshakeHandler extends ChannelInboundHandlerAdapter { + + private static final byte[] HANDSHAKE_WRONG_PK = NoiseTunnelProtos.HandshakeResponse.newBuilder() + .setCode(NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY) + .build().toByteArray(); + private static final byte[] HANDSHAKE_OK = NoiseTunnelProtos.HandshakeResponse.newBuilder() + .setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK) + .build().toByteArray(); + + // We might get additional messages while we're waiting to process a handshake, so keep track of where we are + private boolean receivedHandshakeInit = false; + + private final ClientPublicKeysManager clientPublicKeysManager; + private final ECKeyPair ecKeyPair; + + public NoiseHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) { + this.clientPublicKeysManager = clientPublicKeysManager; + this.ecKeyPair = ecKeyPair; + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + try { + if (!(message instanceof NoiseHandshakeInit handshakeInit)) { + // Anything except HandshakeInit should have been filtered out of the pipeline by now; treat this as an error + throw new IllegalArgumentException("Unexpected message in pipeline: " + message); + } + if (receivedHandshakeInit) { + throw new NoiseHandshakeException("Should not receive messages until handshake complete"); + } + receivedHandshakeInit = true; + + if (handshakeInit.content().readableBytes() > Noise.MAX_PACKET_LEN) { + throw new NoiseHandshakeException("Invalid noise message length " + handshakeInit.content().readableBytes()); + } + + // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array. + // We'll need to copy it to a heap buffer + handleInboundHandshake(context, + handshakeInit.getRemoteAddress(), + handshakeInit.getHandshakePattern(), + ByteBufUtil.getBytes(handshakeInit.content())); + } finally { + ReferenceCountUtil.release(message); + } + } + + private void handleInboundHandshake( + final ChannelHandlerContext context, + final InetAddress remoteAddress, + final HandshakePattern handshakePattern, + final byte[] frameBytes) throws NoiseHandshakeException { + final NoiseHandshakeHelper handshakeHelper = new NoiseHandshakeHelper(handshakePattern, ecKeyPair); + final ByteBuf payload = handshakeHelper.read(frameBytes); + + // Parse the handshake message + final NoiseTunnelProtos.HandshakeInit handshakeInit; + try { + handshakeInit = NoiseTunnelProtos.HandshakeInit.parseFrom(new ByteBufInputStream(payload)); + } catch (IOException e) { + throw new NoiseHandshakeException("Failed to parse handshake message"); + } + + switch (handshakePattern) { + case NK -> { + if (handshakeInit.getDeviceId() != 0 || !handshakeInit.getAci().isEmpty()) { + throw new NoiseHandshakeException("Anonymous handshake should not include identifiers"); + } + handleAuthenticated(context, handshakeHelper, remoteAddress, handshakeInit, Optional.empty()); + } + case IK -> { + final byte[] publicKeyFromClient = handshakeHelper.remotePublicKey() + .orElseThrow(() -> new IllegalStateException("No remote public key")); + final UUID accountIdentifier = aci(handshakeInit); + final byte deviceId = deviceId(handshakeInit); + clientPublicKeysManager + .findPublicKey(accountIdentifier, deviceId) + .whenCompleteAsync((storedPublicKey, throwable) -> { + if (throwable != null) { + context.fireExceptionCaught(ExceptionUtils.unwrap(throwable)); + return; + } + final boolean valid = storedPublicKey + .map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes())) + .orElse(false); + if (!valid) { + // Write a handshake response indicating that the client used the wrong public key + final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_WRONG_PK); + context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage)) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + + context.fireExceptionCaught(new NoiseHandshakeException("Bad public key")); + return; + } + handleAuthenticated(context, + handshakeHelper, remoteAddress, handshakeInit, + Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))); + }, context.executor()); + } + }; + } + + private void handleAuthenticated(final ChannelHandlerContext context, + final NoiseHandshakeHelper handshakeHelper, + final InetAddress remoteAddress, + final NoiseTunnelProtos.HandshakeInit handshakeInit, + final Optional maybeAuthenticatedDevice) { + context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent( + maybeAuthenticatedDevice, + remoteAddress, + handshakeInit.getUserAgent(), + handshakeInit.getAcceptLanguage())); + + // Now that we've authenticated, write the handshake response + final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_OK); + context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage)) + .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + + // The handshake is complete. We can start intercepting read/write for noise encryption/decryption + // Note: It may be tempting to swap the before/remove for a replace, but then when we forward the fast open + // request it will go through the NoiseHandler. We want to skip the NoiseHandler because we've already + // decrypted the fastOpen request + context.pipeline() + .addBefore(context.name(), null, new NoiseHandler(handshakeHelper.getHandshakeState().split())); + context.pipeline().remove(NoiseHandshakeHandler.class); + if (!handshakeInit.getFastOpenRequest().isEmpty()) { + // The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll + // encrypt the response when the server writes back through us + context.fireChannelRead(Unpooled.wrappedBuffer(handshakeInit.getFastOpenRequest().asReadOnlyByteBuffer())); + } + } + + private static UUID aci(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException { + try { + return UUIDUtil.fromByteString(handshakePayload.getAci()); + } catch (IllegalArgumentException e) { + throw new NoiseHandshakeException("Could not parse aci"); + } + } + + private static byte deviceId(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException { + if (!DeviceIdUtil.isValid(handshakePayload.getDeviceId())) { + throw new NoiseHandshakeException("Invalid deviceId"); + } + return (byte) handshakePayload.getDeviceId(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeInit.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeInit.java new file mode 100644 index 000000000..d5368f366 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeInit.java @@ -0,0 +1,33 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.DefaultByteBufHolder; +import java.net.InetAddress; + +/** + * A message that includes the initiator's handshake message, connection metadata, and the handshake type. The metadata + * and handshake type are extracted from the framing layer, so this allows receivers to be framing layer agnostic. + */ +public class NoiseHandshakeInit extends DefaultByteBufHolder { + + private final InetAddress remoteAddress; + private final HandshakePattern handshakePattern; + + public NoiseHandshakeInit( + final InetAddress remoteAddress, + final HandshakePattern handshakePattern, + final ByteBuf initiatorHandshakeMessage) { + super(initiatorHandshakeMessage); + this.remoteAddress = remoteAddress; + this.handshakePattern = handshakePattern; + } + + public InetAddress getRemoteAddress() { + return remoteAddress; + } + + public HandshakePattern getHandshakePattern() { + return handshakePattern; + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java index 130bda9d8..a74c3486f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java @@ -1,5 +1,6 @@ package org.whispersystems.textsecuregcm.grpc.net; +import java.net.InetAddress; import java.util.Optional; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; @@ -9,5 +10,12 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; * * @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a * type that performs authentication + * @param remoteAddress the remote address of the connecting client + * @param userAgent the client supplied userAgent + * @param acceptLanguage the client supplied acceptLanguage */ -public record NoiseIdentityDeterminedEvent(Optional authenticatedDevice) {} +public record NoiseIdentityDeterminedEvent( + Optional authenticatedDevice, + InetAddress remoteAddress, + String userAgent, + String acceptLanguage) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java index ca2d335bc..3ef4f84f1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java @@ -9,6 +9,7 @@ package org.whispersystems.textsecuregcm.grpc.net; */ public record OutboundCloseErrorMessage(Code code, String message) { public enum Code { + /** * The server decided to close the connection. This could be because the server is going away, or it could be * because the credentials for the connected client have been updated. @@ -25,11 +26,6 @@ public record OutboundCloseErrorMessage(Code code, String message) { */ NOISE_HANDSHAKE_ERROR, - /** - * The provided credentials were not valid - */ - AUTHENTICATION_ERROR, - INTERNAL_SERVER_ERROR } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java index 96db6280d..3545d5797 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java @@ -30,12 +30,12 @@ public class NoiseDirectFrame extends DefaultByteBufHolder { public enum FrameType { /** - * The payload is the initiator message or the responder message for a Noise NK handshake. If established, the + * The payload is the initiator message for a Noise NK handshake. If established, the * session will be unauthenticated. */ NK_HANDSHAKE((byte) 1), /** - * The payload is the initiator message or the responder message for a Noise IK handshake. If established, the + * The payload is the initiator message for a Noise IK handshake. If established, the * session will be authenticated. */ IK_HANDSHAKE((byte) 2), @@ -44,9 +44,10 @@ public class NoiseDirectFrame extends DefaultByteBufHolder { */ DATA((byte) 3), /** - * A framing layer error occurred. The payload carries error details. + * A frame sent before the connection is closed. The payload is a protobuf indicating why the connection is being + * closed. */ - ERROR((byte) 4); + CLOSE((byte) 4); private final byte frameType; @@ -64,7 +65,7 @@ public class NoiseDirectFrame extends DefaultByteBufHolder { public boolean isHandshake() { return switch (this) { case IK_HANDSHAKE, NK_HANDSHAKE -> true; - case DATA, ERROR -> false; + case DATA, CLOSE -> false; }; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java index 184a58323..4d94d0400 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java @@ -76,7 +76,7 @@ public class NoiseDirectFrameCodec extends ChannelDuplexHandler { case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE; case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE; case 3 -> NoiseDirectFrame.FrameType.DATA; - case 4 -> NoiseDirectFrame.FrameType.ERROR; + case 4 -> NoiseDirectFrame.FrameType.CLOSE; default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits); }; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java index a2204b660..bf72bf927 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java @@ -4,66 +4,44 @@ */ package org.whispersystems.textsecuregcm.grpc.net.noisedirect; -import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.util.ReferenceCountUtil; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler; -import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; - import java.io.IOException; import java.net.InetSocketAddress; +import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit; /** - * Waits for a Handshake {@link NoiseDirectFrame} and then installs a {@link NoiseDirectDataFrameCodec} and - * {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} and removes itself + * Waits for a Handshake {@link NoiseDirectFrame} and then replaces itself with a {@link NoiseDirectDataFrameCodec} and + * forwards the handshake frame along as a {@link NoiseHandshakeInit} message */ public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter { - private final ClientPublicKeysManager clientPublicKeysManager; - private final ECKeyPair ecKeyPair; - - public NoiseDirectHandshakeSelector(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) { - this.clientPublicKeysManager = clientPublicKeysManager; - this.ecKeyPair = ecKeyPair; - } - - @Override public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { if (msg instanceof NoiseDirectFrame frame) { try { - // We've received an inbound handshake frame so we know what kind of NoiseHandler we need (authenticated or - // anonymous). We construct it here, and then remember the handshake type so we can annotate our handshake - // response with the correct frame type whenever we receive it. - final ChannelDuplexHandler noiseHandler = switch (frame.frameType()) { - case DATA, ERROR -> - throw new NoiseHandshakeException("Invalid frame type for first message " + frame.frameType()); - case IK_HANDSHAKE -> new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair); - case NK_HANDSHAKE -> new NoiseAnonymousHandler(ecKeyPair); - }; - if (ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) { - // TODO: Provide connection metadata / headers in handshake payload - GrpcClientConnectionManager.handleHandshakeInitiated(ctx.channel(), - inetSocketAddress.getAddress(), - "NoiseDirect", - ""); - - } else { + if (!(ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress)) { throw new IOException("Could not determine remote address"); } + // We've received an inbound handshake frame. Pull the framing-protocol specific data the downstream handler + // needs into a NoiseHandshakeInit message and forward that along + final NoiseHandshakeInit handshakeMessage = new NoiseHandshakeInit(inetSocketAddress.getAddress(), + switch (frame.frameType()) { + case DATA -> throw new NoiseHandshakeException("First message must have handshake frame type"); + case CLOSE -> throw new IllegalStateException("Close frames should not reach handshake selector"); + case IK_HANDSHAKE -> HandshakePattern.IK; + case NK_HANDSHAKE -> HandshakePattern.NK; + }, frame.content()); // Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames // should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded // for network serialization. Note that we need to install the Data frame handler before firing the read, // because we may receive an outbound message from the noiseHandler - ctx.pipeline().addAfter(ctx.name(), null, noiseHandler); ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec()); - ctx.fireChannelRead(frame.content()); + ctx.fireChannelRead(handshakeMessage); } catch (Exception e) { ReferenceCountUtil.release(msg); throw e; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectInboundCloseHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectInboundCloseHandler.java new file mode 100644 index 000000000..81dbc7567 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectInboundCloseHandler.java @@ -0,0 +1,36 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc.net.noisedirect; + +import io.micrometer.core.instrument.Metrics; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.ReferenceCountUtil; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; + + +/** + * Watches for inbound close frames and closes the connection in response + */ +public class NoiseDirectInboundCloseHandler extends ChannelInboundHandlerAdapter { + private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(ChannelInboundHandlerAdapter.class, "clientClose"); + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { + if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) { + try { + final NoiseDirectProtos.CloseReason closeReason = NoiseDirectProtos.CloseReason + .parseFrom(ByteBufUtil.getBytes(ndf.content())); + + Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "reason", closeReason.getCode().name()).increment(); + } finally { + ReferenceCountUtil.release(msg); + ctx.close(); + } + } else { + ctx.fireChannelRead(msg); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java index 500400cd9..caa289609 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java @@ -17,20 +17,19 @@ class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter { @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (msg instanceof OutboundCloseErrorMessage err) { - final NoiseDirectProtos.Error.Type type = switch (err.code()) { - case SERVER_CLOSED -> NoiseDirectProtos.Error.Type.UNAVAILABLE; - case NOISE_ERROR -> NoiseDirectProtos.Error.Type.ENCRYPTION_ERROR; - case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.Error.Type.HANDSHAKE_ERROR; - case AUTHENTICATION_ERROR -> NoiseDirectProtos.Error.Type.AUTHENTICATION_ERROR; - case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.Error.Type.INTERNAL_ERROR; + final NoiseDirectProtos.CloseReason.Code code = switch (err.code()) { + case SERVER_CLOSED -> NoiseDirectProtos.CloseReason.Code.UNAVAILABLE; + case NOISE_ERROR -> NoiseDirectProtos.CloseReason.Code.ENCRYPTION_ERROR; + case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.CloseReason.Code.HANDSHAKE_ERROR; + case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.CloseReason.Code.INTERNAL_ERROR; }; - final NoiseDirectProtos.Error proto = NoiseDirectProtos.Error.newBuilder() - .setType(type) + final NoiseDirectProtos.CloseReason proto = NoiseDirectProtos.CloseReason.newBuilder() + .setCode(code) .setMessage(err.message()) .build(); final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize()); proto.writeTo(new ByteBufOutputStream(byteBuf)); - ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.ERROR, byteBuf)) + ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.CLOSE, byteBuf)) .addListener(ChannelFutureListener.CLOSE); } else { ctx.write(msg, promise); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java index 3bca2e69e..89720ee65 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java @@ -19,6 +19,7 @@ import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler; import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeHandler; import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; @@ -50,18 +51,20 @@ public class NoiseDirectTunnelServer implements Managed { protected void initChannel(SocketChannel socketChannel) { socketChannel.pipeline() .addLast(new ProxyProtocolDetectionHandler()) - .addLast(new HAProxyMessageHandler()); - - socketChannel.pipeline() + .addLast(new HAProxyMessageHandler()) // frame byte followed by a 2-byte length field .addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2)) // Parses NoiseDirectFrames from wire bytes and vice versa .addLast(new NoiseDirectFrameCodec()) + // Terminate the connection if the client sends us a close frame + .addLast(new NoiseDirectInboundCloseHandler()) // Turn generic OutboundCloseErrorMessages into noise direct error frames .addLast(new NoiseDirectOutboundErrorHandler()) - // Waits for the handshake to finish and then replaces itself with a NoiseDirectFrameCodec and a - // NoiseHandler to handle noise encryption/decryption - .addLast(new NoiseDirectHandshakeSelector(clientPublicKeysManager, ecKeyPair)) + // Forwards the first payload supplemented with handshake metadata, and then replaces itself with a + // NoiseDirectDataFrameCodec to handle subsequent data frames + .addLast(new NoiseDirectHandshakeSelector()) + // Performs the noise handshake and then replace itself with a NoiseHandler + .addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair)) // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler // once the Noise handshake has completed .addLast(new EstablishLocalGrpcConnectionHandler( diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java index c4dc8528a..67a24b99f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java @@ -4,8 +4,7 @@ import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; enum ApplicationWebSocketCloseReason { NOISE_HANDSHAKE_ERROR(4001), - CLIENT_AUTHENTICATION_ERROR(4002), - NOISE_ENCRYPTION_ERROR(4003); + NOISE_ENCRYPTION_ERROR(4002); private final int statusCode; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java index cd37ac603..ba4835e07 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java @@ -108,9 +108,12 @@ public class NoiseWebSocketTunnelServer implements Managed { .addLast(new WebSocketOutboundErrorHandler()) .addLast(new RejectUnsupportedMessagesHandler()) .addLast(new WebsocketPayloadCodec()) - // The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once - // a WebSocket handshake has been completed - .addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret)) + // The WebSocket handshake complete listener will forward the first payload supplemented with + // data from the websocket handshake completion event, and then remove itself from the pipeline + .addLast(new WebsocketHandshakeCompleteHandler(recognizedProxySecret)) + // The NoiseHandshakeHandler will perform the noise handshake and then replace itself with a + // NoiseHandler + .addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair)) // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler // once the Noise handshake has completed .addLast(new EstablishLocalGrpcConnectionHandler(grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress)) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java index 4dfb54be4..bf63b206e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java @@ -7,14 +7,9 @@ import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; -import javax.crypto.BadPaddingException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.grpc.net.ClientAuthenticationException; -import org.whispersystems.textsecuregcm.grpc.net.NoiseException; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage; -import org.whispersystems.textsecuregcm.util.ExceptionUtils; /** * Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames @@ -46,7 +41,6 @@ class WebSocketOutboundErrorHandler extends ChannelDuplexHandler { case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code(); case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode(); case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode(); - case AUTHENTICATION_ERROR -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode(); case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code(); }; ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java index b41e4328d..7093934bd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java @@ -2,14 +2,14 @@ package org.whispersystems.textsecuregcm.grpc.net.websocket; import com.google.common.annotations.VisibleForTesting; import com.google.common.net.InetAddresses; +import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import io.netty.util.ReferenceCountUtil; import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; @@ -17,13 +17,10 @@ import java.security.MessageDigest; import java.util.Optional; import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; -import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler; -import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; +import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit; /** * A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate @@ -31,10 +28,6 @@ import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; */ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { - private final ClientPublicKeysManager clientPublicKeysManager; - - private final ECKeyPair ecKeyPair; - private final byte[] recognizedProxySecret; private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class); @@ -45,12 +38,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { @VisibleForTesting static final String FORWARDED_FOR_HEADER = "X-Forwarded-For"; - WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager, - final ECKeyPair ecKeyPair, - final String recognizedProxySecret) { + private InetAddress remoteAddress = null; + private HandshakePattern handshakePattern = null; - this.clientPublicKeysManager = clientPublicKeysManager; - this.ecKeyPair = ecKeyPair; + WebsocketHandshakeCompleteHandler(final String recognizedProxySecret) { // The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex- // encoded value). We convert it into a byte array here for easier constant-time comparisons via @@ -61,53 +52,58 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { @Override public void userEventTriggered(final ChannelHandlerContext context, final Object event) { if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) { - final InetAddress preferredRemoteAddress; - { - final Optional maybePreferredRemoteAddress = - getPreferredRemoteAddress(context, handshakeCompleteEvent); + final Optional maybePreferredRemoteAddress = + getPreferredRemoteAddress(context, handshakeCompleteEvent); - if (maybePreferredRemoteAddress.isEmpty()) { - context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, - "Could not determine remote address")) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + if (maybePreferredRemoteAddress.isEmpty()) { + context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, + "Could not determine remote address")) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - return; - } - - preferredRemoteAddress = maybePreferredRemoteAddress.get(); + return; } - GrpcClientConnectionManager.handleHandshakeInitiated(context.channel(), - preferredRemoteAddress, - handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT), - handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE)); - - final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) { - case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> - new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair); - - case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> - new NoiseAnonymousHandler(ecKeyPair); - - default -> { - // The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an - // internal error if something slipped through. - throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri()); - } + remoteAddress = maybePreferredRemoteAddress.get(); + handshakePattern = switch (handshakeCompleteEvent.requestUri()) { + case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> HandshakePattern.IK; + case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> HandshakePattern.NK; + // The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an + // internal error if something slipped through. + default -> throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri()); }; - - context.pipeline().replace(WebsocketHandshakeCompleteHandler.this, null, noiseHandshakeHandler); } context.fireUserEventTriggered(event); } + @Override + public void channelRead(final ChannelHandlerContext context, final Object msg) { + try { + if (!(msg instanceof ByteBuf frame)) { + throw new IllegalStateException("Unexpected msg type: " + msg.getClass()); + } + + if (handshakePattern == null || remoteAddress == null) { + throw new IllegalStateException("Received payload before websocket handshake complete"); + } + + final NoiseHandshakeInit handshakeMessage = + new NoiseHandshakeInit(remoteAddress, handshakePattern, frame); + + context.pipeline().remove(WebsocketHandshakeCompleteHandler.class); + context.fireChannelRead(handshakeMessage); + } catch (Exception e) { + ReferenceCountUtil.release(msg); + throw e; + } + } + private Optional getPreferredRemoteAddress(final ChannelHandlerContext context, final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) { final byte[] recognizedProxySecretFromHeader = handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "") - .getBytes(StandardCharsets.UTF_8); + .getBytes(StandardCharsets.UTF_8); final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader); diff --git a/service/src/main/proto/NoiseDirect.proto b/service/src/main/proto/NoiseDirect.proto index 10a7801df..a5c1ecf8c 100644 --- a/service/src/main/proto/NoiseDirect.proto +++ b/service/src/main/proto/NoiseDirect.proto @@ -8,15 +8,46 @@ syntax = "proto3"; option java_package = "org.whispersystems.textsecuregcm.grpc.net.noisedirect"; option java_outer_classname = "NoiseDirectProtos"; -message Error { - enum Type { +message CloseReason { + enum Code { UNSPECIFIED = 0; - HANDSHAKE_ERROR = 1; - ENCRYPTION_ERROR = 2; - UNAVAILABLE = 3; - INTERNAL_ERROR = 4; - AUTHENTICATION_ERROR = 5; + // Indicates non-error termination + // Examples: + // - The client is finished with the connection + OK = 1; + + // There was an issue with the handshake. If sent after a handshake response, + // the response includes more information about the nature of the error + // Examples: + // - The client did not provide a handshake message + // - The client had incorrect authentication credentials. The handshake + // payload includes additional details + HANDSHAKE_ERROR = 2; + + // There was an encryption/decryption issue after the handshake + // Examples: + // - The client incorrectly encrypted a noise message and it had a bad + // AEAD tag + ENCRYPTION_ERROR = 3; + + // The server is temporarily unavailable, going away, or requires a + // connection reset + // Examples: + // - The server is shutting down + // - The client’s authentication credentials have been rotated + UNAVAILABLE = 4; + + // There was an an internal error + // Examples: + // - The server experienced a temporary database outage that prevented it + // from checking the client's credentials + INTERNAL_ERROR = 5; } - Type type = 1; + + Code code = 1; + + // If present, includes details about the error. Implementations should never + // parse or otherwise implement conditional logic based on the contents of the + // error message string, it is for logging and debugging purposes only. string message = 2; } diff --git a/service/src/main/proto/NoiseTunnel.proto b/service/src/main/proto/NoiseTunnel.proto new file mode 100644 index 000000000..45695e20a --- /dev/null +++ b/service/src/main/proto/NoiseTunnel.proto @@ -0,0 +1,56 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +syntax = "proto3"; + +option java_package = "org.whispersystems.textsecuregcm.grpc.net"; +option java_outer_classname = "NoiseTunnelProtos"; + +message HandshakeInit { + string user_agent = 1; + + // An Accept-Language as described in + // https://httpwg.org/specs/rfc9110.html#field.accept-language + string accept_language = 2; + + // A UUID serialized as 16 bytes (big end first). Must be unset (empty) for an + // unauthenticated handshake + bytes aci = 3; + + // The deviceId, 0 < deviceId < 128. Must be unset for an unauthenticated + // handshake + uint32 device_id = 4; + + // The first bytes of the application request byte stream, may contain less + // than a full request + bytes fast_open_request = 5; +} + +message HandshakeResponse { + enum Code { + UNSPECIFIED = 0; + + // The noise session may be used to send application layer requests + OK = 1; + + // The provided client static key did not match the registered public key + // for the provided aci/deviceId. + WRONG_PUBLIC_KEY = 2; + + // The client version is to old, it should be upgraded before retrying + DEPRECATED = 3; + } + + // The handshake outcome + Code code = 1; + + // Additional information about an error status, for debugging only + string error_details = 2; + + // An optional response to a fast_open_request provided in the HandshakeInit. + // Note that a response may not be present even if a fast_open_request was + // present. If so, the response will be returned in a later message. + bytes fast_open_response = 3; +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java index 0b699b543..ffbd5082d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; import com.southernstorm.noise.protocol.CipherStatePair; import com.southernstorm.noise.protocol.Noise; @@ -16,12 +17,13 @@ import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.util.ReferenceCountUtil; +import java.net.InetAddress; +import java.net.UnknownHostException; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.concurrent.ThreadLocalRandom; @@ -36,16 +38,29 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.util.TestRandomUtil; abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { protected ECKeyPair serverKeyPair; + protected ClientPublicKeysManager clientPublicKeysManager; private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler; private EmbeddedChannel embeddedChannel; + static final String USER_AGENT = "Test/User-Agent"; + static final String ACCEPT_LANGUAGE = "test-lang"; + static final InetAddress REMOTE_ADDRESS; + static { + try { + REMOTE_ADDRESS = InetAddress.getByAddress(new byte[]{0,1,2,3}); + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + } + private static class PongHandler extends ChannelInboundHandlerAdapter { @Override @@ -93,7 +108,10 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { void setUp() { serverKeyPair = Curve.generateKeyPair(); noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler(); - embeddedChannel = new EmbeddedChannel(getHandler(serverKeyPair), noiseHandshakeCompleteHandler); + clientPublicKeysManager = mock(ClientPublicKeysManager.class); + embeddedChannel = new EmbeddedChannel( + new NoiseHandshakeHandler(clientPublicKeysManager, serverKeyPair), + noiseHandshakeCompleteHandler); } @AfterEach @@ -110,8 +128,6 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent(); } - protected abstract ChannelHandler getHandler(final ECKeyPair serverKeyPair); - protected abstract CipherStatePair doHandshake() throws Throwable; /** @@ -140,7 +156,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { final ByteBuf content = Unpooled.wrappedBuffer(contentBytes); - final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(content).await(); + final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new NoiseHandshakeInit(REMOTE_ADDRESS, HandshakePattern.IK, content)).await(); assertFalse(writeFuture.isSuccess()); assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause()); @@ -292,4 +308,19 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big)); assertThrows(NoiseException.class, embeddedChannel::checkException); } + + @Test + public void channelAttributes() throws Throwable { + doHandshake(); + final NoiseIdentityDeterminedEvent event = getNoiseHandshakeCompleteEvent(); + assertEquals(REMOTE_ADDRESS, event.remoteAddress()); + assertEquals(USER_AGENT, event.userAgent()); + assertEquals(ACCEPT_LANGUAGE, event.acceptLanguage()); + } + + protected NoiseTunnelProtos.HandshakeInit.Builder baseHandshakeInit() { + return NoiseTunnelProtos.HandshakeInit.newBuilder() + .setUserAgent(USER_AGENT) + .setAcceptLanguage(ACCEPT_LANGUAGE); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java index f2d75e196..8bccb5971 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java @@ -248,7 +248,10 @@ public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractL } finally { channel.shutdown(); } - assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR); + assertEquals( + NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY, + client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode()); + assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); } } @@ -269,12 +272,35 @@ public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractL } finally { channel.shutdown(); } - - assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR); + assertEquals( + NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY, + client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode()); + assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); } } + @Test + void clientNormalClosure() throws InterruptedException { + final NoiseTunnelClient client = anonymous().build(); + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + try { + final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); + + assertTrue(response.getAccountIdentifier().isEmpty()); + assertEquals(0, response.getDeviceId()); + client.close(); + + // When we gracefully close the tunnel client, we should send an OK close frame + final CloseFrameEvent closeFrame = client.closeFrameFuture().join(); + assertEquals(CloseFrameEvent.CloseInitiator.CLIENT, closeFrame.closeInitiator()); + assertEquals(CloseFrameEvent.CloseReason.OK, closeFrame.closeReason()); + } finally { + channel.shutdown(); + } + } + @Test void connectAnonymous() throws InterruptedException { try (final NoiseTunnelClient client = anonymous().build()) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java index 987e7f166..0a775ccbf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java @@ -8,28 +8,23 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import com.google.protobuf.ByteString; import com.southernstorm.noise.protocol.CipherStatePair; import com.southernstorm.noise.protocol.HandshakeState; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import java.util.Optional; import javax.crypto.BadPaddingException; import javax.crypto.ShortBufferException; import org.junit.jupiter.api.Test; -import org.signal.libsignal.protocol.ecc.ECKeyPair; class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest { - @Override - protected NoiseAnonymousHandler getHandler(final ECKeyPair serverKeyPair) { - - return new NoiseAnonymousHandler(serverKeyPair); - } - @Override protected CipherStatePair doHandshake() throws Exception { - return doHandshake(new byte[0]); + return doHandshake(baseHandshakeInit().build().toByteArray()); } private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception { @@ -49,26 +44,35 @@ class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest { assertEquals( initiateHandshakeMessageLength, clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length)); - final ByteBuf initiateHandshakeMessageBuf = Unpooled.wrappedBuffer(initiateHandshakeMessage); - assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeMessageBuf).await().isSuccess()); - assertEquals(0, initiateHandshakeMessageBuf.refCnt()); + final NoiseHandshakeInit message = new NoiseHandshakeInit( + REMOTE_ADDRESS, + HandshakePattern.NK, + Unpooled.wrappedBuffer(initiateHandshakeMessage)); + assertTrue(embeddedChannel.writeOneInbound(message).await().isSuccess()); + assertEquals(0, message.refCnt()); embeddedChannel.runPendingTasks(); // Read responder handshake message assertFalse(embeddedChannel.outboundMessages().isEmpty()); final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); - @SuppressWarnings("DataFlowIssue") final byte[] responderHandshakeBytes = - new byte[responderHandshakeFrame.readableBytes()]; - responderHandshakeFrame.readBytes(responderHandshakeBytes); + assertNotNull(responderHandshakeFrame); + final byte[] responderHandshakeBytes = ByteBufUtil.getBytes(responderHandshakeFrame); - // ephemeral key, empty encrypted payload AEAD tag - final byte[] handshakeResponsePayload = new byte[32 + 16]; + final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponse = NoiseTunnelProtos.HandshakeResponse.newBuilder() + .setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK) + .build(); - assertEquals(0, + // ephemeral key, payload, AEAD tag + assertEquals(32 + expectedHandshakeResponse.getSerializedSize() + 16, responderHandshakeBytes.length); + + final byte[] handshakeResponsePlaintext = new byte[expectedHandshakeResponse.getSerializedSize()]; + assertEquals(expectedHandshakeResponse.getSerializedSize(), clientHandshakeState.readMessage( responderHandshakeBytes, 0, responderHandshakeBytes.length, - handshakeResponsePayload, 0)); + handshakeResponsePlaintext, 0)); + + assertEquals(expectedHandshakeResponse, NoiseTunnelProtos.HandshakeResponse.parseFrom(handshakeResponsePlaintext)); final byte[] serverPublicKey = new byte[32]; clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); @@ -78,27 +82,35 @@ class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest { } @Test - void handleCompleteHandshakeWithRequest() throws ShortBufferException, BadPaddingException { + void handleCompleteHandshakeWithRequest() throws Exception { final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class)); + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake("ping".getBytes())); + final byte[] handshakePlaintext = baseHandshakeInit() + .setFastOpenRequest(ByteString.copyFromUtf8("ping")).build() + .toByteArray(); + + final CipherStatePair cipherStatePair = doHandshake(handshakePlaintext); final byte[] response = readNextPlaintext(cipherStatePair); assertArrayEquals(response, "pong".getBytes()); - assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent()); + assertEquals( + new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE), + getNoiseHandshakeCompleteEvent()); } @Test void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException { final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class)); + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake(new byte[0])); + final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake()); assertNull(readNextPlaintext(cipherStatePair)); - assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent()); + assertEquals( + new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE), + getNoiseHandshakeCompleteEvent()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java index b498ef954..cfb85a526 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java @@ -8,18 +8,20 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import com.google.protobuf.ByteString; import com.southernstorm.noise.protocol.CipherStatePair; import com.southernstorm.noise.protocol.HandshakeState; import com.southernstorm.noise.protocol.Noise; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.internal.EmptyArrays; +import java.io.IOException; import java.nio.ByteBuffer; import java.security.NoSuchAlgorithmException; import java.util.Optional; @@ -28,40 +30,24 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ThreadLocalRandom; import javax.crypto.BadPaddingException; import javax.crypto.ShortBufferException; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil; class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { - private ClientPublicKeysManager clientPublicKeysManager; private final ECKeyPair clientKeyPair = Curve.generateKeyPair(); - @Override - @BeforeEach - void setUp() { - clientPublicKeysManager = mock(ClientPublicKeysManager.class); - - super.setUp(); - } - - @Override - protected NoiseAuthenticatedHandler getHandler(final ECKeyPair serverKeyPair) { - return new NoiseAuthenticatedHandler(clientPublicKeysManager, serverKeyPair); - } - @Override protected CipherStatePair doHandshake() throws Throwable { final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(1, Device.MAXIMUM_DEVICE_ID + 1); when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); return doHandshake(identityPayload(accountIdentifier, deviceId)); @@ -71,7 +57,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { void handleCompleteHandshakeNoInitialRequest() throws Throwable { final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); @@ -81,7 +67,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId)))); - assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))), + assertEquals( + new NoiseIdentityDeterminedEvent( + Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)), + REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE), getNoiseHandshakeCompleteEvent()); } @@ -89,7 +78,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { void handleCompleteHandshakeWithInitialRequest() throws Throwable { final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); @@ -97,15 +86,19 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); - final ByteBuffer bb = ByteBuffer.allocate(17 + 4); - bb.put(identityPayload(accountIdentifier, deviceId)); - bb.put("ping".getBytes()); + final byte[] handshakeInit = identifiedHandshakeInit(accountIdentifier, deviceId) + .setFastOpenRequest(ByteString.copyFromUtf8("ping")) + .build() + .toByteArray(); - final byte[] response = readNextPlaintext(doHandshake(bb.array())); - assertEquals(response.length, 4); - assertEquals(new String(response), "pong"); + final byte[] response = readNextPlaintext(doHandshake(handshakeInit)); + assertEquals(4, response.length); + assertEquals("pong", new String(response)); - assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))), + assertEquals( + new NoiseIdentityDeterminedEvent( + Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)), + REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE), getNoiseHandshakeCompleteEvent()); } @@ -113,7 +106,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { void handleCompleteHandshakeMissingIdentityInformation() { final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES)); @@ -121,7 +114,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { assertNull(getNoiseHandshakeCompleteEvent()); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class), + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class), "Handshake handler should not remove self from pipeline after failed handshake"); assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class), @@ -132,7 +125,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { void handleCompleteHandshakeMalformedIdentityInformation() { final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); // no deviceId byte byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID()); @@ -142,7 +135,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { assertNull(getNoiseHandshakeCompleteEvent()); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class), + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class), "Handshake handler should not remove self from pipeline after failed handshake"); assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class), @@ -150,10 +143,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { } @Test - void handleCompleteHandshakeUnrecognizedDevice() { + void handleCompleteHandshakeUnrecognizedDevice() throws Throwable { final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); @@ -161,11 +154,13 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId))); + doHandshake( + identityPayload(accountIdentifier, deviceId), + NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY); assertNull(getNoiseHandshakeCompleteEvent()); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class), + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class), "Handshake handler should not remove self from pipeline after failed handshake"); assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class), @@ -173,10 +168,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { } @Test - void handleCompleteHandshakePublicKeyMismatch() { + void handleCompleteHandshakePublicKeyMismatch() throws Throwable { final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); @@ -184,18 +179,21 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) .thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey()))); - assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId))); + doHandshake( + identityPayload(accountIdentifier, deviceId), + NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY); assertNull(getNoiseHandshakeCompleteEvent()); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class), + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class), "Handshake handler should not remove self from pipeline after failed handshake"); } @Test - void handleInvalidExtraWrites() throws NoSuchAlgorithmException, ShortBufferException, InterruptedException { + void handleInvalidExtraWrites() + throws NoSuchAlgorithmException, ShortBufferException, InterruptedException { final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class)); + assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); @@ -205,25 +203,23 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>(); when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture); - final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer( - initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))); - assertTrue(embeddedChannel.writeOneInbound(initiatorMessageFrame).await().isSuccess()); + final NoiseHandshakeInit handshakeInit = new NoiseHandshakeInit( + REMOTE_ADDRESS, + HandshakePattern.IK, + Unpooled.wrappedBuffer( + initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId)))); + assertTrue(embeddedChannel.writeOneInbound(handshakeInit).await().isSuccess()); // While waiting for the public key, send another message final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await(); - assertInstanceOf(NoiseHandshakeException.class, f.exceptionNow()); + assertInstanceOf(IllegalArgumentException.class, f.exceptionNow()); findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey())); embeddedChannel.runPendingTasks(); - - // shouldn't return any response or error, we've already processed an error - embeddedChannel.checkException(); - assertNull(embeddedChannel.outboundMessages().poll()); } @Test public void handleOversizeHandshakeMessage() { - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1); ByteBuffer.wrap(big) .put(UUIDUtil.toBytes(UUID.randomUUID())) @@ -231,6 +227,15 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { assertThrows(NoiseHandshakeException.class, () -> doHandshake(big)); } + @Test + public void handleKeyLookupError() throws Throwable { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.failedFuture(new IOException())); + assertThrows(IOException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId))); + } + private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException { final HandshakeState clientHandshakeState = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR); @@ -262,15 +267,22 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { } private CipherStatePair doHandshake(final byte[] payload) throws Throwable { + return doHandshake(payload, NoiseTunnelProtos.HandshakeResponse.Code.OK); + } + + private CipherStatePair doHandshake(final byte[] payload, final NoiseTunnelProtos.HandshakeResponse.Code expectedStatus) throws Throwable { final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); final HandshakeState clientHandshakeState = clientHandshakeState(); final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload); - final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer(initiatorMessage); - final ChannelFuture await = embeddedChannel.writeOneInbound(initiatorMessageFrame).await(); - assertEquals(0, initiatorMessageFrame.refCnt()); - if (!await.isSuccess()) { + final NoiseHandshakeInit initMessage = new NoiseHandshakeInit( + REMOTE_ADDRESS, + HandshakePattern.IK, + Unpooled.wrappedBuffer(initiatorMessage)); + final ChannelFuture await = embeddedChannel.writeOneInbound(initMessage).await(); + assertEquals(0, initMessage.refCnt()); + if (!await.isSuccess() && expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) { throw await.cause(); } @@ -280,17 +292,27 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { // and issue a "handshake complete" event. embeddedChannel.runPendingTasks(); - // rethrow if running the task caused an error - embeddedChannel.checkException(); + // rethrow if running the task caused an error, and the caller isn't expecting an error + if (expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) { + embeddedChannel.checkException(); + } assertFalse(embeddedChannel.outboundMessages().isEmpty()); - final ByteBuf serverStaticKeyMessageFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); - @SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes = - new byte[serverStaticKeyMessageFrame.readableBytes()]; - serverStaticKeyMessageFrame.readBytes(serverStaticKeyMessageBytes); + final ByteBuf handshakeResponseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); + assertNotNull(handshakeResponseFrame); + final byte[] handshakeResponseCiphertextBytes = ByteBufUtil.getBytes(handshakeResponseFrame); - assertEquals(readHandshakeResponse(clientHandshakeState, serverStaticKeyMessageBytes).length, 0); + final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponsePlaintext = NoiseTunnelProtos.HandshakeResponse.newBuilder() + .setCode(expectedStatus) + .build(); + + final byte[] actualHandshakeResponsePlaintext = + readHandshakeResponse(clientHandshakeState, handshakeResponseCiphertextBytes); + + assertEquals( + expectedHandshakeResponsePlaintext, + NoiseTunnelProtos.HandshakeResponse.parseFrom(actualHandshakeResponsePlaintext)); final byte[] serverPublicKey = new byte[32]; clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); @@ -299,13 +321,15 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { return clientHandshakeState.split(); } + private NoiseTunnelProtos.HandshakeInit.Builder identifiedHandshakeInit(final UUID accountIdentifier, final byte deviceId) { + return baseHandshakeInit() + .setAci(UUIDUtil.toByteString(accountIdentifier)) + .setDeviceId(deviceId); + } - private static byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) { - final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17); - clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits()); - clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits()); - clientIdentityPayloadBuffer.put(deviceId); - clientIdentityPayloadBuffer.flip(); - return clientIdentityPayloadBuffer.array(); + private byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) { + return identifiedHandshakeInit(accountIdentifier, deviceId) + .build() + .toByteArray(); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java index 4f118b99c..13c8b9924 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java @@ -10,10 +10,10 @@ import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos; public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) { public enum CloseReason { + OK, SERVER_CLOSED, NOISE_ERROR, NOISE_HANDSHAKE_ERROR, - AUTHENTICATION_ERROR, INTERNAL_SERVER_ERROR, UNKNOWN } @@ -27,27 +27,27 @@ public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeIniti CloseWebSocketFrame closeWebSocketFrame, CloseInitiator closeInitiator) { final CloseReason code = switch (closeWebSocketFrame.statusCode()) { - case 4003 -> CloseReason.NOISE_ERROR; case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR; - case 4002 -> CloseReason.AUTHENTICATION_ERROR; + case 4002 -> CloseReason.NOISE_ERROR; case 1011 -> CloseReason.INTERNAL_SERVER_ERROR; case 1012 -> CloseReason.SERVER_CLOSED; + case 1000 -> CloseReason.OK; default -> CloseReason.UNKNOWN; }; return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText()); } - public static CloseFrameEvent fromNoiseDirectErrorFrame( - NoiseDirectProtos.Error noiseDirectError, + public static CloseFrameEvent fromNoiseDirectCloseFrame( + NoiseDirectProtos.CloseReason noiseDirectCloseReason, CloseInitiator closeInitiator) { - final CloseReason code = switch (noiseDirectError.getType()) { + final CloseReason code = switch (noiseDirectCloseReason.getCode()) { + case OK -> CloseReason.OK; case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR; case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR; case UNAVAILABLE -> CloseReason.SERVER_CLOSED; case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR; - case AUTHENTICATION_ERROR -> CloseReason.AUTHENTICATION_ERROR; case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN; }; - return new CloseFrameEvent(code, closeInitiator, noiseDirectError.getMessage()); + return new CloseFrameEvent(code, closeInitiator, noiseDirectCloseReason.getMessage()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java index c6e11c3cc..c735902b6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java @@ -11,12 +11,10 @@ import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.util.ReferenceCountUtil; import java.net.SocketAddress; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.Optional; -import javax.annotation.Nullable; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos; import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler; /** @@ -31,12 +29,10 @@ import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler; class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { private final List remoteHandlerStack; - @Nullable - private final AuthenticatedDevice authenticatedDevice; + private final NoiseTunnelProtos.HandshakeInit handshakeInit; private final SocketAddress remoteServerAddress; // If provided, will be sent with the payload in the noise handshake - private final byte[] fastOpenRequest; private final List pendingReads = new ArrayList<>(); @@ -44,13 +40,11 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { EstablishRemoteConnectionHandler( final List remoteHandlerStack, - @Nullable final AuthenticatedDevice authenticatedDevice, final SocketAddress remoteServerAddress, - @Nullable byte[] fastOpenRequest) { + final NoiseTunnelProtos.HandshakeInit handshakeInit) { this.remoteHandlerStack = remoteHandlerStack; - this.authenticatedDevice = authenticatedDevice; + this.handshakeInit = handshakeInit; this.remoteServerAddress = remoteServerAddress; - this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest; } @Override @@ -72,16 +66,19 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { throws Exception { switch (event) { case ReadyForNoiseHandshakeEvent ignored -> - remoteContext.writeAndFlush(Unpooled.wrappedBuffer(initialPayload())) + remoteContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeInit.toByteArray())) .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - case NoiseClientHandshakeCompleteEvent(Optional fastResponse) -> { + case NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) -> { remoteContext.pipeline() .replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel())); localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel())); // If there was a payload response on the handshake, write it back to our gRPC client - fastResponse.ifPresent(plaintext -> - localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext))); + if (!handshakeResponse.getFastOpenResponse().isEmpty()) { + localContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeResponse + .getFastOpenResponse() + .asReadOnlyByteBuffer())); + } // Forward any messages we got from our gRPC client, now will be proxied to the remote context pendingReads.forEach(localContext::fireChannelRead); @@ -120,17 +117,4 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { pendingReads.clear(); } - private byte[] initialPayload() { - if (authenticatedDevice == null) { - return fastOpenRequest; - } - - final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length); - bb.putLong(authenticatedDevice.accountIdentifier().getMostSignificantBits()); - bb.putLong(authenticatedDevice.accountIdentifier().getLeastSignificantBits()); - bb.put(authenticatedDevice.deviceId()); - bb.put(fastOpenRequest); - bb.flip(); - return bb.array(); - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeCompleteEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeCompleteEvent.java index 3174651df..66ef45c9f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeCompleteEvent.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeCompleteEvent.java @@ -4,6 +4,8 @@ */ package org.whispersystems.textsecuregcm.grpc.net.client; +import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos; + import java.util.Optional; /** @@ -12,4 +14,4 @@ import java.util.Optional; * @param fastResponse A response if the client included a request to send in the initiate handshake message payload and * the server included a payload in the handshake response. */ -public record NoiseClientHandshakeCompleteEvent(Optional fastResponse) {} +public record NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) {} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHandler.java index b9743bf68..c45d8639c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHandler.java @@ -1,5 +1,6 @@ package org.whispersystems.textsecuregcm.grpc.net.client; +import com.google.protobuf.InvalidProtocolBufferException; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; @@ -7,9 +8,10 @@ import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; - import java.util.Optional; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; +import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos; +import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage; public class NoiseClientHandshakeHandler extends ChannelDuplexHandler { @@ -38,9 +40,13 @@ public class NoiseClientHandshakeHandler extends ChannelDuplexHandler { if (message instanceof ByteBuf frame) { try { final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame)); - final Optional fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload); + final NoiseTunnelProtos.HandshakeResponse handshakeResponse = + NoiseTunnelProtos.HandshakeResponse.parseFrom(payload); + context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split())); - context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse)); + context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(handshakeResponse)); + } catch (InvalidProtocolBufferException e) { + throw new NoiseHandshakeException("Failed to parse handshake response"); } finally { frame.release(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientTransportHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientTransportHandler.java index cae4d82d0..b6b6f0a24 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientTransportHandler.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientTransportHandler.java @@ -8,9 +8,11 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.util.ReferenceCountUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame; /** * A Noise transport handler manages a bidirectional Noise session after a handshake has completed. @@ -72,8 +74,10 @@ public class NoiseClientTransportHandler extends ChannelDuplexHandler { ReferenceCountUtil.release(plaintext); } } else { - // Clients only write ByteBufs or close the connection on errors, so any other message is unexpected - log.warn("Unexpected object in pipeline: {}", message); + if (!(message instanceof CloseWebSocketFrame || message instanceof NoiseDirectFrame)) { + // Clients only write ByteBufs or a close frame on errors, so any other message is unexpected + log.warn("Unexpected object in pipeline: {}", message); + } context.write(message, promise); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseTunnelClient.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseTunnelClient.java index fddac67fb..e2d7d8d16 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseTunnelClient.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseTunnelClient.java @@ -1,11 +1,21 @@ package org.whispersystems.textsecuregcm.grpc.net.client; +import com.google.protobuf.ByteString; import com.southernstorm.noise.protocol.Noise; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; -import io.netty.channel.*; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalServerChannel; @@ -17,6 +27,13 @@ import io.netty.handler.codec.haproxy.HAProxyMessageEncoder; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; +import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; +import io.netty.handler.codec.http.websocketx.WebSocketVersion; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.util.ReferenceCountUtil; import java.net.SocketAddress; import java.net.URI; import java.security.cert.X509Certificate; @@ -26,26 +43,21 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.function.Function; import java.util.function.Supplier; - -import io.netty.handler.codec.http.HttpObjectAggregator; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; -import io.netty.handler.codec.http.websocketx.WebSocketVersion; -import io.netty.handler.ssl.SslContextBuilder; -import io.netty.util.ReferenceCountUtil; +import javax.net.ssl.SSLException; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos; import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame; import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrameCodec; import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos; import org.whispersystems.textsecuregcm.grpc.net.websocket.WebsocketPayloadCodec; - -import javax.net.ssl.SSLException; +import org.whispersystems.textsecuregcm.util.UUIDUtil; public class NoiseTunnelClient implements AutoCloseable { private final CompletableFuture closeEventFuture; + private final CompletableFuture handshakeEventFuture; + private final CompletableFuture userCloseFuture; private final ServerBootstrap serverBootstrap; private Channel serverChannel; @@ -66,11 +78,10 @@ public class NoiseTunnelClient implements AutoCloseable { FramingType framingType = FramingType.WEBSOCKET; URI websocketUri = ANONYMOUS_WEBSOCKET_URI; HttpHeaders headers = new DefaultHttpHeaders(); + NoiseTunnelProtos.HandshakeInit.Builder handshakeInit = NoiseTunnelProtos.HandshakeInit.newBuilder(); boolean authenticated = false; ECKeyPair ecKeyPair = null; - UUID accountIdentifier = null; - byte deviceId = 0x00; boolean useTls; X509Certificate trustedServerCertificate = null; Supplier proxyMessageSupplier = null; @@ -86,8 +97,8 @@ public class NoiseTunnelClient implements AutoCloseable { public Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) { this.authenticated = true; - this.accountIdentifier = accountIdentifier; - this.deviceId = deviceId; + handshakeInit.setAci(UUIDUtil.toByteString(accountIdentifier)); + handshakeInit.setDeviceId(deviceId); this.ecKeyPair = ecKeyPair; this.websocketUri = AUTHENTICATED_WEBSOCKET_URI; return this; @@ -109,6 +120,16 @@ public class NoiseTunnelClient implements AutoCloseable { return this; } + public Builder setUserAgent(final String userAgent) { + handshakeInit.setUserAgent(userAgent); + return this; + } + + public Builder setAcceptLanguage(final String acceptLanguage) { + handshakeInit.setAcceptLanguage(acceptLanguage); + return this; + } + public Builder setHeaders(final HttpHeaders headers) { this.headers = headers; return this; @@ -155,17 +176,41 @@ public class NoiseTunnelClient implements AutoCloseable { handlers.add(new NoiseClientHandshakeHandler(helper)); + // When the noise handshake completes we'll save the response from the server so client users can inspect it + final UserEventFuture handshakeEventHandler = + new UserEventFuture<>(NoiseClientHandshakeCompleteEvent.class); + handlers.add(handshakeEventHandler); + // Whenever the framing layer sends or receives a close frame, it will emit a CloseFrameEvent and we'll save off // information about why the connection was closed. final UserEventFuture closeEventHandler = new UserEventFuture<>(CloseFrameEvent.class); handlers.add(closeEventHandler); + // When the user closes the client, write a normal closure close frame + final CompletableFuture userCloseFuture = new CompletableFuture<>(); + handlers.add(new ChannelInboundHandlerAdapter() { + @Override + public void handlerAdded(final ChannelHandlerContext ctx) { + userCloseFuture.thenRunAsync(() -> ctx.pipeline().writeAndFlush(switch (framingType) { + case WEBSOCKET -> new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE); + case NOISE_DIRECT -> new NoiseDirectFrame( + NoiseDirectFrame.FrameType.CLOSE, + Unpooled.wrappedBuffer(NoiseDirectProtos.CloseReason + .newBuilder() + .setCode(NoiseDirectProtos.CloseReason.Code.OK) + .build() + .toByteArray())); + }) + .addListener(ChannelFutureListener.CLOSE), + ctx.executor()); + } + }); + final NoiseTunnelClient client = - new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, fastOpenRequest -> new EstablishRemoteConnectionHandler( + new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, handshakeEventHandler.future, userCloseFuture, fastOpenRequest -> new EstablishRemoteConnectionHandler( handlers, - authenticated ? new AuthenticatedDevice(accountIdentifier, deviceId) : null, remoteServerAddress, - fastOpenRequest)); + handshakeInit.setFastOpenRequest(ByteString.copyFrom(fastOpenRequest)).build())); client.start(); return client; } @@ -173,9 +218,13 @@ public class NoiseTunnelClient implements AutoCloseable { private NoiseTunnelClient(NioEventLoopGroup eventLoopGroup, CompletableFuture closeEventFuture, + CompletableFuture handshakeEventFuture, + CompletableFuture userCloseFuture, Function handler) { + this.userCloseFuture = userCloseFuture; this.closeEventFuture = closeEventFuture; + this.handshakeEventFuture = handshakeEventFuture; this.serverBootstrap = new ServerBootstrap() .localAddress(new LocalAddress("websocket-noise-tunnel-client")) .channel(LocalServerChannel.class) @@ -194,10 +243,10 @@ public class NoiseTunnelClient implements AutoCloseable { .addLast(new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { - if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) { - byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest()); - requestBufferedEvent.fastOpenRequest().release(); - ctx.pipeline().addLast(handler.apply(fastOpenRequest)); + if (evt instanceof FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest)) { + byte[] fastOpenRequestBytes = ByteBufUtil.getBytes(fastOpenRequest); + fastOpenRequest.release(); + ctx.pipeline().addLast(handler.apply(fastOpenRequestBytes)); } super.userEventTriggered(ctx, evt); } @@ -216,7 +265,7 @@ public class NoiseTunnelClient implements AutoCloseable { } @Override - public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { + public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) { if (cls.isInstance(evt)) { future.complete((T) evt); } @@ -236,6 +285,7 @@ public class NoiseTunnelClient implements AutoCloseable { @Override public void close() throws InterruptedException { + userCloseFuture.complete(null); serverChannel.close().await(); } @@ -246,6 +296,14 @@ public class NoiseTunnelClient implements AutoCloseable { return closeEventFuture; } + /** + * @return A future that completes when the noise handshake finishes + */ + public CompletableFuture getHandshakeEventFuture() { + return handshakeEventFuture; + } + + private static List noiseDirectHandlerStack(boolean authenticated) { return List.of( new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2), @@ -259,12 +317,12 @@ public class NoiseTunnelClient implements AutoCloseable { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) { + if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) { try { - final NoiseDirectProtos.Error errorPayload = - NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content())); + final NoiseDirectProtos.CloseReason closeReason = + NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content())); ctx.fireUserEventTriggered( - CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.SERVER)); + CloseFrameEvent.fromNoiseDirectCloseFrame(closeReason, CloseFrameEvent.CloseInitiator.SERVER)); } finally { ReferenceCountUtil.release(msg); } @@ -275,11 +333,11 @@ public class NoiseTunnelClient implements AutoCloseable { @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) { - final NoiseDirectProtos.Error errorPayload = - NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content())); + if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) { + final NoiseDirectProtos.CloseReason errorPayload = + NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content())); ctx.fireUserEventTriggered( - CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT)); + CloseFrameEvent.fromNoiseDirectCloseFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT)); } ctx.write(msg, promise); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/TlsWebSocketNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/TlsWebSocketNoiseTunnelServerIntegrationTest.java index 04a68714e..a38c9181e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/TlsWebSocketNoiseTunnelServerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/TlsWebSocketNoiseTunnelServerIntegrationTest.java @@ -126,11 +126,13 @@ class TlsWebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelSe final HttpHeaders headers = new DefaultHttpHeaders() .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) - .add("X-Forwarded-For", remoteAddress) - .add("Accept-Language", acceptLanguage) - .add("User-Agent", userAgent); + .add("X-Forwarded-For", remoteAddress); - try (final NoiseTunnelClient client = anonymous().setHeaders(headers).build()) { + try (final NoiseTunnelClient client = anonymous() + .setHeaders(headers) + .setUserAgent(userAgent) + .setAcceptLanguage(acceptLanguage) + .build()) { final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketNoiseTunnelServerIntegrationTest.java index 30ef53022..2efb0e72c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketNoiseTunnelServerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketNoiseTunnelServerIntegrationTest.java @@ -3,12 +3,10 @@ package org.whispersystems.textsecuregcm.grpc.net.websocket; import io.netty.channel.local.LocalAddress; import io.netty.channel.nio.NioEventLoopGroup; import java.util.concurrent.Executor; -import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest; import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage; import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandlerTest.java index 97aeef600..e27311489 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandlerTest.java @@ -6,9 +6,9 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.params.provider.Arguments.argumentSet; import static org.junit.jupiter.params.provider.Arguments.arguments; -import static org.mockito.Mockito.mock; import com.google.common.net.InetAddresses; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -17,7 +17,6 @@ import io.netty.channel.local.LocalAddress; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; -import io.netty.util.Attribute; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; @@ -32,13 +31,10 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.signal.libsignal.protocol.ecc.Curve; -import org.whispersystems.textsecuregcm.grpc.RequestAttributes; import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler; -import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; +import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern; +import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit; +import org.whispersystems.textsecuregcm.util.TestRandomUtil; class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { @@ -84,9 +80,7 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { userEventRecordingHandler = new UserEventRecordingHandler(); embeddedChannel = new MutableRemoteAddressEmbeddedChannel( - new WebsocketHandshakeCompleteHandler(mock(ClientPublicKeysManager.class), - Curve.generateKeyPair(), - RECOGNIZED_PROXY_SECRET), + new WebsocketHandshakeCompleteHandler(RECOGNIZED_PROXY_SECRET), userEventRecordingHandler); embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0)); @@ -94,22 +88,25 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { @ParameterizedTest @MethodSource - void handleWebSocketHandshakeComplete(final String uri, final Class expectedHandlerClass) { + void handleWebSocketHandshakeComplete(final String uri, final HandshakePattern pattern) { final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null); embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); - - assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class)); - assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass)); - assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents()); + + final byte[] payload = TestRandomUtil.nextBytes(100); + embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload)); + assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class)); + final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll(); + assertNotNull(init); + assertEquals(init.getHandshakePattern(), pattern); } private static List handleWebSocketHandshakeComplete() { return List.of( - Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseAuthenticatedHandler.class), - Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseAnonymousHandler.class)); + Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, HandshakePattern.IK), + Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, HandshakePattern.NK)); } @Test @@ -141,13 +138,19 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { embeddedChannel.setRemoteAddress(remoteAddress); embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); - - - assertEquals(expectedRemoteAddress, - Optional.ofNullable(embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY)) - .map(Attribute::get) - .map(RequestAttributes::remoteAddress) + final byte[] payload = TestRandomUtil.nextBytes(100); + embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload)); + final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll(); + assertEquals( + expectedRemoteAddress, + Optional.ofNullable(init) + .map(NoiseHandshakeInit::getRemoteAddress) .orElse(null)); + if (expectedRemoteAddress == null) { + assertThrows(IllegalStateException.class, embeddedChannel::checkException); + } else { + assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class)); + } } private static List getRemoteAddress() {