Replace XX/NX handshakes with IK/NK

This commit is contained in:
Ravi Khadiwala 2024-06-24 17:20:42 -05:00 committed by ravi-signal
parent c835d85256
commit 542422b7b8
41 changed files with 1902 additions and 1611 deletions

View File

@ -478,7 +478,6 @@ noiseTunnel:
tlsKeyStoreEntryAlias: example.com
tlsKeyStorePassword: secret://noiseTunnel.tlsKeyStorePassword
noiseStaticPrivateKey: secret://noiseTunnel.noiseStaticPrivateKey
noiseRootPublicKeySignature: ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret
externalRequestFilter:

View File

@ -933,7 +933,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
clientConnectionManager,
clientPublicKeysManager,
config.getNoiseWebSocketTunnelConfiguration().noiseStaticKeyPair(),
config.getNoiseWebSocketTunnelConfiguration().noiseRootPublicKeySignature(),
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
config.getNoiseWebSocketTunnelConfiguration().recognizedProxySecret().value());

View File

@ -15,7 +15,6 @@ public record NoiseWebSocketTunnelConfiguration(@Positive int port,
@Nullable String tlsKeyStoreEntryAlias,
@Nullable SecretString tlsKeyStorePassword,
@NotNull SecretBytes noiseStaticPrivateKey,
@NotNull byte[] noiseRootPublicKeySignature,
@NotNull SecretString recognizedProxySecret) {
public ECKeyPair noiseStaticKeyPair() throws InvalidKeyException {

View File

@ -1,124 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.util.internal.EmptyArrays;
import java.security.NoSuchAlgorithmException;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
/**
* An abstract base class for XX- and NX-patterned Noise responder handshake handlers.
*
* @see <a href="https://noiseprotocol.org/noise.html">The Noise Protocol Framework</a>
*/
abstract class AbstractNoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
private final ECKeyPair ecKeyPair;
private final byte[] publicKeySignature;
private final HandshakeState handshakeState;
private static final int EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH = 32;
/**
* Constructs a new Noise handler with the given static server keys and static public key signature. The static public
* key must be signed by a trusted root private key whose public key is known to and trusted by authenticating
* clients.
*
* @param noiseProtocolName the name of the Noise protocol implemented by this handshake handler
* @param ecKeyPair the static key pair for this server
* @param publicKeySignature an Ed25519 signature of the raw bytes of the static public key
*/
AbstractNoiseHandshakeHandler(final String noiseProtocolName,
final ECKeyPair ecKeyPair,
final byte[] publicKeySignature) {
this.ecKeyPair = ecKeyPair;
this.publicKeySignature = publicKeySignature;
try {
this.handshakeState = new HandshakeState(noiseProtocolName, HandshakeState.RESPONDER);
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Unsupported Noise algorithm: " + noiseProtocolName, e);
}
}
protected HandshakeState getHandshakeState() {
return handshakeState;
}
/**
* Handles an initial ephemeral key message from a client, advancing the handshake state and sending the server's
* static keys to the client. Both XX and NX patterns begin with a client sending its ephemeral key to the server.
* Clients must not include an additional payload with their ephemeral key message. The server's reply contains its
* static keys along with an Ed25519 signature of its public static key by a trusted root key.
*
* @param context the channel handler context for this message
* @param frame the websocket frame containing the ephemeral key message
*
* @throws NoiseHandshakeException if the ephemeral key message from the client was not of the expected size or if a
* general Noise encryption error occurred
*/
protected void handleEphemeralKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
throws NoiseHandshakeException {
if (frame.content().readableBytes() != EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH) {
throw new NoiseHandshakeException("Unexpected ephemeral key message length");
}
// Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is
// making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key
// from the supplied private key (and will in fact overwrite any previously-set public key when setting a private
// key), so we just set the private key here.
handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0);
handshakeState.start();
// The initial message from the client should just include a plaintext ephemeral key with no payload. The frame is
// coming off the wire and so will be in a direct buffer that doesn't have a backing array.
final byte[] ephemeralKeyMessage = ByteBufUtil.getBytes(frame.content());
frame.content().readBytes(ephemeralKeyMessage);
try {
handshakeState.readMessage(ephemeralKeyMessage, 0, ephemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
} catch (final ShortBufferException e) {
// This should never happen since we're checking the length of the frame up front
throw new NoiseHandshakeException("Unexpected client payload");
} catch (final BadPaddingException e) {
// It turns out this should basically never happen because (a) we're not using padding and (b) the "bad AEAD tag"
// subclass of a bad padding exception can only happen if we have some AD to check, which we don't for an
// ephemeral-key-only message
throw new NoiseHandshakeException("Invalid keys");
}
// Send our key material and public key signature back to the client; this buffer will include:
//
// - A 32-byte plaintext ephemeral key
// - A 32-byte encrypted static key
// - A 16-byte AEAD tag for the static key
// - The public key signature payload
// - A 16-byte AEAD tag for the payload
final byte[] keyMaterial = new byte[32 + 32 + 16 + publicKeySignature.length + 16];
try {
handshakeState.writeMessage(keyMaterial, 0, publicKeySignature, 0, publicKeySignature.length);
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(keyMaterial)))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
} catch (final ShortBufferException e) {
// This should never happen for messages of known length that we control
throw new AssertionError("Key material buffer was too short for message", e);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
handshakeState.destroy();
}
}

View File

@ -9,6 +9,7 @@ import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import javax.crypto.BadPaddingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
/**
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a
@ -38,7 +39,7 @@ class ErrorHandler extends ChannelInboundHandlerAdapter {
@Override
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
if (websocketHandshakeComplete) {
final WebSocketCloseStatus webSocketCloseStatus = switch (cause) {
final WebSocketCloseStatus webSocketCloseStatus = switch (ExceptionUtils.unwrap(cause)) {
case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage());
case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated");
case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
@ -51,6 +52,7 @@ class ErrorHandler extends ChannelInboundHandlerAdapter {
context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
} else {
log.debug("Error occurred before websocket handshake complete", cause);
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
// way; just close the connection instead.
context.close();

View File

@ -48,12 +48,12 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
@Override
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
// authenticated service.
final LocalAddress grpcServerAddress = noiseHandshakeCompleteEvent.authenticatedDevice().isPresent()
final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent()
? authenticatedGrpcServerAddress
: anonymousGrpcServerAddress;
@ -72,7 +72,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
if (localChannelFuture.isSuccess()) {
clientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
remoteChannelContext.channel(),
noiseHandshakeCompleteEvent.authenticatedDevice());
noiseIdentityDeterminedEvent.authenticatedDevice());
// Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());

View File

@ -0,0 +1,21 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
enum HandshakePattern {
NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");
private final String protocol;
public String protocol() {
return protocol;
}
HandshakePattern(String protocol) {
this.protocol = protocol;
}
}

View File

@ -0,0 +1,34 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
/**
* A NoiseAnonymousHandler is a netty pipeline element that handles the responder side of an unauthenticated handshake
* and noise encryption/decryption.
* <p>
* A noise NK handshake must be used for unauthenticated connections. Optionally, the initiator can also include an
* initial request in their payload. If provided, this allows the server to begin processing the request without an
* initial message delay (fast open).
* <p>
* Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent}
* indicating that initiator connected anonymously.
*/
class NoiseAnonymousHandler extends NoiseHandler {
public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) {
super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair));
}
@Override
CompletableFuture<HandshakeResult> handleHandshakePayload(final ChannelHandlerContext context,
final Optional<byte[]> initiatorPublicKey, final ByteBuf handshakePayload) {
return CompletableFuture.completedFuture(new HandshakeResult(
handshakePayload,
Optional.empty()
));
}
}

View File

@ -0,0 +1,96 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.ReferenceCountUtil;
import java.security.MessageDigest;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
/**
* A NoiseAuthenticatedHandler is a netty pipeline element that handles the responder side of an authenticated handshake
* and noise encryption/decryption. Authenticated handshakes are noise IK handshakes where the initiator's static public
* key is authenticated by the responder.
* <p>
* The authenticated handshake requires the initiator to provide a payload with their first handshake message that
* includes their account identifier and device id in network byte-order. Optionally, the initiator can also include an
* initial request in their payload. If provided, this allows the server to begin processing the request without an
* initial message delay (fast open).
* <pre>
* +-----------------+----------------+------------------------+
* | UUID (16) | deviceId (1) | request bytes (N) |
* +-----------------+----------------+------------------------+
* </pre>
* <p>
* For a successful handshake, the static key provided in the handshake message must match the server's stored public
* key for the device identified by the provided ACI and deviceId.
* <p>
* As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}.
*/
class NoiseAuthenticatedHandler extends NoiseHandler {
private final ClientPublicKeysManager clientPublicKeysManager;
NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair) {
super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair));
this.clientPublicKeysManager = clientPublicKeysManager;
}
@Override
CompletableFuture<HandshakeResult> handleHandshakePayload(
final ChannelHandlerContext context,
final Optional<byte[]> initiatorPublicKey,
final ByteBuf handshakePayload) throws NoiseHandshakeException {
if (handshakePayload.readableBytes() < 17) {
throw new NoiseHandshakeException("Invalid handshake payload");
}
final byte[] publicKeyFromClient = initiatorPublicKey
.orElseThrow(() -> new IllegalStateException("No remote public key"));
// Advances the read index by 16 bytes
final UUID accountIdentifier = parseUUID(handshakePayload);
// Advances the read index by 1 byte
final byte deviceId = handshakePayload.readByte();
final ByteBuf fastOpenRequest = handshakePayload.slice();
return clientPublicKeysManager
.findPublicKey(accountIdentifier, deviceId)
.handleAsync((storedPublicKey, throwable) -> {
if (throwable != null) {
ReferenceCountUtil.release(fastOpenRequest);
throw ExceptionUtils.wrap(throwable);
}
final boolean valid = storedPublicKey
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
.orElse(false);
if (!valid) {
throw ExceptionUtils.wrap(new ClientAuthenticationException());
}
return new HandshakeResult(
fastOpenRequest,
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
}, context.executor());
}
/**
* Parse a {@link UUID} out of bytes, advancing the readerIdx by 16
*
* @param bytes The {@link ByteBuf} to read from
* @return The parsed UUID
* @throws NoiseHandshakeException If a UUID could not be parsed from bytes
*/
private UUID parseUUID(final ByteBuf bytes) throws NoiseHandshakeException {
if (bytes.readableBytes() < 16) {
throw new NoiseHandshakeException("Could not parse account identifier");
}
return new UUID(bytes.readLong(), bytes.readLong());
}
}

View File

@ -0,0 +1,195 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.CipherState;
import com.southernstorm.noise.protocol.CipherStatePair;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.EmptyArrays;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
/**
* A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts
* inbound messages, and encrypts outbound messages
*/
abstract class NoiseHandler extends ChannelDuplexHandler {
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
private enum State {
// Waiting for handshake to complete
HANDSHAKE,
// Can freely exchange encrypted noise messages on an established session
TRANSPORT,
// Finished with error
ERROR
}
private final NoiseHandshakeHelper handshakeHelper;
private State state = State.HANDSHAKE;
private CipherStatePair cipherStatePair;
NoiseHandler(NoiseHandshakeHelper handshakeHelper) {
this.handshakeHelper = handshakeHelper;
}
/**
* The result of processing an initiator handshake payload
*
* @param fastOpenRequest A fast-open request included in the handshake. If none was present, this should be an
* empty ByteBuf
* @param authenticatedDevice If present, the successfully authenticated initiator identity
*/
record HandshakeResult(ByteBuf fastOpenRequest, Optional<AuthenticatedDevice> authenticatedDevice) {}
/**
* Parse and potentially authenticate the initiator handshake message
*
* @param context A {@link ChannelHandlerContext}
* @param initiatorPublicKey The initiator's static public key, if a handshake pattern that includes it was used
* @param handshakePayload The handshake payload provided in the initiator message
* @return A {@link HandshakeResult} that includes an authenticated device and a parsed fast-open request if one was
* present in the handshake payload.
* @throws NoiseHandshakeException If the handshake payload was invalid
* @throws ClientAuthenticationException If the initiatorPublicKey could not be authenticated
*/
abstract CompletableFuture<HandshakeResult> handleHandshakePayload(
final ChannelHandlerContext context,
final Optional<byte[]> initiatorPublicKey,
final ByteBuf handshakePayload) throws NoiseHandshakeException, ClientAuthenticationException;
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (message instanceof BinaryWebSocketFrame frame) {
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
// We'll need to copy it to a heap buffer.
handleInboundMessage(context, ByteBufUtil.getBytes(frame.content()));
} else {
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
// error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
} catch (Exception e) {
fail(context, e);
} finally {
ReferenceCountUtil.release(message);
}
}
private void handleInboundMessage(final ChannelHandlerContext context, final byte[] frameBytes)
throws NoiseHandshakeException, ShortBufferException, BadPaddingException, ClientAuthenticationException {
switch (state) {
// Got an initiator handshake message
case HANDSHAKE -> {
final ByteBuf payload = handshakeHelper.read(frameBytes);
handleHandshakePayload(context, handshakeHelper.remotePublicKey(), payload).whenCompleteAsync(
(result, throwable) -> {
if (state == State.ERROR) {
return;
}
if (throwable != null) {
fail(context, ExceptionUtils.unwrap(throwable));
return;
}
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(result.authenticatedDevice()));
// Now that we've authenticated, write the handshake response
byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES);
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
this.state = State.TRANSPORT;
this.cipherStatePair = handshakeHelper.getHandshakeState().split();
if (result.fastOpenRequest().isReadable()) {
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
// encrypt the response when the server writes back through us
context.fireChannelRead(result.fastOpenRequest());
} else {
ReferenceCountUtil.release(result.fastOpenRequest());
}
}, context.executor());
}
// Got a client message that should be decrypted and forwarded
case TRANSPORT -> {
final CipherState cipherState = cipherStatePair.getReceiver();
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
final int plaintextLength = cipherState.decryptWithAd(null,
frameBytes, 0,
frameBytes, 0,
frameBytes.length);
// Forward the decrypted plaintext along
context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength));
}
// The session is already in an error state, drop the message
case ERROR -> {
}
}
}
/**
* Set the state to the error state (so subsequent messages fast-fail) and propagate the failure reason on the
* context
*/
private void fail(final ChannelHandlerContext context, final Throwable cause) {
this.state = State.ERROR;
context.fireExceptionCaught(cause);
}
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
throws Exception {
if (message instanceof ByteBuf plaintext) {
try {
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
final CipherState cipherState = cipherStatePair.getSender();
final int plaintextLength = plaintext.readableBytes();
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
} finally {
ReferenceCountUtil.release(plaintext);
}
} else {
if (!(message instanceof WebSocketFrame)) {
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
// get issued in response to exceptions)
log.warn("Unexpected object in pipeline: {}", message);
}
context.write(message, promise);
}
}
}

View File

@ -1,13 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import java.util.Optional;
/**
* An event that indicates that a Noise handshake has completed, possibly authenticating a caller in the process.
*
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
* type that performs authentication
*/
record NoiseHandshakeCompleteEvent(Optional<AuthenticatedDevice> authenticatedDevice) {
}

View File

@ -0,0 +1,124 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.HandshakeState;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import java.security.NoSuchAlgorithmException;
import java.util.Optional;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
/**
* Helper for the responder of a 2-message handshake with a pre-shared responder static key
*/
class NoiseHandshakeHelper {
private final static int AEAD_TAG_LENGTH = 16;
private final static int KEY_LENGTH = 32;
private final HandshakePattern handshakePattern;
private final ECKeyPair serverStaticKeyPair;
private final HandshakeState handshakeState;
NoiseHandshakeHelper(HandshakePattern handshakePattern, ECKeyPair serverStaticKeyPair) {
this.handshakePattern = handshakePattern;
this.serverStaticKeyPair = serverStaticKeyPair;
try {
this.handshakeState = new HandshakeState(handshakePattern.protocol(), HandshakeState.RESPONDER);
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Unsupported Noise algorithm: " + handshakePattern.protocol(), e);
}
}
/**
* Get the length of the initiator's keys
*
* @return length of the handshake message sent by the remote party (the initiator) not including the payload
*/
private int initiatorHandshakeMessageKeyLength() {
return switch (handshakePattern) {
// ephemeral key, static key (encrypted), AEAD tag for static key
case IK -> KEY_LENGTH + KEY_LENGTH + AEAD_TAG_LENGTH;
// ephemeral key only
case NK -> KEY_LENGTH;
};
}
HandshakeState getHandshakeState() {
return this.handshakeState;
}
ByteBuf read(byte[] remoteHandshakeMessage) throws NoiseHandshakeException {
if (handshakeState.getAction() != HandshakeState.NO_ACTION) {
throw new NoiseHandshakeException("Cannot send more data before handshake is complete");
}
// Length for an empty payload
final int minMessageLength = initiatorHandshakeMessageKeyLength() + AEAD_TAG_LENGTH;
if (remoteHandshakeMessage.length < minMessageLength || remoteHandshakeMessage.length > Noise.MAX_PACKET_LEN) {
throw new NoiseHandshakeException("Unexpected ephemeral key message length");
}
final int payloadLength = remoteHandshakeMessage.length - initiatorHandshakeMessageKeyLength() - AEAD_TAG_LENGTH;
// Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is
// making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key
// from the supplied private key (and will in fact overwrite any previously-set public key when setting a private
// key), so we just set the private key here.
handshakeState.getLocalKeyPair().setPrivateKey(serverStaticKeyPair.getPrivateKey().serialize(), 0);
handshakeState.start();
int payloadBytesRead;
try {
payloadBytesRead = handshakeState.readMessage(remoteHandshakeMessage, 0, remoteHandshakeMessage.length,
remoteHandshakeMessage, 0);
} catch (final ShortBufferException e) {
// This should never happen since we're checking the length of the frame up front
throw new NoiseHandshakeException("Unexpected client payload");
} catch (final BadPaddingException e) {
// We aren't using padding but may get this error if the AEAD tag does not match the encrypted client static key
// or payload
throw new NoiseHandshakeException("Invalid keys or payload");
}
if (payloadBytesRead != payloadLength) {
throw new NoiseHandshakeException(
"Unexpected payload length, required " + payloadLength + " but got " + payloadBytesRead);
}
return Unpooled.wrappedBuffer(remoteHandshakeMessage, 0, payloadBytesRead);
}
byte[] write(byte[] payload) {
if (handshakeState.getAction() != HandshakeState.WRITE_MESSAGE) {
throw new IllegalStateException("Cannot send data before handshake is complete");
}
// Currently only support handshake patterns where the server static key is known
// Send our ephemeral key and the response to the initiator with the encrypted payload
final byte[] response = new byte[KEY_LENGTH + payload.length + AEAD_TAG_LENGTH];
try {
int written = handshakeState.writeMessage(response, 0, payload, 0, payload.length);
if (written != response.length) {
throw new IllegalStateException("Unexpected handshake response length");
}
return response;
} catch (final ShortBufferException e) {
// This should never happen for messages of known length that we control
throw new IllegalStateException("Key material buffer was too short for message", e);
}
}
Optional<byte[]> remotePublicKey() {
return Optional.ofNullable(handshakeState.getRemotePublicKey()).map(dhstate -> {
final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()];
handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0);
return publicKeyFromClient;
});
}
}

View File

@ -0,0 +1,13 @@
package org.whispersystems.textsecuregcm.grpc.net;
import java.util.Optional;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
/**
* An event that indicates that an identity of a noise handshake initiator has been determined. If the initiator is
* connecting anonymously, the identity is empty, otherwise it will be present and already authenticated.
*
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
* type that performs authentication
*/
record NoiseIdentityDeterminedEvent(Optional<AuthenticatedDevice> authenticatedDevice) {}

View File

@ -1,40 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.util.Optional;
import io.netty.util.ReferenceCountUtil;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
/**
* A Noise NX handler handles the responder side of a Noise NX handshake.
*/
class NoiseNXHandshakeHandler extends AbstractNoiseHandshakeHandler {
static final String NOISE_PROTOCOL_NAME = "Noise_NX_25519_ChaChaPoly_BLAKE2b";
NoiseNXHandshakeHandler(final ECKeyPair ecKeyPair, final byte[] publicKeySignature) {
super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature);
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof BinaryWebSocketFrame frame) {
try {
handleEphemeralKeyMessage(context, frame);
} finally {
frame.release();
}
// All we need to do is accept the client's ephemeral key and send our own static keys; after that, we can consider
// the handshake complete
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
context.pipeline().replace(NoiseNXHandshakeHandler.this, null, new NoiseTransportHandler(getHandshakeState().split()));
} else {
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
// error
ReferenceCountUtil.release(message);
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
}
}

View File

@ -53,7 +53,6 @@ public class NoiseWebSocketTunnelServer implements Managed {
final ClientConnectionManager clientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final byte[] publicKeySignature,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws SSLException {
@ -107,7 +106,7 @@ public class NoiseWebSocketTunnelServer implements Managed {
.addLast(new RejectUnsupportedMessagesHandler())
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
// a WebSocket handshake has been completed
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature, recognizedProxySecret))
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret))
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
// once the Noise handshake has completed
.addLast(new EstablishLocalGrpcConnectionHandler(clientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))

View File

@ -1,178 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import java.security.MessageDigest;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
/**
* A Noise XX handler handles the responder side of a Noise XX handshake. This implementation expects clients to send
* identifying information (an account identifier and device ID) as an additional payload when sending its static key
* material. It compares the static public key against the stored public key for the identified device asynchronously,
* buffering traffic from the client until the authentication check completes.
*/
class NoiseXXHandshakeHandler extends AbstractNoiseHandshakeHandler {
private final ClientPublicKeysManager clientPublicKeysManager;
private AuthenticationState authenticationState = AuthenticationState.GET_EPHEMERAL_KEY;
private final List<BinaryWebSocketFrame> pendingInboundFrames = new ArrayList<>();
static final String NOISE_PROTOCOL_NAME = "Noise_XX_25519_ChaChaPoly_BLAKE2b";
// When the client sends its static key message, we expect:
//
// - A 32-byte encrypted static public key
// - A 16-byte AEAD tag for the static key
// - 17 bytes of identity data in the message payload (a UUID and a one-byte device ID)
// - A 16-byte AEAD tag for the identity payload
private static final int EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH = 81;
private enum AuthenticationState {
GET_EPHEMERAL_KEY,
GET_STATIC_KEY,
CHECK_PUBLIC_KEY,
ERROR
}
public NoiseXXHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final byte[] publicKeySignature) {
super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature);
this.clientPublicKeysManager = clientPublicKeysManager;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof BinaryWebSocketFrame frame) {
try {
switch (authenticationState) {
case GET_EPHEMERAL_KEY -> {
try {
handleEphemeralKeyMessage(context, frame);
authenticationState = AuthenticationState.GET_STATIC_KEY;
} finally {
frame.release();
}
}
case GET_STATIC_KEY -> {
try {
handleStaticKey(context, frame);
authenticationState = AuthenticationState.CHECK_PUBLIC_KEY;
} finally {
frame.release();
}
}
case CHECK_PUBLIC_KEY -> {
// Buffer any inbound traffic until we've finished checking the client's public key
pendingInboundFrames.add(frame);
}
case ERROR -> {
// If authentication has failed for any reason, just discard inbound traffic until the channel closes
frame.release();
}
}
} catch (final ShortBufferException e) {
authenticationState = AuthenticationState.ERROR;
throw new NoiseHandshakeException("Unexpected payload length");
} catch (final BadPaddingException e) {
authenticationState = AuthenticationState.ERROR;
throw new ClientAuthenticationException();
}
} else {
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
// error
ReferenceCountUtil.release(message);
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
}
private void handleStaticKey(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
throws NoiseHandshakeException, ShortBufferException, BadPaddingException {
if (frame.content().readableBytes() != EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH) {
throw new NoiseHandshakeException("Unexpected client static key message length");
}
final HandshakeState handshakeState = getHandshakeState();
// The websocket frame will have come right off the wire, and so needs to be copied from a non-array-backed direct
// buffer into a heap buffer.
final byte[] staticKeyAndClientIdentityMessage = ByteBufUtil.getBytes(frame.content());
// The payload from the client should be a UUID (16 bytes) followed by a device ID (1 byte)
final byte[] payload = new byte[17];
final UUID accountIdentifier;
final byte deviceId;
final int payloadBytesRead = handshakeState.readMessage(staticKeyAndClientIdentityMessage,
0, staticKeyAndClientIdentityMessage.length, payload, 0);
if (payloadBytesRead != 17) {
throw new NoiseHandshakeException("Unexpected identity payload length");
}
try {
accountIdentifier = UUIDUtil.fromBytes(payload, 0);
} catch (final IllegalArgumentException e) {
throw new NoiseHandshakeException("Could not parse account identifier");
}
deviceId = payload[16];
// Verify the identity of the caller by comparing the submitted static public key against the stored public key for
// the identified device
clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)
.whenCompleteAsync((maybePublicKey, throwable) -> maybePublicKey.ifPresentOrElse(storedPublicKey -> {
final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()];
handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0);
if (MessageDigest.isEqual(publicKeyFromClient, storedPublicKey.getPublicKeyBytes())) {
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))));
context.pipeline().addAfter(context.name(), null, new NoiseTransportHandler(handshakeState.split()));
// Flush any buffered reads
pendingInboundFrames.forEach(context::fireChannelRead);
pendingInboundFrames.clear();
context.pipeline().remove(NoiseXXHandshakeHandler.this);
} else {
// We found a key, but it doesn't match what the caller submitted
context.fireExceptionCaught(new ClientAuthenticationException());
authenticationState = AuthenticationState.ERROR;
}
},
() -> {
// We couldn't find a key for the identified account/device
context.fireExceptionCaught(new ClientAuthenticationException());
authenticationState = AuthenticationState.ERROR;
}),
context.executor());
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
super.handlerRemoved(context);
pendingInboundFrames.forEach(BinaryWebSocketFrame::release);
pendingInboundFrames.clear();
}
}

View File

@ -31,7 +31,6 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
private final ClientPublicKeysManager clientPublicKeysManager;
private final ECKeyPair ecKeyPair;
private final byte[] publicKeySignature;
private final byte[] recognizedProxySecret;
@ -45,12 +44,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final byte[] publicKeySignature,
final String recognizedProxySecret) {
this.clientPublicKeysManager = clientPublicKeysManager;
this.ecKeyPair = ecKeyPair;
this.publicKeySignature = publicKeySignature;
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
@ -84,10 +81,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH ->
new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature);
new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair);
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH ->
new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature);
new NoiseAnonymousHandler(ecKeyPair);
default -> {
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an

View File

@ -1,94 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import java.security.NoSuchAlgorithmException;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
abstract class AbstractNoiseClientHandler extends ChannelInboundHandlerAdapter {
private final ECPublicKey rootPublicKey;
private final HandshakeState handshakeState;
AbstractNoiseClientHandler(final ECPublicKey rootPublicKey) {
this.rootPublicKey = rootPublicKey;
try {
handshakeState = new HandshakeState(getNoiseProtocolName(), HandshakeState.INITIATOR);
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Unsupported Noise algorithm: " + getNoiseProtocolName(), e);
}
}
protected abstract String getNoiseProtocolName();
protected abstract void startHandshake();
protected HandshakeState getHandshakeState() {
return handshakeState;
}
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
startHandshake();
final byte[] ephemeralKeyMessage = new byte[32];
handshakeState.writeMessage(ephemeralKeyMessage, 0, null, 0, 0);
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessage)))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
}
}
super.userEventTriggered(context, event);
}
protected void handleServerStaticKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
throws NoiseHandshakeException {
// The frame is coming right off the wire and so will be a direct buffer not backed by an array; copy it to a heap
// buffer so we can Noise at it.
final ByteBuf keyMaterialBuffer = context.alloc().heapBuffer(frame.content().readableBytes());
final byte[] serverPublicKeySignature = new byte[64];
try {
frame.content().readBytes(keyMaterialBuffer);
final int payloadBytesRead =
handshakeState.readMessage(keyMaterialBuffer.array(), keyMaterialBuffer.arrayOffset(), keyMaterialBuffer.readableBytes(), serverPublicKeySignature, 0);
if (payloadBytesRead != 64) {
throw new NoiseHandshakeException("Unexpected signature length");
}
} catch (final ShortBufferException e) {
throw new NoiseHandshakeException("Unexpected signature length");
} catch (final BadPaddingException e) {
throw new NoiseHandshakeException("Invalid keys");
} finally {
keyMaterialBuffer.release();
}
final byte[] serverPublicKey = new byte[32];
handshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
if (!rootPublicKey.verifySignature(serverPublicKey, serverPublicKeySignature)) {
throw new NoiseHandshakeException("Invalid server public key signature");
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) throws Exception {
handshakeState.destroy();
}
}

View File

@ -0,0 +1,257 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.southernstorm.noise.protocol.CipherStatePair;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ThreadLocalRandom;
import javax.annotation.Nullable;
import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import io.netty.util.ReferenceCountUtil;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
protected ECKeyPair serverKeyPair;
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
private EmbeddedChannel embeddedChannel;
private static class PongHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
try {
if (msg instanceof ByteBuf bb) {
if (new String(ByteBufUtil.getBytes(bb)).equals("ping")) {
ctx.writeAndFlush(Unpooled.wrappedBuffer("pong".getBytes()))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
} else {
throw new IllegalArgumentException("Unexpected message: " + new String(ByteBufUtil.getBytes(bb)));
}
} else {
throw new IllegalArgumentException("Unexpected message type: " + msg);
}
} finally {
ReferenceCountUtil.release(msg);
}
}
}
private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
@Nullable
private NoiseIdentityDeterminedEvent handshakeCompleteEvent = null;
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
handshakeCompleteEvent = noiseIdentityDeterminedEvent;
context.pipeline().addAfter(context.name(), null, new PongHandler());
context.pipeline().remove(NoiseHandshakeCompleteHandler.class);
} else {
context.fireUserEventTriggered(event);
}
}
@Nullable
public NoiseIdentityDeterminedEvent getHandshakeCompleteEvent() {
return handshakeCompleteEvent;
}
}
@BeforeEach
void setUp() {
serverKeyPair = Curve.generateKeyPair();
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
embeddedChannel = new EmbeddedChannel(getHandler(serverKeyPair), noiseHandshakeCompleteHandler);
}
@AfterEach
void tearDown() {
embeddedChannel.close();
}
protected EmbeddedChannel getEmbeddedChannel() {
return embeddedChannel;
}
@Nullable
protected NoiseIdentityDeterminedEvent getNoiseHandshakeCompleteEvent() {
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
}
protected abstract ChannelHandler getHandler(final ECKeyPair serverKeyPair);
protected abstract CipherStatePair doHandshake() throws Throwable;
/**
* Read a message from the embedded channel and deserialize it with the provided client cipher state. If there are no
* waiting messages in the channel, return null.
*/
byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException {
final BinaryWebSocketFrame responseFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
if (responseFrame == null) {
return null;
}
final byte[] plaintext = new byte[responseFrame.content().readableBytes() - 16];
final int read = clientCipherPair.getReceiver().decryptWithAd(null,
ByteBufUtil.getBytes(responseFrame.content()), 0,
plaintext, 0,
responseFrame.content().readableBytes());
assertEquals(read, plaintext.length);
return plaintext;
}
@Test
void handleInvalidInitialMessage() throws InterruptedException {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await();
assertFalse(writeFuture.isSuccess());
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
assertEquals(0, content.refCnt());
assertNull(getNoiseHandshakeCompleteEvent());
}
@Test
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7];
for (int i = 0; i < frames.length; i++) {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
embeddedChannel.writeOneInbound(frames[i]).await();
}
for (final BinaryWebSocketFrame frame : frames) {
assertEquals(0, frame.refCnt());
}
assertNull(getNoiseHandshakeCompleteEvent());
}
@Test
void handleNonWebSocketBinaryFrame() throws Throwable {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
final ByteBuf message = Unpooled.wrappedBuffer(contentBytes);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
assertFalse(writeFuture.isSuccess());
assertInstanceOf(IllegalArgumentException.class, writeFuture.cause());
assertEquals(0, message.refCnt());
assertNull(getNoiseHandshakeCompleteEvent());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
@Test
void channelRead() throws Throwable {
final CipherStatePair clientCipherStatePair = doHandshake();
final byte[] plaintext = "ping".getBytes(StandardCharsets.UTF_8);
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext));
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
assertEquals(0, ciphertextFrame.refCnt());
final byte[] response = readNextPlaintext(clientCipherStatePair);
assertArrayEquals("pong".getBytes(StandardCharsets.UTF_8), response);
}
@Test
void channelReadBadCiphertext() throws Throwable {
doHandshake();
final byte[] bogusCiphertext = new byte[32];
io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext);
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext));
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
assertEquals(0, ciphertextFrame.refCnt());
assertFalse(readCiphertextFuture.isSuccess());
assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
@Test
void channelReadUnexpectedMessageType() throws Throwable {
doHandshake();
final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await();
assertFalse(readFuture.isSuccess());
assertInstanceOf(IllegalArgumentException.class, readFuture.cause());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
@Test
void write() throws Throwable {
final CipherStatePair clientCipherStatePair = doHandshake();
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext);
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
assertTrue(writePlaintextFuture.await().isSuccess());
assertEquals(0, plaintextBuffer.refCnt());
final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
assertNotNull(ciphertextFrame);
assertTrue(embeddedChannel.outboundMessages().isEmpty());
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
ciphertextFrame.release();
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length);
assertArrayEquals(plaintext, decryptedPlaintext);
}
@Test
void writeUnexpectedMessageType() throws Throwable {
doHandshake();
final Object unexpectedMessaged = new Object();
final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged);
assertTrue(writeFuture.await().isSuccess());
assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll());
assertTrue(embeddedChannel.outboundMessages().isEmpty());
}
}

View File

@ -1,141 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import java.util.concurrent.ThreadLocalRandom;
import javax.annotation.Nullable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
abstract class AbstractNoiseHandshakeHandlerTest extends AbstractLeakDetectionTest {
private ECPublicKey rootPublicKey;
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
private EmbeddedChannel embeddedChannel;
private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
@Nullable
private NoiseHandshakeCompleteEvent handshakeCompleteEvent = null;
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
handshakeCompleteEvent = noiseHandshakeCompleteEvent;
} else {
context.fireUserEventTriggered(event);
}
}
@Nullable
public NoiseHandshakeCompleteEvent getHandshakeCompleteEvent() {
return handshakeCompleteEvent;
}
}
@BeforeEach
void setUp() {
final ECKeyPair rootKeyPair = Curve.generateKeyPair();
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
rootPublicKey = rootKeyPair.getPublicKey();
final byte[] serverPublicKeySignature =
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes());
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
embeddedChannel =
new EmbeddedChannel(getHandler(serverKeyPair, serverPublicKeySignature), noiseHandshakeCompleteHandler);
}
@AfterEach
void tearDown() {
embeddedChannel.close();
}
protected EmbeddedChannel getEmbeddedChannel() {
return embeddedChannel;
}
protected ECPublicKey getRootPublicKey() {
return rootPublicKey;
}
@Nullable
protected NoiseHandshakeCompleteEvent getNoiseHandshakeCompleteEvent() {
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
}
protected abstract AbstractNoiseHandshakeHandler getHandler(final ECKeyPair serverKeyPair, final byte[] serverPublicKeySignature);
@Test
void handleInvalidInitialMessage() throws InterruptedException {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await();
assertFalse(writeFuture.isSuccess());
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
assertEquals(0, content.refCnt());
assertNull(getNoiseHandshakeCompleteEvent());
}
@Test
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7];
for (int i = 0; i < frames.length; i++) {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
embeddedChannel.writeOneInbound(frames[i]).await();
}
for (final BinaryWebSocketFrame frame : frames) {
assertEquals(0, frame.refCnt());
}
assertNull(getNoiseHandshakeCompleteEvent());
}
@Test
void handleNonWebSocketBinaryFrame() throws InterruptedException {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
final ByteBuf message = Unpooled.wrappedBuffer(contentBytes);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
assertFalse(writeFuture.isSuccess());
assertInstanceOf(IllegalArgumentException.class, writeFuture.cause());
assertEquals(0, message.refCnt());
assertNull(getNoiseHandshakeCompleteEvent());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
}

View File

@ -2,6 +2,7 @@ package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.Noise;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
@ -19,6 +20,7 @@ import io.netty.handler.ssl.SslContextBuilder;
import io.netty.util.ReferenceCountUtil;
import java.net.SocketAddress;
import java.net.URI;
import java.nio.ByteBuffer;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
@ -29,6 +31,10 @@ import javax.net.ssl.SSLException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
/**
* Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote
* gRPC server
*/
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final boolean useTls;
@ -36,13 +42,15 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final URI websocketUri;
private final boolean authenticated;
@Nullable private final ECKeyPair ecKeyPair;
private final ECPublicKey rootPublicKey;
private final ECPublicKey serverPublicKey;
@Nullable private final UUID accountIdentifier;
private final byte deviceId;
private final HttpHeaders headers;
private final SocketAddress remoteServerAddress;
private final WebSocketCloseListener webSocketCloseListener;
@Nullable private final Supplier<HAProxyMessage> proxyMessageSupplier;
// If provided, will be sent with the payload in the noise handshake
private final byte[] fastOpenRequest;
private final List<Object> pendingReads = new ArrayList<>();
@ -54,26 +62,28 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
final URI websocketUri,
final boolean authenticated,
@Nullable final ECKeyPair ecKeyPair,
final ECPublicKey rootPublicKey,
final ECPublicKey serverPublicKey,
@Nullable final UUID accountIdentifier,
final byte deviceId,
final HttpHeaders headers,
final SocketAddress remoteServerAddress,
final WebSocketCloseListener webSocketCloseListener,
@Nullable Supplier<HAProxyMessage> proxyMessageSupplier) {
@Nullable Supplier<HAProxyMessage> proxyMessageSupplier,
@Nullable byte[] fastOpenRequest) {
this.useTls = useTls;
this.trustedServerCertificate = trustedServerCertificate;
this.websocketUri = websocketUri;
this.authenticated = authenticated;
this.ecKeyPair = ecKeyPair;
this.rootPublicKey = rootPublicKey;
this.serverPublicKey = serverPublicKey;
this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId;
this.headers = headers;
this.remoteServerAddress = remoteServerAddress;
this.webSocketCloseListener = webSocketCloseListener;
this.proxyMessageSupplier = proxyMessageSupplier;
this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest;
}
@Override
@ -104,6 +114,10 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
channel.pipeline().addLast(sslContextBuilder.build().newHandler(channel.alloc()));
}
final NoiseClientHandshakeHelper helper = authenticated
? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair)
: NoiseClientHandshakeHelper.NK(serverPublicKey);
channel.pipeline()
.addLast(new HttpClientCodec())
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
@ -118,22 +132,24 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
Noise.MAX_PACKET_LEN,
10_000))
.addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener))
.addLast(authenticated
? new NoiseXXClientHandshakeHandler(ecKeyPair, rootPublicKey, accountIdentifier, deviceId)
: new NoiseNXClientHandshakeHandler(rootPublicKey))
// Listens for a Websocket HANDSHAKE_COMPLETE and begins the noise handshake when it is done
.addLast(new NoiseClientHandshakeHandler(helper, initialPayload()))
.addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
throws Exception {
if (event instanceof NoiseHandshakeCompleteEvent) {
if (event instanceof NoiseClientHandshakeCompleteEvent handshakeCompleteEvent) {
remoteContext.pipeline()
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
// If there was a payload response on the handshake, write it back to our gRPC client
handshakeCompleteEvent.fastResponse().ifPresent(plaintext ->
localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext)));
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
pendingReads.forEach(localContext::fireChannelRead);
pendingReads.clear();
localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
}
@ -165,4 +181,18 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
pendingReads.forEach(ReferenceCountUtil::release);
pendingReads.clear();
}
private byte[] initialPayload() {
if (!authenticated) {
return fastOpenRequest;
}
final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length);
bb.putLong(accountIdentifier.getMostSignificantBits());
bb.putLong(accountIdentifier.getLeastSignificantBits());
bb.put(deviceId);
bb.put(fastOpenRequest);
bb.flip();
return bb.array();
}
}

View File

@ -0,0 +1,9 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {}

View File

@ -0,0 +1,184 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.util.ReferenceCountUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HexFormat;
import java.util.List;
import java.util.stream.Stream;
/**
* The noise tunnel streams bytes out of a gRPC client through noise and to a remote server. The server supports a "fast
* open" optimization where the client can send a request along with the noise handshake. There's no direct way to
* extract the request boundaries from the gRPC client's byte-stream, so {@link Http2Buffering#handler()} provides an
* inbound pipeline handler that will parse the byte-stream back into HTTP/2 frames and buffer the first request.
* <p>
* Once an entire request has been buffered, the handler will remove itself from the pipeline and emit a
* {@link FastOpenRequestBufferedEvent}
*/
class Http2Buffering {
/**
* Create a pipeline handler that consumes serialized HTTP/2 ByteBufs and emits a fast-open request
*/
static ChannelInboundHandler handler() {
return new Http2PrefaceHandler();
}
private Http2Buffering() {
}
private static class Http2PrefaceHandler extends ChannelInboundHandlerAdapter {
// https://www.rfc-editor.org/rfc/rfc7540.html#section-3.5
private static final byte[] HTTP2_PREFACE =
HexFormat.of().parseHex("505249202a20485454502f322e300d0a0d0a534d0d0a0d0a");
private final ByteBuf read = Unpooled.buffer(HTTP2_PREFACE.length, HTTP2_PREFACE.length);
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) {
if (message instanceof ByteBuf bb) {
bb.readBytes(read);
if (read.readableBytes() < HTTP2_PREFACE.length) {
// Copied the message into the read buffer, but haven't yet got a full HTTP2 preface. Wait for more.
return;
}
if (!Arrays.equals(read.array(), HTTP2_PREFACE)) {
throw new IllegalStateException("HTTP/2 stream must start with HTTP/2 preface");
}
context.pipeline().replace(this, "http2frame1", new Http2LengthFieldFrameDecoder());
context.pipeline().addAfter("http2frame1", "http2frame2", new Http2FrameDecoder());
context.pipeline().addAfter("http2frame2", "http2frame3", new Http2FirstRequestHandler());
context.fireChannelRead(bb);
} else {
throw new IllegalStateException("Unexpected message: " + message);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
ReferenceCountUtil.release(read);
}
}
private record Http2Frame(ByteBuf bytes, FrameType type, boolean endStream) {
private static final byte FLAG_END_STREAM = 0x01;
enum FrameType {
SETTINGS,
HEADERS,
DATA,
WINDOW_UPDATE,
OTHER;
static FrameType fromSerializedType(final byte type) {
return switch (type) {
case 0x00 -> Http2Frame.FrameType.DATA;
case 0x01 -> Http2Frame.FrameType.HEADERS;
case 0x04 -> Http2Frame.FrameType.SETTINGS;
case 0x08 -> Http2Frame.FrameType.WINDOW_UPDATE;
default -> Http2Frame.FrameType.OTHER;
};
}
}
}
/**
* Emit ByteBuf of entire HTTP/2 frame
*/
private static class Http2LengthFieldFrameDecoder extends LengthFieldBasedFrameDecoder {
public Http2LengthFieldFrameDecoder() {
// Frames are 3 bytes of length, 6 bytes of other header, and then length bytes of payload
super(16 * 1024 * 1024, 0, 3, 6, 0);
}
}
/**
* Parse the serialized Http/2 frames into {@link Http2Frame} objects
*/
private static class Http2FrameDecoder extends ByteToMessageDecoder {
@Override
protected void decode(final ChannelHandlerContext ctx, final ByteBuf in, final List<Object> out) throws Exception {
// https://www.rfc-editor.org/rfc/rfc7540.html#section-4.1
final Http2Frame.FrameType frameType = Http2Frame.FrameType.fromSerializedType(in.getByte(in.readerIndex() + 3));
final boolean endStream = endStream(frameType, in.getByte(in.readerIndex() + 4));
out.add(new Http2Frame(in.readBytes(in.readableBytes()), frameType, endStream));
}
boolean endStream(Http2Frame.FrameType frameType, byte flags) {
// A gRPC request are packed into HTTP/2 frames like:
// HEADERS frame | DATA frame 1 (endStream=0) | ... | DATA frame N (endstream=1)
//
// Our goal is to get an entire request buffered, so as soon as we see a DATA frame with the end stream flag set
// we have a whole request. Note that we could have pieces of multiple requests, but the only thing we care about
// is having at least one complete request. In total, we can expect something like:
// HTTP-preface | SETTINGS frame | Frames we don't care about ... | DATA (endstream=1)
//
// The connection isn't 'established' until the server has responded with their own SETTINGS frame with the ack
// bit set, but HTTP/2 allows the client to send frames before getting the ACK.
if (frameType == Http2Frame.FrameType.DATA) {
return (flags & Http2Frame.FLAG_END_STREAM) == Http2Frame.FLAG_END_STREAM;
}
// In theory, at least. Unfortunately, the java gRPC client always waits for the HTTP/2 handshake to complete
// (which requires the server sending back the ack) before it actually sends any requests. So if we waited for a
// DATA frame, it would never come. The gRPC-java implementation always at least sends a WINDOW_UPDATE, so we
// might as well pack that in.
return frameType == Http2Frame.FrameType.WINDOW_UPDATE;
}
}
/**
* Collect HTTP/2 frames until we get an entire "request" to send
*/
private static class Http2FirstRequestHandler extends ChannelInboundHandlerAdapter {
final List<Http2Frame> pendingFrames = new ArrayList<>();
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) {
if (message instanceof Http2Frame http2Frame) {
if (pendingFrames.isEmpty() && http2Frame.type != Http2Frame.FrameType.SETTINGS) {
throw new IllegalStateException(
"HTTP/2 stream must start with HTTP/2 SETTINGS frame, got " + http2Frame.type);
}
pendingFrames.add(http2Frame);
if (http2Frame.endStream) {
// We have a whole "request", emit the first request event and remove the http2 buffering handlers
final ByteBuf request = Unpooled.wrappedBuffer(Stream.concat(
Stream.of(Unpooled.wrappedBuffer(Http2PrefaceHandler.HTTP2_PREFACE)),
pendingFrames.stream().map(Http2Frame::bytes))
.toArray(ByteBuf[]::new));
pendingFrames.clear();
context.pipeline().remove(Http2LengthFieldFrameDecoder.class);
context.pipeline().remove(Http2FrameDecoder.class);
context.pipeline().remove(this);
context.fireUserEventTriggered(new FastOpenRequestBufferedEvent(request));
}
} else {
throw new IllegalStateException("Unexpected message: " + message);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
pendingFrames.forEach(frame -> ReferenceCountUtil.release(frame.bytes()));
pendingFrames.clear();
}
}
}

View File

@ -0,0 +1,108 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.util.Optional;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
@Override
protected NoiseAnonymousHandler getHandler(final ECKeyPair serverKeyPair) {
return new NoiseAnonymousHandler(serverKeyPair);
}
@Override
protected CipherStatePair doHandshake() throws Exception {
return doHandshake(new byte[0]);
}
private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
final HandshakeState clientHandshakeState =
new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0);
clientHandshakeState.start();
// Send initiator handshake message
// 32 byte key, request payload, 16 byte AEAD tag
final int initiateHandshakeMessageLength = 32 + requestPayload.length + 16;
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeMessageLength];
assertEquals(
initiateHandshakeMessageLength,
clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
final BinaryWebSocketFrame initiateHandshakeFrame = new BinaryWebSocketFrame(
Unpooled.wrappedBuffer(initiateHandshakeMessage));
assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeFrame).await().isSuccess());
assertEquals(0, initiateHandshakeFrame.refCnt());
embeddedChannel.runPendingTasks();
// Read responder handshake message
assertFalse(embeddedChannel.outboundMessages().isEmpty());
final BinaryWebSocketFrame responderHandshakeFrame = (BinaryWebSocketFrame)
embeddedChannel.outboundMessages().poll();
@SuppressWarnings("DataFlowIssue") final byte[] responderHandshakeBytes =
new byte[responderHandshakeFrame.content().readableBytes()];
responderHandshakeFrame.content().readBytes(responderHandshakeBytes);
// ephemeral key, empty encrypted payload AEAD tag
final byte[] handshakeResponsePayload = new byte[32 + 16];
assertEquals(0,
clientHandshakeState.readMessage(
responderHandshakeBytes, 0, responderHandshakeBytes.length,
handshakeResponsePayload, 0));
final byte[] serverPublicKey = new byte[32];
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes());
return clientHandshakeState.split();
}
@Test
void handleCompleteHandshakeWithRequest() throws ShortBufferException, BadPaddingException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class));
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake("ping".getBytes()));
final byte[] response = readNextPlaintext(cipherStatePair);
assertArrayEquals(response, "pong".getBytes());
assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
}
@Test
void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class));
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake(new byte[0]));
assertNull(readNextPlaintext(cipherStatePair));
assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
}
}

View File

@ -0,0 +1,301 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.util.internal.EmptyArrays;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
private ClientPublicKeysManager clientPublicKeysManager;
private final ECKeyPair clientKeyPair = Curve.generateKeyPair();
@Override
@BeforeEach
void setUp() {
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
super.setUp();
}
@Override
protected NoiseAuthenticatedHandler getHandler(final ECKeyPair serverKeyPair) {
return new NoiseAuthenticatedHandler(clientPublicKeysManager, serverKeyPair);
}
@Override
protected CipherStatePair doHandshake() throws Throwable {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
return doHandshake(identityPayload(accountIdentifier, deviceId));
}
@Test
void handleCompleteHandshakeNoInitialRequest() throws Throwable {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId))));
assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
getNoiseHandshakeCompleteEvent());
}
@Test
void handleCompleteHandshakeWithInitialRequest() throws Throwable {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final ByteBuffer bb = ByteBuffer.allocate(17 + 4);
bb.put(identityPayload(accountIdentifier, deviceId));
bb.put("ping".getBytes());
final byte[] response = readNextPlaintext(doHandshake(bb.array()));
assertEquals(response.length, 4);
assertEquals(new String(response), "pong");
assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
getNoiseHandshakeCompleteEvent());
}
@Test
void handleCompleteHandshakeMissingIdentityInformation() {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES));
verifyNoInteractions(clientPublicKeysManager);
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakeMalformedIdentityInformation() {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
// no deviceId byte
byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID());
assertThrows(NoiseHandshakeException.class, () -> doHandshake(malformedIdentityPayload));
verifyNoInteractions(clientPublicKeysManager);
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakeUnrecognizedDevice() {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakePublicKeyMismatch() {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
}
@Test
void handleInvalidExtraWrites() throws NoSuchAlgorithmException, ShortBufferException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final HandshakeState clientHandshakeState = clientHandshakeState();
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))));
assertTrue(embeddedChannel.writeOneInbound(initiatorMessageFrame).await().isSuccess());
// While waiting for the public key, send another message
final ChannelFuture f = embeddedChannel.writeOneInbound(
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(new byte[0]))).await();
assertInstanceOf(NoiseHandshakeException.class, f.exceptionNow());
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
embeddedChannel.runPendingTasks();
// shouldn't return any response or error, we've already processed an error
embeddedChannel.checkException();
assertNull(embeddedChannel.outboundMessages().poll());
}
private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException {
final HandshakeState clientHandshakeState =
new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0);
clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0);
clientHandshakeState.start();
return clientHandshakeState;
}
private byte[] initiatorHandshakeMessage(final HandshakeState clientHandshakeState, final byte[] payload)
throws ShortBufferException {
// Ephemeral key, encrypted static key, AEAD tag, encrypted payload, AEAD tag
final byte[] initiatorMessageBytes = new byte[32 + 32 + 16 + payload.length + 16];
int written = clientHandshakeState.writeMessage(initiatorMessageBytes, 0, payload, 0, payload.length);
assertEquals(written, initiatorMessageBytes.length);
return initiatorMessageBytes;
}
private byte[] readHandshakeResponse(final HandshakeState clientHandshakeState, final byte[] message)
throws ShortBufferException, BadPaddingException {
// 32 byte ephemeral server key, 16 byte AEAD tag for encrypted payload
final int expectedResponsePayloadLength = message.length - 32 - 16;
final byte[] responsePayload = new byte[expectedResponsePayloadLength];
final int responsePayloadLength = clientHandshakeState.readMessage(message, 0, message.length, responsePayload, 0);
assertEquals(expectedResponsePayloadLength, responsePayloadLength);
return responsePayload;
}
private CipherStatePair doHandshake(final byte[] payload) throws Throwable {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
final HandshakeState clientHandshakeState = clientHandshakeState();
final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame(
Unpooled.wrappedBuffer(initiatorMessage));
final ChannelFuture await = embeddedChannel.writeOneInbound(initiatorMessageFrame).await();
assertEquals(0, initiatorMessageFrame.refCnt());
if (!await.isSuccess()) {
throw await.cause();
}
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
// rethrow if running the task caused an error
embeddedChannel.checkException();
assertFalse(embeddedChannel.outboundMessages().isEmpty());
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
assertEquals(readHandshakeResponse(clientHandshakeState, serverStaticKeyMessageBytes).length, 0);
final byte[] serverPublicKey = new byte[32];
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes());
return clientHandshakeState.split();
}
private static byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) {
final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17);
clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits());
clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits());
clientIdentityPayloadBuffer.put(deviceId);
clientIdentityPayloadBuffer.flip();
return clientIdentityPayloadBuffer.array();
}
}

View File

@ -0,0 +1,15 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import java.util.Optional;
/**
* A netty user event that indicates that the noise handshake finished successfully.
*
* @param fastResponse A response if the client included a request to send in the initiate handshake message payload and
* the server included a payload in the handshake response.
*/
public record NoiseClientHandshakeCompleteEvent(Optional<byte[]> fastResponse) {}

View File

@ -0,0 +1,55 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import java.util.Optional;
class NoiseClientHandshakeHandler extends ChannelInboundHandlerAdapter {
private final NoiseClientHandshakeHelper handshakeHelper;
private final byte[] payload;
NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper, final byte[] payload) {
this.handshakeHelper = handshakeHelper;
this.payload = payload;
}
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
byte[] handshakeMessage = handshakeHelper.write(payload);
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
}
}
super.userEventTriggered(context, event);
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message)
throws NoiseHandshakeException {
if (message instanceof BinaryWebSocketFrame frame) {
try {
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame.content()));
final Optional<byte[]> fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload);
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse));
} finally {
frame.release();
}
} else {
context.fireChannelRead(message);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
handshakeHelper.destroy();
}
}

View File

@ -0,0 +1,93 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
import java.security.NoSuchAlgorithmException;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
public class NoiseClientHandshakeHelper {
private final HandshakePattern handshakePattern;
private final HandshakeState handshakeState;
private NoiseClientHandshakeHelper(HandshakePattern handshakePattern, HandshakeState handshakeState) {
this.handshakePattern = handshakePattern;
this.handshakeState = handshakeState;
}
static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) {
try {
final HandshakeState state = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
state.getLocalKeyPair().setPrivateKey(clientStaticKey.getPrivateKey().serialize(), 0);
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
state.start();
return new NoiseClientHandshakeHelper(HandshakePattern.IK, state);
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException(e);
}
}
static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) {
try {
final HandshakeState state = new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
state.start();
return new NoiseClientHandshakeHelper(HandshakePattern.NK, state);
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException(e);
}
}
byte[] write(final byte[] requestPayload) throws ShortBufferException {
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeKeysLength() + requestPayload.length + 16];
handshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length);
return initiateHandshakeMessage;
}
private int initiateHandshakeKeysLength() {
return switch (handshakePattern) {
// 32-byte ephemeral key, 32-byte encrypted static key, 16-byte AEAD tag
case IK -> 32 + 32 + 16;
// 32-byte ephemeral key
case NK -> 32;
};
}
byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException {
// Don't process additional messages if the handshake failed and we're just waiting to close
if (handshakeState.getAction() != HandshakeState.READ_MESSAGE) {
throw new NoiseHandshakeException("Received message with handshake state " + handshakeState.getAction());
}
final int payloadLength = responderHandshakeMessage.length - 16 - 32;
final byte[] responsePayload = new byte[payloadLength];
final int payloadBytesRead;
try {
payloadBytesRead = handshakeState
.readMessage(responderHandshakeMessage, 0, responderHandshakeMessage.length, responsePayload, 0);
if (payloadBytesRead != responsePayload.length) {
throw new IllegalStateException(
"Unexpected payload length, required " + payloadLength + " got " + payloadBytesRead);
}
return responsePayload;
} catch (ShortBufferException e) {
throw new IllegalStateException("Failed to deserialize payload of known length" + e.getMessage());
} catch (BadPaddingException e) {
throw new NoiseHandshakeException(e.getMessage());
}
}
CipherStatePair split() {
return this.handshakeState.split();
}
void destroy() {
this.handshakeState.destroy();
}
}

View File

@ -11,30 +11,26 @@ import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A Noise transport handler manages a bidirectional Noise session after a handshake has completed.
*/
class NoiseTransportHandler extends ChannelDuplexHandler {
class NoiseClientTransportHandler extends ChannelDuplexHandler {
private final CipherStatePair cipherStatePair;
private static final Logger log = LoggerFactory.getLogger(NoiseTransportHandler.class);
private static final Logger log = LoggerFactory.getLogger(NoiseClientTransportHandler.class);
NoiseTransportHandler(CipherStatePair cipherStatePair) {
NoiseClientTransportHandler(CipherStatePair cipherStatePair) {
this.cipherStatePair = cipherStatePair;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message)
throws ShortBufferException, BadPaddingException {
if (message instanceof BinaryWebSocketFrame frame) {
try {
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (message instanceof BinaryWebSocketFrame frame) {
final CipherState cipherState = cipherStatePair.getReceiver();
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
@ -45,22 +41,22 @@ class NoiseTransportHandler extends ChannelDuplexHandler {
final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length);
context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength));
} finally {
frame.release();
} else {
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
// error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
} else {
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
// error
} finally {
ReferenceCountUtil.release(message);
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
}
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
throws Exception {
if (message instanceof ByteBuf plaintext) {
try {
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
final CipherState cipherState = cipherStatePair.getSender();
final int plaintextLength = plaintext.readableBytes();
@ -75,7 +71,7 @@ class NoiseTransportHandler extends ChannelDuplexHandler {
context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
} finally {
plaintext.release();
ReferenceCountUtil.release(plaintext);
}
} else {
if (!(message instanceof WebSocketFrame)) {
@ -83,7 +79,6 @@ class NoiseTransportHandler extends ChannelDuplexHandler {
// get issued in response to exceptions)
log.warn("Unexpected object in pipeline: {}", message);
}
context.write(message, promise);
}
}

View File

@ -0,0 +1,66 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNoException;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBuf;
import java.nio.charset.StandardCharsets;
import javax.crypto.ShortBufferException;
import io.netty.buffer.ByteBufUtil;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
public class NoiseHandshakeHelperTest {
@ParameterizedTest
@EnumSource(HandshakePattern.class)
void testWithPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), "pong".getBytes(StandardCharsets.UTF_8));
}
@ParameterizedTest
@EnumSource(HandshakePattern.class)
void testWithRequestPayload(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), new byte[0]);
}
@ParameterizedTest
@EnumSource(HandshakePattern.class)
void testWithoutPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
doHandshake(pattern, new byte[0], new byte[0]);
}
void doHandshake(final HandshakePattern pattern, final byte[] requestPayload, final byte[] responsePayload) throws ShortBufferException, NoiseHandshakeException {
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
NoiseHandshakeHelper serverHelper = new NoiseHandshakeHelper(pattern, serverKeyPair);
NoiseClientHandshakeHelper clientHelper = switch (pattern) {
case IK -> NoiseClientHandshakeHelper.IK(serverKeyPair.getPublicKey(), clientKeyPair);
case NK -> NoiseClientHandshakeHelper.NK(serverKeyPair.getPublicKey());
};
final byte[] initiate = clientHelper.write(requestPayload);
final ByteBuf actualRequestPayload = serverHelper.read(initiate);
assertThat(ByteBufUtil.getBytes(actualRequestPayload)).isEqualTo(requestPayload);
assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.WRITE_MESSAGE);
final byte[] respond = serverHelper.write(responsePayload);
byte[] actualResponsePayload = clientHelper.read(respond);
assertThat(actualResponsePayload).isEqualTo(responsePayload);
assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.SPLIT);
assertThatNoException().isThrownBy(() -> serverHelper.getHandshakeState().split());
assertThatNoException().isThrownBy(() -> clientHelper.split());
}
}

View File

@ -1,47 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.util.Optional;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
class NoiseNXClientHandshakeHandler extends AbstractNoiseClientHandler {
private boolean receivedServerStaticKeyMessage = false;
NoiseNXClientHandshakeHandler(final ECPublicKey rootPublicKey) {
super(rootPublicKey);
}
@Override
protected String getNoiseProtocolName() {
return NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME;
}
@Override
protected void startHandshake() {
getHandshakeState().start();
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof BinaryWebSocketFrame frame) {
try {
// Don't process additional messages if we're just waiting to close because the handshake failed
if (receivedServerStaticKeyMessage) {
return;
}
receivedServerStaticKeyMessage = true;
handleServerStaticKeyMessage(context, frame);
context.pipeline().replace(this, null, new NoiseTransportHandler(getHandshakeState().split()));
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
} finally {
frame.release();
}
} else {
context.fireChannelRead(message);
}
}
}

View File

@ -1,84 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.security.NoSuchAlgorithmException;
import java.util.Optional;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
class NoiseNXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest {
@Override
protected NoiseNXHandshakeHandler getHandler(final ECKeyPair serverKeyPair,
final byte[] serverPublicKeySignature) {
return new NoiseNXHandshakeHandler(serverKeyPair, serverPublicKeySignature);
}
@Test
void handleCompleteHandshake()
throws NoSuchAlgorithmException, ShortBufferException, InterruptedException, BadPaddingException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class));
final HandshakeState clientHandshakeState =
new HandshakeState(NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
clientHandshakeState.start();
{
final byte[] ephemeralKeyMessageBytes = new byte[32];
clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0);
final BinaryWebSocketFrame ephemeralKeyMessageFrame =
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes));
assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess());
assertEquals(0, ephemeralKeyMessageFrame.refCnt());
}
{
assertEquals(1, embeddedChannel.outboundMessages().size());
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
final byte[] serverPublicKeySignature = new byte[64];
final int payloadLength =
clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0);
assertEquals(serverPublicKeySignature.length, payloadLength);
final byte[] serverPublicKey = new byte[32];
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature));
}
assertEquals(new NoiseHandshakeCompleteEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
assertNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class),
"Handshake handler should remove self from pipeline after successful handshake");
assertNotNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
"Handshake handler should insert a Noise stream handler after successful handshake");
}
}

View File

@ -1,135 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.ThreadLocalRandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import static org.junit.jupiter.api.Assertions.*;
class NoiseTransportHandlerTest extends AbstractLeakDetectionTest {
private CipherStatePair clientCipherStatePair;
private EmbeddedChannel embeddedChannel;
// We use an NN handshake for this test just because it's a little shorter and easier to set up
private static final String NOISE_PROTOCOL_NAME = "Noise_NN_25519_ChaChaPoly_BLAKE2b";
@BeforeEach
void setUp() throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException {
final HandshakeState clientHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
final HandshakeState serverHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.RESPONDER);
clientHandshakeState.start();
serverHandshakeState.start();
final byte[] clientEphemeralKeyMessage = new byte[32];
assertEquals(clientEphemeralKeyMessage.length,
clientHandshakeState.writeMessage(clientEphemeralKeyMessage, 0, null, 0, 0));
serverHandshakeState.readMessage(clientEphemeralKeyMessage, 0, clientEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
// 32 bytes of key material plus a 16-byte MAC
final byte[] serverEphemeralKeyMessage = new byte[48];
assertEquals(serverEphemeralKeyMessage.length,
serverHandshakeState.writeMessage(serverEphemeralKeyMessage, 0, null, 0, 0));
clientHandshakeState.readMessage(serverEphemeralKeyMessage, 0, serverEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
clientCipherStatePair = clientHandshakeState.split();
embeddedChannel = new EmbeddedChannel(new NoiseTransportHandler(serverHandshakeState.split()));
clientHandshakeState.destroy();
serverHandshakeState.destroy();
}
@Test
void channelRead() throws ShortBufferException, InterruptedException {
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext));
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
assertEquals(0, ciphertextFrame.refCnt());
final ByteBuf decryptedPlaintextBuffer = (ByteBuf) embeddedChannel.inboundMessages().poll();
assertNotNull(decryptedPlaintextBuffer);
assertTrue(embeddedChannel.inboundMessages().isEmpty());
final byte[] decryptedPlaintext = ByteBufUtil.getBytes(decryptedPlaintextBuffer);
decryptedPlaintextBuffer.release();
assertArrayEquals(plaintext, decryptedPlaintext);
}
@Test
void channelReadBadCiphertext() throws InterruptedException {
final byte[] bogusCiphertext = new byte[32];
ThreadLocalRandom.current().nextBytes(bogusCiphertext);
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext));
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
assertEquals(0, ciphertextFrame.refCnt());
assertFalse(readCiphertextFuture.isSuccess());
assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
@Test
void channelReadUnexpectedMessageType() throws InterruptedException {
final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await();
assertFalse(readFuture.isSuccess());
assertInstanceOf(IllegalArgumentException.class, readFuture.cause());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
@Test
void write() throws InterruptedException, ShortBufferException, BadPaddingException {
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext);
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
assertTrue(writePlaintextFuture.await().isSuccess());
assertEquals(0, plaintextBuffer.refCnt());
final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
assertNotNull(ciphertextFrame);
assertTrue(embeddedChannel.outboundMessages().isEmpty());
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
ciphertextFrame.release();
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length);
assertArrayEquals(plaintext, decryptedPlaintext);
}
@Test
void writeUnexpectedMessageType() throws InterruptedException {
final Object unexpectedMessaged = new Object();
final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged);
assertTrue(writeFuture.await().isSuccess());
assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll());
assertTrue(embeddedChannel.outboundMessages().isEmpty());
}
}

View File

@ -1,22 +1,26 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import java.net.SocketAddress;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.UUID;
import java.util.function.Function;
import java.util.function.Supplier;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.http.HttpHeaders;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import javax.annotation.Nullable;
class NoiseWebSocketTunnelClient implements AutoCloseable {
@ -26,19 +30,86 @@ class NoiseWebSocketTunnelClient implements AutoCloseable {
static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
public NoiseWebSocketTunnelClient(final SocketAddress remoteServerAddress,
final URI websocketUri,
final boolean authenticated,
final ECKeyPair ecKeyPair,
final ECPublicKey rootPublicKey,
@Nullable final UUID accountIdentifier,
final byte deviceId,
final HttpHeaders headers,
final boolean useTls,
@Nullable final X509Certificate trustedServerCertificate,
@Nullable final Supplier<HAProxyMessage> proxyMessageSupplier,
final NioEventLoopGroup eventLoopGroup,
final WebSocketCloseListener webSocketCloseListener) {
static class Builder {
final SocketAddress remoteServerAddress;
NioEventLoopGroup eventLoopGroup;
ECPublicKey serverPublicKey;
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
HttpHeaders headers = new DefaultHttpHeaders();
WebSocketCloseListener webSocketCloseListener = WebSocketCloseListener.NOOP_LISTENER;
boolean authenticated = false;
ECKeyPair ecKeyPair = null;
UUID accountIdentifier = null;
byte deviceId = 0x00;
boolean useTls;
X509Certificate trustedServerCertificate = null;
Supplier<HAProxyMessage> proxyMessageSupplier = null;
Builder(
final SocketAddress remoteServerAddress,
final NioEventLoopGroup eventLoopGroup,
final ECPublicKey serverPublicKey) {
this.remoteServerAddress = remoteServerAddress;
this.eventLoopGroup = eventLoopGroup;
this.serverPublicKey = serverPublicKey;
}
Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
this.authenticated = true;
this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId;
this.ecKeyPair = ecKeyPair;
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
return this;
}
Builder setWebsocketUri(final URI websocketUri) {
this.websocketUri = websocketUri;
return this;
}
Builder setUseTls(X509Certificate trustedServerCertificate) {
this.useTls = true;
this.trustedServerCertificate = trustedServerCertificate;
return this;
}
Builder setProxyMessageSupplier(Supplier<HAProxyMessage> proxyMessageSupplier) {
this.proxyMessageSupplier = proxyMessageSupplier;
return this;
}
Builder setHeaders(final HttpHeaders headers) {
this.headers = headers;
return this;
}
Builder setWebSocketCloseListener(final WebSocketCloseListener webSocketCloseListener) {
this.webSocketCloseListener = webSocketCloseListener;
return this;
}
Builder setServerPublicKey(ECPublicKey serverPublicKey) {
this.serverPublicKey = serverPublicKey;
return this;
}
NoiseWebSocketTunnelClient build() {
final NoiseWebSocketTunnelClient client =
new NoiseWebSocketTunnelClient(eventLoopGroup, fastOpenRequest -> new EstablishRemoteConnectionHandler(
useTls, trustedServerCertificate, websocketUri, authenticated, ecKeyPair, serverPublicKey,
accountIdentifier, deviceId, headers, remoteServerAddress, webSocketCloseListener, proxyMessageSupplier,
fastOpenRequest));
client.start();
return client;
}
}
private NoiseWebSocketTunnelClient(NioEventLoopGroup eventLoopGroup,
Function<byte[], EstablishRemoteConnectionHandler> handler) {
this.serverBootstrap = new ServerBootstrap()
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
@ -47,28 +118,37 @@ class NoiseWebSocketTunnelClient implements AutoCloseable {
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(final LocalChannel localChannel) {
localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(useTls,
trustedServerCertificate,
websocketUri,
authenticated,
ecKeyPair,
rootPublicKey,
accountIdentifier,
deviceId,
headers,
remoteServerAddress,
webSocketCloseListener,
proxyMessageSupplier));
localChannel.pipeline()
// We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the
// stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put
// in the handshake.
.addLast(Http2Buffering.handler())
// Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At
// that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually
// connect to the remote service
.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) {
byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest());
requestBufferedEvent.fastOpenRequest().release();
ctx.pipeline().addLast(handler.apply(fastOpenRequest));
}
super.userEventTriggered(ctx, evt);
}
})
.addLast(new ClientErrorHandler());
}
});
}
LocalAddress getLocalAddress() {
return (LocalAddress) serverChannel.localAddress();
}
NoiseWebSocketTunnelClient start() throws InterruptedException {
serverChannel = serverBootstrap.bind().await().channel();
private NoiseWebSocketTunnelClient start() {
serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel();
return this;
}
@ -76,4 +156,5 @@ class NoiseWebSocketTunnelClient implements AutoCloseable {
public void close() throws InterruptedException {
serverChannel.close().await();
}
}

View File

@ -50,6 +50,7 @@ import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import org.apache.commons.lang3.RandomStringUtils;
@ -67,7 +68,6 @@ import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
@ -89,7 +89,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private ClientConnectionManager clientConnectionManager;
private ClientPublicKeysManager clientPublicKeysManager;
private ECKeyPair rootKeyPair;
private ECKeyPair serverKeyPair;
private ECKeyPair clientKeyPair;
private ManagedLocalGrpcServer authenticatedGrpcServer;
@ -153,9 +153,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
}
rootKeyPair = Curve.generateKeyPair();
clientKeyPair = Curve.generateKeyPair();
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
serverKeyPair = Curve.generateKeyPair();
clientConnectionManager = new ClientConnectionManager();
@ -192,14 +191,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServer.start();
tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
new X509Certificate[] { serverTlsCertificate },
new X509Certificate[]{serverTlsCertificate},
serverTlsPrivateKey,
nioEventLoopGroup,
delegatedTaskExecutor,
clientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
@ -214,7 +212,6 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
clientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
@ -241,9 +238,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
@ParameterizedTest
@ValueSource(booleans = { true, false })
@ValueSource(booleans = {true, false})
void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), includeProxyMessage)) {
try (final NoiseWebSocketTunnelClient client = authenticated()
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
@ -259,24 +258,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
@ParameterizedTest
@ValueSource(booleans = { true, false })
@ValueSource(booleans = {true, false})
void connectAuthenticatedPlaintext(final boolean includeProxyMessage) throws InterruptedException {
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient(
tlsNoiseWebSocketTunnelServer.getLocalAddress(),
NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
true,
clientKeyPair,
rootKeyPair.getPublicKey(),
ACCOUNT_IDENTIFIER,
DEVICE_ID,
new DefaultHttpHeaders(),
true,
serverTlsCertificate,
includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null,
nioEventLoopGroup,
WebSocketCloseListener.NOOP_LISTENER)
.start()) {
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient
.Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID)
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
@ -295,10 +283,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
void connectAuthenticatedBadServerKeySignature() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
// Try to verify the server's public key with something other than the key with which it was signed
try (final NoiseWebSocketTunnelClient client =
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders(), false)) {
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
@ -312,7 +301,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
@Test
@ -322,7 +312,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(webSocketCloseListener)) {
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
@ -335,7 +327,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
}
@Test
@ -345,8 +338,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final NoiseWebSocketTunnelClient client =
buildAndStartAuthenticatedClient(webSocketCloseListener)) {
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
@ -360,29 +354,18 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
}
@Test
void connectAuthenticatedToAnonymousService() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient(
tlsNoiseWebSocketTunnelServer.getLocalAddress(),
URI.create("wss://localhost/anonymous"),
true,
clientKeyPair,
rootKeyPair.getPublicKey(),
ACCOUNT_IDENTIFIER,
DEVICE_ID,
new DefaultHttpHeaders(),
true,
serverTlsCertificate,
null,
nioEventLoopGroup,
webSocketCloseListener)
.start()) {
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebsocketUri(NoiseWebSocketTunnelClient.ANONYMOUS_WEBSOCKET_URI)
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
@ -395,12 +378,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
@Test
void connectAnonymous() throws InterruptedException {
try (final NoiseWebSocketTunnelClient client = buildAndStartAnonymousClient()) {
try (final NoiseWebSocketTunnelClient client = anonymous().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
@ -420,9 +404,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
// Try to verify the server's public key with something other than the key with which it was signed
try (final NoiseWebSocketTunnelClient client =
buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) {
try (final NoiseWebSocketTunnelClient client = anonymous()
.setWebSocketCloseListener(webSocketCloseListener)
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
@ -435,29 +420,18 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
@Test
void connectAnonymousToAuthenticatedService() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient(
tlsNoiseWebSocketTunnelServer.getLocalAddress(),
URI.create("wss://localhost/authenticated"),
false,
null,
rootKeyPair.getPublicKey(),
null,
(byte) 0,
new DefaultHttpHeaders(),
true,
serverTlsCertificate,
null,
nioEventLoopGroup,
webSocketCloseListener)
.start()) {
try (final NoiseWebSocketTunnelClient client = anonymous()
.setWebsocketUri(NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI)
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
@ -470,7 +444,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
private ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
@ -506,8 +481,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
.uri(authenticatedUri)
.PUT(HttpRequest.BodyPublishers.ofString("test"))
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"Non-GET requests should not be allowed");
assertEquals(426, httpClient.send(HttpRequest.newBuilder()
@ -538,8 +513,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
.add("Accept-Language", acceptLanguage)
.add("User-Agent", userAgent);
try (final NoiseWebSocketTunnelClient client =
buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), headers)) {
try (final NoiseWebSocketTunnelClient client = anonymous().setHeaders(headers).build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
@ -582,7 +556,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
};
try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(webSocketCloseListener)) {
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
@ -606,63 +582,24 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException {
return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER);
private NoiseWebSocketTunnelClient.Builder anonymous() {
return new NoiseWebSocketTunnelClient
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
.setUseTls(serverTlsCertificate);
}
private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener)
throws InterruptedException {
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), false);
private NoiseWebSocketTunnelClient.Builder authenticated() {
return new NoiseWebSocketTunnelClient
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID)
.setUseTls(serverTlsCertificate);
}
private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener,
final ECPublicKey rootPublicKey,
final HttpHeaders headers,
final boolean includeProxyMessage) throws InterruptedException {
return new NoiseWebSocketTunnelClient(tlsNoiseWebSocketTunnelServer.getLocalAddress(),
NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
true,
clientKeyPair,
rootPublicKey,
ACCOUNT_IDENTIFIER,
DEVICE_ID,
headers,
true,
serverTlsCertificate,
includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null,
nioEventLoopGroup,
webSocketCloseListener)
.start();
}
private NoiseWebSocketTunnelClient buildAndStartAnonymousClient() throws InterruptedException {
return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders());
}
private NoiseWebSocketTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener,
final ECPublicKey rootPublicKey,
final HttpHeaders headers) throws InterruptedException {
return new NoiseWebSocketTunnelClient(tlsNoiseWebSocketTunnelServer.getLocalAddress(),
NoiseWebSocketTunnelClient.ANONYMOUS_WEBSOCKET_URI,
false,
null,
rootPublicKey,
null,
(byte) 0,
headers,
true,
serverTlsCertificate,
null,
nioEventLoopGroup,
webSocketCloseListener)
.start();
}
private static HAProxyMessage buildProxyMessage() {
return new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"10.0.0.1", "10.0.0.2", 12345, 443);
private static Supplier<HAProxyMessage> proxyMessageSupplier(boolean includeProxyMesage) {
return includeProxyMesage
? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"10.0.0.1", "10.0.0.2", 12345, 443)
: null;
}
}

View File

@ -1,89 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.UUID;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
class NoiseXXClientHandshakeHandler extends AbstractNoiseClientHandler {
private final ECKeyPair ecKeyPair;
private final UUID accountIdentifier;
private final byte deviceId;
private boolean receivedServerStaticKeyMessage = false;
NoiseXXClientHandshakeHandler(final ECKeyPair ecKeyPair,
final ECPublicKey rootPublicKey,
final UUID accountIdentifier,
final byte deviceId) {
super(rootPublicKey);
this.ecKeyPair = ecKeyPair;
this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId;
}
@Override
protected String getNoiseProtocolName() {
return NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME;
}
@Override
protected void startHandshake() {
final HandshakeState handshakeState = getHandshakeState();
// Noise-java derives the public key from the private key, so we just need to set the private key
handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0);
handshakeState.start();
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message)
throws NoiseHandshakeException, ShortBufferException {
if (message instanceof BinaryWebSocketFrame frame) {
try {
// Don't process additional messages if the handshake failed and we're just waiting to close
if (receivedServerStaticKeyMessage) {
return;
}
receivedServerStaticKeyMessage = true;
handleServerStaticKeyMessage(context, frame);
final ByteBuffer clientIdentityBuffer = ByteBuffer.allocate(17);
clientIdentityBuffer.putLong(accountIdentifier.getMostSignificantBits());
clientIdentityBuffer.putLong(accountIdentifier.getLeastSignificantBits());
clientIdentityBuffer.put(deviceId);
clientIdentityBuffer.flip();
final HandshakeState handshakeState = getHandshakeState();
// We're sending two 32-byte keys plus the client identity payload
final byte[] staticKeyAndIdentityMessage = new byte[64 + clientIdentityBuffer.remaining()];
handshakeState.writeMessage(
staticKeyAndIdentityMessage, 0, clientIdentityBuffer.array(), 0, clientIdentityBuffer.remaining());
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(staticKeyAndIdentityMessage)))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
context.pipeline().replace(this, null, new NoiseTransportHandler(handshakeState.split()));
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
} finally {
frame.release();
}
} else {
context.fireChannelRead(message);
}
}
}

View File

@ -1,453 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.southernstorm.noise.protocol.CipherState;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import io.netty.util.internal.EmptyArrays;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
class NoiseXXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest {
private ClientPublicKeysManager clientPublicKeysManager;
@Override
@BeforeEach
void setUp() {
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
super.setUp();
}
@Override
protected NoiseXXHandshakeHandler getHandler(final ECKeyPair serverKeyPair,
final byte[] serverPublicKeySignature) {
return new NoiseXXHandshakeHandler(clientPublicKeysManager, serverKeyPair, serverPublicKeySignature);
}
@Test
void handleCompleteHandshake()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
getNoiseHandshakeCompleteEvent());
assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should remove self from pipeline after successful handshake");
assertNotNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
"Handshake handler should insert a Noise stream handler after successful handshake");
}
@Test
void handleCompleteHandshakeMissingIdentityInformation()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
{
final byte[] clientStaticKeyMessageBytes = new byte[64];
final int messageLength =
clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, EmptyArrays.EMPTY_BYTES, 0, 0);
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
final ChannelFuture writeClientStaticKeyMessageFuture =
getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await();
assertFalse(writeClientStaticKeyMessageFuture.isSuccess());
assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause());
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
}
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakeMalformedIdentityInformation()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
{
final byte[] clientStaticKeyMessageBytes = new byte[96];
final int messageLength =
clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, new byte[32], 0, 32);
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
final ChannelFuture writeClientStaticKeyMessageFuture =
getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await();
assertFalse(writeClientStaticKeyMessageFuture.isSuccess());
assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause());
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
}
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakeUnrecognizedDevice()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException);
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakePublicKeyMismatch()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException);
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakeBufferedReads()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
final ByteBuf[] additionalMessages = new ByteBuf[4];
final CipherState senderState = clientHandshakeState.split().getSender();
try {
for (int i = 0; i < additionalMessages.length; i++) {
final byte[] contentBytes = new byte[32];
ThreadLocalRandom.current().nextBytes(contentBytes);
// Copy the "plaintext" portion of the content bytes for future assertions
additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16);
// Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD
// tag
senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16);
assertTrue(
embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await()
.isSuccess());
}
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
getNoiseHandshakeCompleteEvent());
assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should remove self from pipeline after successful handshake");
assertNotNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
"Handshake handler should insert a Noise stream handler after successful handshake");
for (final ByteBuf additionalMessage : additionalMessages) {
assertEquals(additionalMessage, embeddedChannel.inboundMessages().poll(),
"Buffered message should pass through pipeline after successful handshake");
}
} finally {
for (final ByteBuf additionalMessage : additionalMessages) {
additionalMessage.release();
}
}
}
@Test
void handleCompleteHandshakeFailureBufferedReads()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
final ByteBuf[] additionalMessages = new ByteBuf[4];
final CipherState senderState = clientHandshakeState.split().getSender();
try {
for (int i = 0; i < additionalMessages.length; i++) {
final byte[] contentBytes = new byte[32];
ThreadLocalRandom.current().nextBytes(contentBytes);
// Copy the "plaintext" portion of the content bytes for future assertions
additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16);
// Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD
// tag
senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16);
assertTrue(embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await().isSuccess());
}
findPublicKeyFuture.complete(Optional.empty());
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
assertTrue(embeddedChannel.inboundMessages().isEmpty(),
"Buffered messages should not pass through pipeline after failed handshake");
} finally {
for (final ByteBuf additionalMessage : additionalMessages) {
additionalMessage.release();
}
}
}
private HandshakeState exchangeClientEphemeralAndServerStaticMessages(final ECKeyPair clientKeyPair)
throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
final HandshakeState clientHandshakeState =
new HandshakeState(NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0);
clientHandshakeState.start();
{
final byte[] ephemeralKeyMessageBytes = new byte[32];
clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0);
final BinaryWebSocketFrame ephemeralKeyMessageFrame =
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes));
assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess());
assertEquals(0, ephemeralKeyMessageFrame.refCnt());
}
{
assertEquals(1, embeddedChannel.outboundMessages().size());
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
final byte[] serverPublicKeySignature = new byte[64];
final int payloadLength =
clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0);
assertEquals(serverPublicKeySignature.length, payloadLength);
final byte[] serverPublicKey = new byte[32];
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature));
}
return clientHandshakeState;
}
private void sendClientStaticKey(final HandshakeState handshakeState, final UUID accountIdentifier, final byte deviceId)
throws ShortBufferException, InterruptedException {
final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17);
clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits());
clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits());
clientIdentityPayloadBuffer.put(deviceId);
clientIdentityPayloadBuffer.flip();
final byte[] clientStaticKeyMessageBytes = new byte[81];
final int messageLength =
handshakeState.writeMessage(clientStaticKeyMessageBytes, 0, clientIdentityPayloadBuffer.array(), 0, clientIdentityPayloadBuffer.remaining());
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
assertTrue(getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await().isSuccess());
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
}
}

View File

@ -0,0 +1,80 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A TypedNoiseChannelDuplexHandler is a convenience {@link ChannelDuplexHandler} that can be inserted in a pipeline
* after a successful websocket handshake. It expects inbound messages to be {@link BinaryWebSocketFrame}s and outbound
* messages to be bytes.
*/
abstract class TypedNoiseChannelDuplexHandler extends ChannelDuplexHandler {
private static final Logger log = LoggerFactory.getLogger(TypedNoiseChannelDuplexHandler.class);
/**
* Handle an inbound message. The frame will be automatically released after the method is finished running.
*
* @param context The current {@link ChannelHandlerContext}
* @param frameBytes A {@link ByteBuf} extracted from a {@link BinaryWebSocketFrame} that contains a complete noise
* packet
* @throws Exception
*/
abstract void handleInbound(final ChannelHandlerContext context, ByteBuf frameBytes) throws Exception;
/**
* Handle an outbound byte message. The message will be automatically released after the method is finished running.
*
* @param context The current {@link ChannelHandlerContext}
* @param bytes The bytes to write
* @throws Exception
*/
abstract void handleOutbound(final ChannelHandlerContext context, final ByteBuf bytes,
final ChannelPromise promise) throws Exception;
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (message instanceof BinaryWebSocketFrame frame) {
handleInbound(context, frame.content());
} else {
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
// error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
} finally {
ReferenceCountUtil.release(message);
}
}
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
throws Exception {
if (message instanceof ByteBuf serverResponse) {
try {
handleOutbound(context, serverResponse, promise);
} finally {
ReferenceCountUtil.release(serverResponse);
}
} else {
if (!(message instanceof WebSocketFrame)) {
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
// get issued in response to exceptions)
log.warn("Unexpected object in pipeline: {}", message);
}
context.write(message, promise);
}
}
}

View File

@ -79,7 +79,6 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
embeddedChannel = new MutableRemoteAddressEmbeddedChannel(
new WebsocketHandshakeCompleteHandler(mock(ClientPublicKeysManager.class),
Curve.generateKeyPair(),
new byte[64],
RECOGNIZED_PROXY_SECRET),
userEventRecordingHandler);
@ -88,7 +87,7 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends ChannelHandler> expectedHandlerClass) {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
@ -102,8 +101,8 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
private static List<Arguments> handleWebSocketHandshakeComplete() {
return List.of(
Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseAuthenticatedHandler.class),
Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseAnonymousHandler.class));
}
@Test

View File

@ -130,5 +130,7 @@ linkDevice.secret: AAAAAAAAAAA=
tlsKeyStore.password: unset
noiseTunnel.noiseStaticPrivateKey: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
# The below private key was generated exclusively for testing purposes. Do not use it in any other context.
# Corresponding public key: cYUAFtkWK/4x3AfW/yw7qgIo/mQUaRSWaPolGQkiL14=
noiseTunnel.noiseStaticPrivateKey: qK5FD9WmuhoLPsS/Z4swcZkwDn9OpeM5ZmcEVMpEQ24=
noiseTunnel.recognizedProxySecret: ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789AAAAAAA

View File

@ -452,7 +452,6 @@ callingTurnManualTable:
noiseTunnel:
port: 8443
noiseStaticPrivateKey: secret://noiseTunnel.noiseStaticPrivateKey
noiseRootPublicKeySignature: ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret
externalRequestFilter: