Replace XX/NX handshakes with IK/NK
This commit is contained in:
parent
c835d85256
commit
542422b7b8
|
@ -478,7 +478,6 @@ noiseTunnel:
|
|||
tlsKeyStoreEntryAlias: example.com
|
||||
tlsKeyStorePassword: secret://noiseTunnel.tlsKeyStorePassword
|
||||
noiseStaticPrivateKey: secret://noiseTunnel.noiseStaticPrivateKey
|
||||
noiseRootPublicKeySignature: ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
|
||||
recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret
|
||||
|
||||
externalRequestFilter:
|
||||
|
|
|
@ -933,7 +933,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
clientConnectionManager,
|
||||
clientPublicKeysManager,
|
||||
config.getNoiseWebSocketTunnelConfiguration().noiseStaticKeyPair(),
|
||||
config.getNoiseWebSocketTunnelConfiguration().noiseRootPublicKeySignature(),
|
||||
authenticatedGrpcServerAddress,
|
||||
anonymousGrpcServerAddress,
|
||||
config.getNoiseWebSocketTunnelConfiguration().recognizedProxySecret().value());
|
||||
|
|
|
@ -15,7 +15,6 @@ public record NoiseWebSocketTunnelConfiguration(@Positive int port,
|
|||
@Nullable String tlsKeyStoreEntryAlias,
|
||||
@Nullable SecretString tlsKeyStorePassword,
|
||||
@NotNull SecretBytes noiseStaticPrivateKey,
|
||||
@NotNull byte[] noiseRootPublicKeySignature,
|
||||
@NotNull SecretString recognizedProxySecret) {
|
||||
|
||||
public ECKeyPair noiseStaticKeyPair() throws InvalidKeyException {
|
||||
|
|
|
@ -1,124 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
/**
|
||||
* An abstract base class for XX- and NX-patterned Noise responder handshake handlers.
|
||||
*
|
||||
* @see <a href="https://noiseprotocol.org/noise.html">The Noise Protocol Framework</a>
|
||||
*/
|
||||
abstract class AbstractNoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final ECKeyPair ecKeyPair;
|
||||
private final byte[] publicKeySignature;
|
||||
|
||||
private final HandshakeState handshakeState;
|
||||
|
||||
private static final int EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH = 32;
|
||||
|
||||
/**
|
||||
* Constructs a new Noise handler with the given static server keys and static public key signature. The static public
|
||||
* key must be signed by a trusted root private key whose public key is known to and trusted by authenticating
|
||||
* clients.
|
||||
*
|
||||
* @param noiseProtocolName the name of the Noise protocol implemented by this handshake handler
|
||||
* @param ecKeyPair the static key pair for this server
|
||||
* @param publicKeySignature an Ed25519 signature of the raw bytes of the static public key
|
||||
*/
|
||||
AbstractNoiseHandshakeHandler(final String noiseProtocolName,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final byte[] publicKeySignature) {
|
||||
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
this.publicKeySignature = publicKeySignature;
|
||||
|
||||
try {
|
||||
this.handshakeState = new HandshakeState(noiseProtocolName, HandshakeState.RESPONDER);
|
||||
} catch (final NoSuchAlgorithmException e) {
|
||||
throw new AssertionError("Unsupported Noise algorithm: " + noiseProtocolName, e);
|
||||
}
|
||||
}
|
||||
|
||||
protected HandshakeState getHandshakeState() {
|
||||
return handshakeState;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles an initial ephemeral key message from a client, advancing the handshake state and sending the server's
|
||||
* static keys to the client. Both XX and NX patterns begin with a client sending its ephemeral key to the server.
|
||||
* Clients must not include an additional payload with their ephemeral key message. The server's reply contains its
|
||||
* static keys along with an Ed25519 signature of its public static key by a trusted root key.
|
||||
*
|
||||
* @param context the channel handler context for this message
|
||||
* @param frame the websocket frame containing the ephemeral key message
|
||||
*
|
||||
* @throws NoiseHandshakeException if the ephemeral key message from the client was not of the expected size or if a
|
||||
* general Noise encryption error occurred
|
||||
*/
|
||||
protected void handleEphemeralKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
|
||||
throws NoiseHandshakeException {
|
||||
|
||||
if (frame.content().readableBytes() != EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH) {
|
||||
throw new NoiseHandshakeException("Unexpected ephemeral key message length");
|
||||
}
|
||||
|
||||
// Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is
|
||||
// making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key
|
||||
// from the supplied private key (and will in fact overwrite any previously-set public key when setting a private
|
||||
// key), so we just set the private key here.
|
||||
handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0);
|
||||
handshakeState.start();
|
||||
|
||||
// The initial message from the client should just include a plaintext ephemeral key with no payload. The frame is
|
||||
// coming off the wire and so will be in a direct buffer that doesn't have a backing array.
|
||||
final byte[] ephemeralKeyMessage = ByteBufUtil.getBytes(frame.content());
|
||||
frame.content().readBytes(ephemeralKeyMessage);
|
||||
|
||||
try {
|
||||
handshakeState.readMessage(ephemeralKeyMessage, 0, ephemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
|
||||
} catch (final ShortBufferException e) {
|
||||
// This should never happen since we're checking the length of the frame up front
|
||||
throw new NoiseHandshakeException("Unexpected client payload");
|
||||
} catch (final BadPaddingException e) {
|
||||
// It turns out this should basically never happen because (a) we're not using padding and (b) the "bad AEAD tag"
|
||||
// subclass of a bad padding exception can only happen if we have some AD to check, which we don't for an
|
||||
// ephemeral-key-only message
|
||||
throw new NoiseHandshakeException("Invalid keys");
|
||||
}
|
||||
|
||||
// Send our key material and public key signature back to the client; this buffer will include:
|
||||
//
|
||||
// - A 32-byte plaintext ephemeral key
|
||||
// - A 32-byte encrypted static key
|
||||
// - A 16-byte AEAD tag for the static key
|
||||
// - The public key signature payload
|
||||
// - A 16-byte AEAD tag for the payload
|
||||
final byte[] keyMaterial = new byte[32 + 32 + 16 + publicKeySignature.length + 16];
|
||||
|
||||
try {
|
||||
handshakeState.writeMessage(keyMaterial, 0, publicKeySignature, 0, publicKeySignature.length);
|
||||
|
||||
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(keyMaterial)))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
} catch (final ShortBufferException e) {
|
||||
// This should never happen for messages of known length that we control
|
||||
throw new AssertionError("Key material buffer was too short for message", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
handshakeState.destroy();
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@ import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
|||
import javax.crypto.BadPaddingException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
||||
/**
|
||||
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a
|
||||
|
@ -38,7 +39,7 @@ class ErrorHandler extends ChannelInboundHandlerAdapter {
|
|||
@Override
|
||||
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
|
||||
if (websocketHandshakeComplete) {
|
||||
final WebSocketCloseStatus webSocketCloseStatus = switch (cause) {
|
||||
final WebSocketCloseStatus webSocketCloseStatus = switch (ExceptionUtils.unwrap(cause)) {
|
||||
case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage());
|
||||
case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated");
|
||||
case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
|
||||
|
@ -51,6 +52,7 @@ class ErrorHandler extends ChannelInboundHandlerAdapter {
|
|||
context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
} else {
|
||||
log.debug("Error occurred before websocket handshake complete", cause);
|
||||
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
|
||||
// way; just close the connection instead.
|
||||
context.close();
|
||||
|
|
|
@ -48,12 +48,12 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
|
||||
if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
|
||||
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
|
||||
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
|
||||
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
|
||||
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
|
||||
// authenticated service.
|
||||
final LocalAddress grpcServerAddress = noiseHandshakeCompleteEvent.authenticatedDevice().isPresent()
|
||||
final LocalAddress grpcServerAddress = noiseIdentityDeterminedEvent.authenticatedDevice().isPresent()
|
||||
? authenticatedGrpcServerAddress
|
||||
: anonymousGrpcServerAddress;
|
||||
|
||||
|
@ -72,7 +72,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
if (localChannelFuture.isSuccess()) {
|
||||
clientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
|
||||
remoteChannelContext.channel(),
|
||||
noiseHandshakeCompleteEvent.authenticatedDevice());
|
||||
noiseIdentityDeterminedEvent.authenticatedDevice());
|
||||
|
||||
// Close the local connection if the remote channel closes and vice versa
|
||||
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
enum HandshakePattern {
|
||||
NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
|
||||
IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");
|
||||
|
||||
private final String protocol;
|
||||
|
||||
public String protocol() {
|
||||
return protocol;
|
||||
}
|
||||
|
||||
|
||||
HandshakePattern(String protocol) {
|
||||
this.protocol = protocol;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
/**
|
||||
* A NoiseAnonymousHandler is a netty pipeline element that handles the responder side of an unauthenticated handshake
|
||||
* and noise encryption/decryption.
|
||||
* <p>
|
||||
* A noise NK handshake must be used for unauthenticated connections. Optionally, the initiator can also include an
|
||||
* initial request in their payload. If provided, this allows the server to begin processing the request without an
|
||||
* initial message delay (fast open).
|
||||
* <p>
|
||||
* Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent}
|
||||
* indicating that initiator connected anonymously.
|
||||
*/
|
||||
class NoiseAnonymousHandler extends NoiseHandler {
|
||||
|
||||
public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) {
|
||||
super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair));
|
||||
}
|
||||
|
||||
@Override
|
||||
CompletableFuture<HandshakeResult> handleHandshakePayload(final ChannelHandlerContext context,
|
||||
final Optional<byte[]> initiatorPublicKey, final ByteBuf handshakePayload) {
|
||||
return CompletableFuture.completedFuture(new HandshakeResult(
|
||||
handshakePayload,
|
||||
Optional.empty()
|
||||
));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.security.MessageDigest;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
||||
/**
|
||||
* A NoiseAuthenticatedHandler is a netty pipeline element that handles the responder side of an authenticated handshake
|
||||
* and noise encryption/decryption. Authenticated handshakes are noise IK handshakes where the initiator's static public
|
||||
* key is authenticated by the responder.
|
||||
* <p>
|
||||
* The authenticated handshake requires the initiator to provide a payload with their first handshake message that
|
||||
* includes their account identifier and device id in network byte-order. Optionally, the initiator can also include an
|
||||
* initial request in their payload. If provided, this allows the server to begin processing the request without an
|
||||
* initial message delay (fast open).
|
||||
* <pre>
|
||||
* +-----------------+----------------+------------------------+
|
||||
* | UUID (16) | deviceId (1) | request bytes (N) |
|
||||
* +-----------------+----------------+------------------------+
|
||||
* </pre>
|
||||
* <p>
|
||||
* For a successful handshake, the static key provided in the handshake message must match the server's stored public
|
||||
* key for the device identified by the provided ACI and deviceId.
|
||||
* <p>
|
||||
* As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}.
|
||||
*/
|
||||
class NoiseAuthenticatedHandler extends NoiseHandler {
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair) {
|
||||
super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair));
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
CompletableFuture<HandshakeResult> handleHandshakePayload(
|
||||
final ChannelHandlerContext context,
|
||||
final Optional<byte[]> initiatorPublicKey,
|
||||
final ByteBuf handshakePayload) throws NoiseHandshakeException {
|
||||
if (handshakePayload.readableBytes() < 17) {
|
||||
throw new NoiseHandshakeException("Invalid handshake payload");
|
||||
}
|
||||
|
||||
final byte[] publicKeyFromClient = initiatorPublicKey
|
||||
.orElseThrow(() -> new IllegalStateException("No remote public key"));
|
||||
|
||||
// Advances the read index by 16 bytes
|
||||
final UUID accountIdentifier = parseUUID(handshakePayload);
|
||||
|
||||
// Advances the read index by 1 byte
|
||||
final byte deviceId = handshakePayload.readByte();
|
||||
|
||||
final ByteBuf fastOpenRequest = handshakePayload.slice();
|
||||
return clientPublicKeysManager
|
||||
.findPublicKey(accountIdentifier, deviceId)
|
||||
.handleAsync((storedPublicKey, throwable) -> {
|
||||
if (throwable != null) {
|
||||
ReferenceCountUtil.release(fastOpenRequest);
|
||||
throw ExceptionUtils.wrap(throwable);
|
||||
}
|
||||
final boolean valid = storedPublicKey
|
||||
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
|
||||
.orElse(false);
|
||||
if (!valid) {
|
||||
throw ExceptionUtils.wrap(new ClientAuthenticationException());
|
||||
}
|
||||
return new HandshakeResult(
|
||||
fastOpenRequest,
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
|
||||
}, context.executor());
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a {@link UUID} out of bytes, advancing the readerIdx by 16
|
||||
*
|
||||
* @param bytes The {@link ByteBuf} to read from
|
||||
* @return The parsed UUID
|
||||
* @throws NoiseHandshakeException If a UUID could not be parsed from bytes
|
||||
*/
|
||||
private UUID parseUUID(final ByteBuf bytes) throws NoiseHandshakeException {
|
||||
if (bytes.readableBytes() < 16) {
|
||||
throw new NoiseHandshakeException("Could not parse account identifier");
|
||||
}
|
||||
return new UUID(bytes.readLong(), bytes.readLong());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,195 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherState;
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
||||
/**
|
||||
* A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts
|
||||
* inbound messages, and encrypts outbound messages
|
||||
*/
|
||||
abstract class NoiseHandler extends ChannelDuplexHandler {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
|
||||
|
||||
private enum State {
|
||||
// Waiting for handshake to complete
|
||||
HANDSHAKE,
|
||||
// Can freely exchange encrypted noise messages on an established session
|
||||
TRANSPORT,
|
||||
// Finished with error
|
||||
ERROR
|
||||
}
|
||||
|
||||
private final NoiseHandshakeHelper handshakeHelper;
|
||||
|
||||
private State state = State.HANDSHAKE;
|
||||
private CipherStatePair cipherStatePair;
|
||||
|
||||
NoiseHandler(NoiseHandshakeHelper handshakeHelper) {
|
||||
this.handshakeHelper = handshakeHelper;
|
||||
}
|
||||
|
||||
/**
|
||||
* The result of processing an initiator handshake payload
|
||||
*
|
||||
* @param fastOpenRequest A fast-open request included in the handshake. If none was present, this should be an
|
||||
* empty ByteBuf
|
||||
* @param authenticatedDevice If present, the successfully authenticated initiator identity
|
||||
*/
|
||||
record HandshakeResult(ByteBuf fastOpenRequest, Optional<AuthenticatedDevice> authenticatedDevice) {}
|
||||
|
||||
/**
|
||||
* Parse and potentially authenticate the initiator handshake message
|
||||
*
|
||||
* @param context A {@link ChannelHandlerContext}
|
||||
* @param initiatorPublicKey The initiator's static public key, if a handshake pattern that includes it was used
|
||||
* @param handshakePayload The handshake payload provided in the initiator message
|
||||
* @return A {@link HandshakeResult} that includes an authenticated device and a parsed fast-open request if one was
|
||||
* present in the handshake payload.
|
||||
* @throws NoiseHandshakeException If the handshake payload was invalid
|
||||
* @throws ClientAuthenticationException If the initiatorPublicKey could not be authenticated
|
||||
*/
|
||||
abstract CompletableFuture<HandshakeResult> handleHandshakePayload(
|
||||
final ChannelHandlerContext context,
|
||||
final Optional<byte[]> initiatorPublicKey,
|
||||
final ByteBuf handshakePayload) throws NoiseHandshakeException, ClientAuthenticationException;
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
try {
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||
// We'll need to copy it to a heap buffer.
|
||||
handleInboundMessage(context, ByteBufUtil.getBytes(frame.content()));
|
||||
} else {
|
||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||
// error
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
fail(context, e);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(message);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleInboundMessage(final ChannelHandlerContext context, final byte[] frameBytes)
|
||||
throws NoiseHandshakeException, ShortBufferException, BadPaddingException, ClientAuthenticationException {
|
||||
switch (state) {
|
||||
|
||||
// Got an initiator handshake message
|
||||
case HANDSHAKE -> {
|
||||
final ByteBuf payload = handshakeHelper.read(frameBytes);
|
||||
handleHandshakePayload(context, handshakeHelper.remotePublicKey(), payload).whenCompleteAsync(
|
||||
(result, throwable) -> {
|
||||
if (state == State.ERROR) {
|
||||
return;
|
||||
}
|
||||
if (throwable != null) {
|
||||
fail(context, ExceptionUtils.unwrap(throwable));
|
||||
return;
|
||||
}
|
||||
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(result.authenticatedDevice()));
|
||||
|
||||
// Now that we've authenticated, write the handshake response
|
||||
byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES);
|
||||
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
|
||||
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
|
||||
this.state = State.TRANSPORT;
|
||||
this.cipherStatePair = handshakeHelper.getHandshakeState().split();
|
||||
if (result.fastOpenRequest().isReadable()) {
|
||||
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
|
||||
// encrypt the response when the server writes back through us
|
||||
context.fireChannelRead(result.fastOpenRequest());
|
||||
} else {
|
||||
ReferenceCountUtil.release(result.fastOpenRequest());
|
||||
}
|
||||
}, context.executor());
|
||||
}
|
||||
|
||||
// Got a client message that should be decrypted and forwarded
|
||||
case TRANSPORT -> {
|
||||
final CipherState cipherState = cipherStatePair.getReceiver();
|
||||
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
|
||||
final int plaintextLength = cipherState.decryptWithAd(null,
|
||||
frameBytes, 0,
|
||||
frameBytes, 0,
|
||||
frameBytes.length);
|
||||
|
||||
// Forward the decrypted plaintext along
|
||||
context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength));
|
||||
}
|
||||
|
||||
// The session is already in an error state, drop the message
|
||||
case ERROR -> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the state to the error state (so subsequent messages fast-fail) and propagate the failure reason on the
|
||||
* context
|
||||
*/
|
||||
private void fail(final ChannelHandlerContext context, final Throwable cause) {
|
||||
this.state = State.ERROR;
|
||||
context.fireExceptionCaught(cause);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
|
||||
throws Exception {
|
||||
if (message instanceof ByteBuf plaintext) {
|
||||
try {
|
||||
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
|
||||
final CipherState cipherState = cipherStatePair.getSender();
|
||||
final int plaintextLength = plaintext.readableBytes();
|
||||
|
||||
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
|
||||
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
|
||||
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
|
||||
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
|
||||
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
|
||||
|
||||
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
||||
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
||||
|
||||
context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
|
||||
|
||||
} finally {
|
||||
ReferenceCountUtil.release(plaintext);
|
||||
}
|
||||
} else {
|
||||
if (!(message instanceof WebSocketFrame)) {
|
||||
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
|
||||
// get issued in response to exceptions)
|
||||
log.warn("Unexpected object in pipeline: {}", message);
|
||||
}
|
||||
context.write(message, promise);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* An event that indicates that a Noise handshake has completed, possibly authenticating a caller in the process.
|
||||
*
|
||||
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
|
||||
* type that performs authentication
|
||||
*/
|
||||
record NoiseHandshakeCompleteEvent(Optional<AuthenticatedDevice> authenticatedDevice) {
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.Optional;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
/**
|
||||
* Helper for the responder of a 2-message handshake with a pre-shared responder static key
|
||||
*/
|
||||
class NoiseHandshakeHelper {
|
||||
|
||||
private final static int AEAD_TAG_LENGTH = 16;
|
||||
private final static int KEY_LENGTH = 32;
|
||||
|
||||
private final HandshakePattern handshakePattern;
|
||||
private final ECKeyPair serverStaticKeyPair;
|
||||
private final HandshakeState handshakeState;
|
||||
|
||||
NoiseHandshakeHelper(HandshakePattern handshakePattern, ECKeyPair serverStaticKeyPair) {
|
||||
this.handshakePattern = handshakePattern;
|
||||
this.serverStaticKeyPair = serverStaticKeyPair;
|
||||
try {
|
||||
this.handshakeState = new HandshakeState(handshakePattern.protocol(), HandshakeState.RESPONDER);
|
||||
} catch (final NoSuchAlgorithmException e) {
|
||||
throw new AssertionError("Unsupported Noise algorithm: " + handshakePattern.protocol(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the length of the initiator's keys
|
||||
*
|
||||
* @return length of the handshake message sent by the remote party (the initiator) not including the payload
|
||||
*/
|
||||
private int initiatorHandshakeMessageKeyLength() {
|
||||
return switch (handshakePattern) {
|
||||
// ephemeral key, static key (encrypted), AEAD tag for static key
|
||||
case IK -> KEY_LENGTH + KEY_LENGTH + AEAD_TAG_LENGTH;
|
||||
// ephemeral key only
|
||||
case NK -> KEY_LENGTH;
|
||||
};
|
||||
}
|
||||
|
||||
HandshakeState getHandshakeState() {
|
||||
return this.handshakeState;
|
||||
}
|
||||
|
||||
ByteBuf read(byte[] remoteHandshakeMessage) throws NoiseHandshakeException {
|
||||
if (handshakeState.getAction() != HandshakeState.NO_ACTION) {
|
||||
throw new NoiseHandshakeException("Cannot send more data before handshake is complete");
|
||||
}
|
||||
|
||||
// Length for an empty payload
|
||||
final int minMessageLength = initiatorHandshakeMessageKeyLength() + AEAD_TAG_LENGTH;
|
||||
if (remoteHandshakeMessage.length < minMessageLength || remoteHandshakeMessage.length > Noise.MAX_PACKET_LEN) {
|
||||
throw new NoiseHandshakeException("Unexpected ephemeral key message length");
|
||||
}
|
||||
|
||||
final int payloadLength = remoteHandshakeMessage.length - initiatorHandshakeMessageKeyLength() - AEAD_TAG_LENGTH;
|
||||
|
||||
// Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is
|
||||
// making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key
|
||||
// from the supplied private key (and will in fact overwrite any previously-set public key when setting a private
|
||||
// key), so we just set the private key here.
|
||||
handshakeState.getLocalKeyPair().setPrivateKey(serverStaticKeyPair.getPrivateKey().serialize(), 0);
|
||||
handshakeState.start();
|
||||
|
||||
int payloadBytesRead;
|
||||
|
||||
try {
|
||||
payloadBytesRead = handshakeState.readMessage(remoteHandshakeMessage, 0, remoteHandshakeMessage.length,
|
||||
remoteHandshakeMessage, 0);
|
||||
} catch (final ShortBufferException e) {
|
||||
// This should never happen since we're checking the length of the frame up front
|
||||
throw new NoiseHandshakeException("Unexpected client payload");
|
||||
} catch (final BadPaddingException e) {
|
||||
// We aren't using padding but may get this error if the AEAD tag does not match the encrypted client static key
|
||||
// or payload
|
||||
throw new NoiseHandshakeException("Invalid keys or payload");
|
||||
}
|
||||
if (payloadBytesRead != payloadLength) {
|
||||
throw new NoiseHandshakeException(
|
||||
"Unexpected payload length, required " + payloadLength + " but got " + payloadBytesRead);
|
||||
}
|
||||
return Unpooled.wrappedBuffer(remoteHandshakeMessage, 0, payloadBytesRead);
|
||||
}
|
||||
|
||||
byte[] write(byte[] payload) {
|
||||
if (handshakeState.getAction() != HandshakeState.WRITE_MESSAGE) {
|
||||
throw new IllegalStateException("Cannot send data before handshake is complete");
|
||||
}
|
||||
|
||||
// Currently only support handshake patterns where the server static key is known
|
||||
// Send our ephemeral key and the response to the initiator with the encrypted payload
|
||||
final byte[] response = new byte[KEY_LENGTH + payload.length + AEAD_TAG_LENGTH];
|
||||
try {
|
||||
int written = handshakeState.writeMessage(response, 0, payload, 0, payload.length);
|
||||
if (written != response.length) {
|
||||
throw new IllegalStateException("Unexpected handshake response length");
|
||||
}
|
||||
return response;
|
||||
} catch (final ShortBufferException e) {
|
||||
// This should never happen for messages of known length that we control
|
||||
throw new IllegalStateException("Key material buffer was too short for message", e);
|
||||
}
|
||||
}
|
||||
|
||||
Optional<byte[]> remotePublicKey() {
|
||||
return Optional.ofNullable(handshakeState.getRemotePublicKey()).map(dhstate -> {
|
||||
final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()];
|
||||
handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0);
|
||||
return publicKeyFromClient;
|
||||
});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
|
||||
/**
|
||||
* An event that indicates that an identity of a noise handshake initiator has been determined. If the initiator is
|
||||
* connecting anonymously, the identity is empty, otherwise it will be present and already authenticated.
|
||||
*
|
||||
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
|
||||
* type that performs authentication
|
||||
*/
|
||||
record NoiseIdentityDeterminedEvent(Optional<AuthenticatedDevice> authenticatedDevice) {}
|
|
@ -1,40 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import java.util.Optional;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
/**
|
||||
* A Noise NX handler handles the responder side of a Noise NX handshake.
|
||||
*/
|
||||
class NoiseNXHandshakeHandler extends AbstractNoiseHandshakeHandler {
|
||||
|
||||
static final String NOISE_PROTOCOL_NAME = "Noise_NX_25519_ChaChaPoly_BLAKE2b";
|
||||
|
||||
NoiseNXHandshakeHandler(final ECKeyPair ecKeyPair, final byte[] publicKeySignature) {
|
||||
super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
try {
|
||||
handleEphemeralKeyMessage(context, frame);
|
||||
} finally {
|
||||
frame.release();
|
||||
}
|
||||
|
||||
// All we need to do is accept the client's ephemeral key and send our own static keys; after that, we can consider
|
||||
// the handshake complete
|
||||
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
|
||||
context.pipeline().replace(NoiseNXHandshakeHandler.this, null, new NoiseTransportHandler(getHandshakeState().split()));
|
||||
} else {
|
||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||
// error
|
||||
ReferenceCountUtil.release(message);
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -53,7 +53,6 @@ public class NoiseWebSocketTunnelServer implements Managed {
|
|||
final ClientConnectionManager clientConnectionManager,
|
||||
final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final byte[] publicKeySignature,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress,
|
||||
final String recognizedProxySecret) throws SSLException {
|
||||
|
@ -107,7 +106,7 @@ public class NoiseWebSocketTunnelServer implements Managed {
|
|||
.addLast(new RejectUnsupportedMessagesHandler())
|
||||
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
|
||||
// a WebSocket handshake has been completed
|
||||
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature, recognizedProxySecret))
|
||||
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret))
|
||||
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||
// once the Noise handshake has completed
|
||||
.addLast(new EstablishLocalGrpcConnectionHandler(clientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
||||
|
|
|
@ -1,178 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.security.MessageDigest;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
/**
|
||||
* A Noise XX handler handles the responder side of a Noise XX handshake. This implementation expects clients to send
|
||||
* identifying information (an account identifier and device ID) as an additional payload when sending its static key
|
||||
* material. It compares the static public key against the stored public key for the identified device asynchronously,
|
||||
* buffering traffic from the client until the authentication check completes.
|
||||
*/
|
||||
class NoiseXXHandshakeHandler extends AbstractNoiseHandshakeHandler {
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private AuthenticationState authenticationState = AuthenticationState.GET_EPHEMERAL_KEY;
|
||||
|
||||
private final List<BinaryWebSocketFrame> pendingInboundFrames = new ArrayList<>();
|
||||
|
||||
static final String NOISE_PROTOCOL_NAME = "Noise_XX_25519_ChaChaPoly_BLAKE2b";
|
||||
|
||||
// When the client sends its static key message, we expect:
|
||||
//
|
||||
// - A 32-byte encrypted static public key
|
||||
// - A 16-byte AEAD tag for the static key
|
||||
// - 17 bytes of identity data in the message payload (a UUID and a one-byte device ID)
|
||||
// - A 16-byte AEAD tag for the identity payload
|
||||
private static final int EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH = 81;
|
||||
|
||||
private enum AuthenticationState {
|
||||
GET_EPHEMERAL_KEY,
|
||||
GET_STATIC_KEY,
|
||||
CHECK_PUBLIC_KEY,
|
||||
ERROR
|
||||
}
|
||||
|
||||
public NoiseXXHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final byte[] publicKeySignature) {
|
||||
|
||||
super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature);
|
||||
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
try {
|
||||
switch (authenticationState) {
|
||||
case GET_EPHEMERAL_KEY -> {
|
||||
try {
|
||||
handleEphemeralKeyMessage(context, frame);
|
||||
authenticationState = AuthenticationState.GET_STATIC_KEY;
|
||||
} finally {
|
||||
frame.release();
|
||||
}
|
||||
}
|
||||
case GET_STATIC_KEY -> {
|
||||
try {
|
||||
handleStaticKey(context, frame);
|
||||
authenticationState = AuthenticationState.CHECK_PUBLIC_KEY;
|
||||
} finally {
|
||||
frame.release();
|
||||
}
|
||||
}
|
||||
case CHECK_PUBLIC_KEY -> {
|
||||
// Buffer any inbound traffic until we've finished checking the client's public key
|
||||
pendingInboundFrames.add(frame);
|
||||
}
|
||||
case ERROR -> {
|
||||
// If authentication has failed for any reason, just discard inbound traffic until the channel closes
|
||||
frame.release();
|
||||
}
|
||||
}
|
||||
} catch (final ShortBufferException e) {
|
||||
authenticationState = AuthenticationState.ERROR;
|
||||
throw new NoiseHandshakeException("Unexpected payload length");
|
||||
} catch (final BadPaddingException e) {
|
||||
authenticationState = AuthenticationState.ERROR;
|
||||
throw new ClientAuthenticationException();
|
||||
}
|
||||
} else {
|
||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||
// error
|
||||
ReferenceCountUtil.release(message);
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleStaticKey(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
|
||||
throws NoiseHandshakeException, ShortBufferException, BadPaddingException {
|
||||
|
||||
if (frame.content().readableBytes() != EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH) {
|
||||
throw new NoiseHandshakeException("Unexpected client static key message length");
|
||||
}
|
||||
|
||||
final HandshakeState handshakeState = getHandshakeState();
|
||||
|
||||
// The websocket frame will have come right off the wire, and so needs to be copied from a non-array-backed direct
|
||||
// buffer into a heap buffer.
|
||||
final byte[] staticKeyAndClientIdentityMessage = ByteBufUtil.getBytes(frame.content());
|
||||
|
||||
// The payload from the client should be a UUID (16 bytes) followed by a device ID (1 byte)
|
||||
final byte[] payload = new byte[17];
|
||||
|
||||
final UUID accountIdentifier;
|
||||
final byte deviceId;
|
||||
|
||||
final int payloadBytesRead = handshakeState.readMessage(staticKeyAndClientIdentityMessage,
|
||||
0, staticKeyAndClientIdentityMessage.length, payload, 0);
|
||||
|
||||
if (payloadBytesRead != 17) {
|
||||
throw new NoiseHandshakeException("Unexpected identity payload length");
|
||||
}
|
||||
|
||||
try {
|
||||
accountIdentifier = UUIDUtil.fromBytes(payload, 0);
|
||||
} catch (final IllegalArgumentException e) {
|
||||
throw new NoiseHandshakeException("Could not parse account identifier");
|
||||
}
|
||||
|
||||
deviceId = payload[16];
|
||||
|
||||
// Verify the identity of the caller by comparing the submitted static public key against the stored public key for
|
||||
// the identified device
|
||||
clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)
|
||||
.whenCompleteAsync((maybePublicKey, throwable) -> maybePublicKey.ifPresentOrElse(storedPublicKey -> {
|
||||
final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()];
|
||||
handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0);
|
||||
|
||||
if (MessageDigest.isEqual(publicKeyFromClient, storedPublicKey.getPublicKeyBytes())) {
|
||||
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))));
|
||||
|
||||
context.pipeline().addAfter(context.name(), null, new NoiseTransportHandler(handshakeState.split()));
|
||||
|
||||
// Flush any buffered reads
|
||||
pendingInboundFrames.forEach(context::fireChannelRead);
|
||||
pendingInboundFrames.clear();
|
||||
|
||||
context.pipeline().remove(NoiseXXHandshakeHandler.this);
|
||||
} else {
|
||||
// We found a key, but it doesn't match what the caller submitted
|
||||
context.fireExceptionCaught(new ClientAuthenticationException());
|
||||
authenticationState = AuthenticationState.ERROR;
|
||||
}
|
||||
},
|
||||
() -> {
|
||||
// We couldn't find a key for the identified account/device
|
||||
context.fireExceptionCaught(new ClientAuthenticationException());
|
||||
authenticationState = AuthenticationState.ERROR;
|
||||
}),
|
||||
context.executor());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
super.handlerRemoved(context);
|
||||
|
||||
pendingInboundFrames.forEach(BinaryWebSocketFrame::release);
|
||||
pendingInboundFrames.clear();
|
||||
}
|
||||
}
|
|
@ -31,7 +31,6 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
|||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private final ECKeyPair ecKeyPair;
|
||||
private final byte[] publicKeySignature;
|
||||
|
||||
private final byte[] recognizedProxySecret;
|
||||
|
||||
|
@ -45,12 +44,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
|||
|
||||
WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final byte[] publicKeySignature,
|
||||
final String recognizedProxySecret) {
|
||||
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
this.publicKeySignature = publicKeySignature;
|
||||
|
||||
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
|
||||
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
|
||||
|
@ -84,10 +81,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
|||
|
||||
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
|
||||
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH ->
|
||||
new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature);
|
||||
new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair);
|
||||
|
||||
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH ->
|
||||
new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature);
|
||||
new NoiseAnonymousHandler(ecKeyPair);
|
||||
|
||||
default -> {
|
||||
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
|
||||
|
|
|
@ -1,94 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
|
||||
abstract class AbstractNoiseClientHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final ECPublicKey rootPublicKey;
|
||||
|
||||
private final HandshakeState handshakeState;
|
||||
|
||||
AbstractNoiseClientHandler(final ECPublicKey rootPublicKey) {
|
||||
this.rootPublicKey = rootPublicKey;
|
||||
|
||||
try {
|
||||
handshakeState = new HandshakeState(getNoiseProtocolName(), HandshakeState.INITIATOR);
|
||||
} catch (final NoSuchAlgorithmException e) {
|
||||
throw new AssertionError("Unsupported Noise algorithm: " + getNoiseProtocolName(), e);
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract String getNoiseProtocolName();
|
||||
|
||||
protected abstract void startHandshake();
|
||||
|
||||
protected HandshakeState getHandshakeState() {
|
||||
return handshakeState;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
|
||||
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
|
||||
startHandshake();
|
||||
|
||||
final byte[] ephemeralKeyMessage = new byte[32];
|
||||
handshakeState.writeMessage(ephemeralKeyMessage, 0, null, 0, 0);
|
||||
|
||||
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessage)))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
super.userEventTriggered(context, event);
|
||||
}
|
||||
|
||||
protected void handleServerStaticKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
|
||||
throws NoiseHandshakeException {
|
||||
|
||||
// The frame is coming right off the wire and so will be a direct buffer not backed by an array; copy it to a heap
|
||||
// buffer so we can Noise at it.
|
||||
final ByteBuf keyMaterialBuffer = context.alloc().heapBuffer(frame.content().readableBytes());
|
||||
final byte[] serverPublicKeySignature = new byte[64];
|
||||
|
||||
try {
|
||||
frame.content().readBytes(keyMaterialBuffer);
|
||||
|
||||
final int payloadBytesRead =
|
||||
handshakeState.readMessage(keyMaterialBuffer.array(), keyMaterialBuffer.arrayOffset(), keyMaterialBuffer.readableBytes(), serverPublicKeySignature, 0);
|
||||
|
||||
if (payloadBytesRead != 64) {
|
||||
throw new NoiseHandshakeException("Unexpected signature length");
|
||||
}
|
||||
} catch (final ShortBufferException e) {
|
||||
throw new NoiseHandshakeException("Unexpected signature length");
|
||||
} catch (final BadPaddingException e) {
|
||||
throw new NoiseHandshakeException("Invalid keys");
|
||||
} finally {
|
||||
keyMaterialBuffer.release();
|
||||
}
|
||||
|
||||
final byte[] serverPublicKey = new byte[32];
|
||||
handshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||
|
||||
if (!rootPublicKey.verifySignature(serverPublicKey, serverPublicKeySignature)) {
|
||||
throw new NoiseHandshakeException("Invalid server public key signature");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) throws Exception {
|
||||
handshakeState.destroy();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,257 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.crypto.AEADBadTagException;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
protected ECKeyPair serverKeyPair;
|
||||
|
||||
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
|
||||
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
private static class PongHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
|
||||
try {
|
||||
if (msg instanceof ByteBuf bb) {
|
||||
if (new String(ByteBufUtil.getBytes(bb)).equals("ping")) {
|
||||
ctx.writeAndFlush(Unpooled.wrappedBuffer("pong".getBytes()))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unexpected message: " + new String(ByteBufUtil.getBytes(bb)));
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unexpected message type: " + msg);
|
||||
}
|
||||
} finally {
|
||||
ReferenceCountUtil.release(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Nullable
|
||||
private NoiseIdentityDeterminedEvent handshakeCompleteEvent = null;
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
|
||||
handshakeCompleteEvent = noiseIdentityDeterminedEvent;
|
||||
context.pipeline().addAfter(context.name(), null, new PongHandler());
|
||||
context.pipeline().remove(NoiseHandshakeCompleteHandler.class);
|
||||
} else {
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public NoiseIdentityDeterminedEvent getHandshakeCompleteEvent() {
|
||||
return handshakeCompleteEvent;
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
serverKeyPair = Curve.generateKeyPair();
|
||||
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
|
||||
embeddedChannel = new EmbeddedChannel(getHandler(serverKeyPair), noiseHandshakeCompleteHandler);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
embeddedChannel.close();
|
||||
}
|
||||
|
||||
protected EmbeddedChannel getEmbeddedChannel() {
|
||||
return embeddedChannel;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
protected NoiseIdentityDeterminedEvent getNoiseHandshakeCompleteEvent() {
|
||||
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
|
||||
}
|
||||
|
||||
protected abstract ChannelHandler getHandler(final ECKeyPair serverKeyPair);
|
||||
|
||||
protected abstract CipherStatePair doHandshake() throws Throwable;
|
||||
|
||||
/**
|
||||
* Read a message from the embedded channel and deserialize it with the provided client cipher state. If there are no
|
||||
* waiting messages in the channel, return null.
|
||||
*/
|
||||
byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException {
|
||||
final BinaryWebSocketFrame responseFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
||||
if (responseFrame == null) {
|
||||
return null;
|
||||
}
|
||||
final byte[] plaintext = new byte[responseFrame.content().readableBytes() - 16];
|
||||
final int read = clientCipherPair.getReceiver().decryptWithAd(null,
|
||||
ByteBufUtil.getBytes(responseFrame.content()), 0,
|
||||
plaintext, 0,
|
||||
responseFrame.content().readableBytes());
|
||||
assertEquals(read, plaintext.length);
|
||||
return plaintext;
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void handleInvalidInitialMessage() throws InterruptedException {
|
||||
final byte[] contentBytes = new byte[17];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await();
|
||||
|
||||
assertFalse(writeFuture.isSuccess());
|
||||
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
|
||||
assertEquals(0, content.refCnt());
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
|
||||
final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7];
|
||||
|
||||
for (int i = 0; i < frames.length; i++) {
|
||||
final byte[] contentBytes = new byte[17];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
|
||||
|
||||
embeddedChannel.writeOneInbound(frames[i]).await();
|
||||
}
|
||||
|
||||
for (final BinaryWebSocketFrame frame : frames) {
|
||||
assertEquals(0, frame.refCnt());
|
||||
}
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleNonWebSocketBinaryFrame() throws Throwable {
|
||||
final byte[] contentBytes = new byte[17];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
final ByteBuf message = Unpooled.wrappedBuffer(contentBytes);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
|
||||
|
||||
assertFalse(writeFuture.isSuccess());
|
||||
assertInstanceOf(IllegalArgumentException.class, writeFuture.cause());
|
||||
assertEquals(0, message.refCnt());
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void channelRead() throws Throwable {
|
||||
final CipherStatePair clientCipherStatePair = doHandshake();
|
||||
final byte[] plaintext = "ping".getBytes(StandardCharsets.UTF_8);
|
||||
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
|
||||
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
|
||||
|
||||
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext));
|
||||
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
|
||||
assertEquals(0, ciphertextFrame.refCnt());
|
||||
|
||||
final byte[] response = readNextPlaintext(clientCipherStatePair);
|
||||
assertArrayEquals("pong".getBytes(StandardCharsets.UTF_8), response);
|
||||
}
|
||||
|
||||
@Test
|
||||
void channelReadBadCiphertext() throws Throwable {
|
||||
doHandshake();
|
||||
final byte[] bogusCiphertext = new byte[32];
|
||||
io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext);
|
||||
|
||||
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext));
|
||||
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
|
||||
|
||||
assertEquals(0, ciphertextFrame.refCnt());
|
||||
assertFalse(readCiphertextFuture.isSuccess());
|
||||
assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause());
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void channelReadUnexpectedMessageType() throws Throwable {
|
||||
doHandshake();
|
||||
final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await();
|
||||
|
||||
assertFalse(readFuture.isSuccess());
|
||||
assertInstanceOf(IllegalArgumentException.class, readFuture.cause());
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void write() throws Throwable {
|
||||
final CipherStatePair clientCipherStatePair = doHandshake();
|
||||
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
|
||||
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext);
|
||||
|
||||
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
|
||||
assertTrue(writePlaintextFuture.await().isSuccess());
|
||||
assertEquals(0, plaintextBuffer.refCnt());
|
||||
|
||||
final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
||||
assertNotNull(ciphertextFrame);
|
||||
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
||||
|
||||
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
|
||||
ciphertextFrame.release();
|
||||
|
||||
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
|
||||
clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length);
|
||||
|
||||
assertArrayEquals(plaintext, decryptedPlaintext);
|
||||
}
|
||||
|
||||
@Test
|
||||
void writeUnexpectedMessageType() throws Throwable {
|
||||
doHandshake();
|
||||
final Object unexpectedMessaged = new Object();
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged);
|
||||
assertTrue(writeFuture.await().isSuccess());
|
||||
|
||||
assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll());
|
||||
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
||||
}
|
||||
|
||||
}
|
|
@ -1,141 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import javax.annotation.Nullable;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
|
||||
abstract class AbstractNoiseHandshakeHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
private ECPublicKey rootPublicKey;
|
||||
|
||||
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
|
||||
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Nullable
|
||||
private NoiseHandshakeCompleteEvent handshakeCompleteEvent = null;
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
|
||||
handshakeCompleteEvent = noiseHandshakeCompleteEvent;
|
||||
} else {
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public NoiseHandshakeCompleteEvent getHandshakeCompleteEvent() {
|
||||
return handshakeCompleteEvent;
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
final ECKeyPair rootKeyPair = Curve.generateKeyPair();
|
||||
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
|
||||
|
||||
rootPublicKey = rootKeyPair.getPublicKey();
|
||||
|
||||
final byte[] serverPublicKeySignature =
|
||||
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes());
|
||||
|
||||
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
|
||||
|
||||
embeddedChannel =
|
||||
new EmbeddedChannel(getHandler(serverKeyPair, serverPublicKeySignature), noiseHandshakeCompleteHandler);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
embeddedChannel.close();
|
||||
}
|
||||
|
||||
protected EmbeddedChannel getEmbeddedChannel() {
|
||||
return embeddedChannel;
|
||||
}
|
||||
|
||||
protected ECPublicKey getRootPublicKey() {
|
||||
return rootPublicKey;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
protected NoiseHandshakeCompleteEvent getNoiseHandshakeCompleteEvent() {
|
||||
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
|
||||
}
|
||||
|
||||
protected abstract AbstractNoiseHandshakeHandler getHandler(final ECKeyPair serverKeyPair, final byte[] serverPublicKeySignature);
|
||||
|
||||
@Test
|
||||
void handleInvalidInitialMessage() throws InterruptedException {
|
||||
final byte[] contentBytes = new byte[17];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await();
|
||||
|
||||
assertFalse(writeFuture.isSuccess());
|
||||
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
|
||||
assertEquals(0, content.refCnt());
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
|
||||
final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7];
|
||||
|
||||
for (int i = 0; i < frames.length; i++) {
|
||||
final byte[] contentBytes = new byte[17];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
|
||||
|
||||
embeddedChannel.writeOneInbound(frames[i]).await();
|
||||
}
|
||||
|
||||
for (final BinaryWebSocketFrame frame : frames) {
|
||||
assertEquals(0, frame.refCnt());
|
||||
}
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleNonWebSocketBinaryFrame() throws InterruptedException {
|
||||
final byte[] contentBytes = new byte[17];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
final ByteBuf message = Unpooled.wrappedBuffer(contentBytes);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
|
||||
|
||||
assertFalse(writeFuture.isSuccess());
|
||||
assertInstanceOf(IllegalArgumentException.class, writeFuture.cause());
|
||||
assertEquals(0, message.refCnt());
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ package org.whispersystems.textsecuregcm.grpc.net;
|
|||
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.bootstrap.Bootstrap;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
|
@ -19,6 +20,7 @@ import io.netty.handler.ssl.SslContextBuilder;
|
|||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.SocketAddress;
|
||||
import java.net.URI;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -29,6 +31,10 @@ import javax.net.ssl.SSLException;
|
|||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
|
||||
/**
|
||||
* Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote
|
||||
* gRPC server
|
||||
*/
|
||||
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final boolean useTls;
|
||||
|
@ -36,13 +42,15 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
private final URI websocketUri;
|
||||
private final boolean authenticated;
|
||||
@Nullable private final ECKeyPair ecKeyPair;
|
||||
private final ECPublicKey rootPublicKey;
|
||||
private final ECPublicKey serverPublicKey;
|
||||
@Nullable private final UUID accountIdentifier;
|
||||
private final byte deviceId;
|
||||
private final HttpHeaders headers;
|
||||
private final SocketAddress remoteServerAddress;
|
||||
private final WebSocketCloseListener webSocketCloseListener;
|
||||
@Nullable private final Supplier<HAProxyMessage> proxyMessageSupplier;
|
||||
// If provided, will be sent with the payload in the noise handshake
|
||||
private final byte[] fastOpenRequest;
|
||||
|
||||
private final List<Object> pendingReads = new ArrayList<>();
|
||||
|
||||
|
@ -54,26 +62,28 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
final URI websocketUri,
|
||||
final boolean authenticated,
|
||||
@Nullable final ECKeyPair ecKeyPair,
|
||||
final ECPublicKey rootPublicKey,
|
||||
final ECPublicKey serverPublicKey,
|
||||
@Nullable final UUID accountIdentifier,
|
||||
final byte deviceId,
|
||||
final HttpHeaders headers,
|
||||
final SocketAddress remoteServerAddress,
|
||||
final WebSocketCloseListener webSocketCloseListener,
|
||||
@Nullable Supplier<HAProxyMessage> proxyMessageSupplier) {
|
||||
@Nullable Supplier<HAProxyMessage> proxyMessageSupplier,
|
||||
@Nullable byte[] fastOpenRequest) {
|
||||
|
||||
this.useTls = useTls;
|
||||
this.trustedServerCertificate = trustedServerCertificate;
|
||||
this.websocketUri = websocketUri;
|
||||
this.authenticated = authenticated;
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
this.rootPublicKey = rootPublicKey;
|
||||
this.serverPublicKey = serverPublicKey;
|
||||
this.accountIdentifier = accountIdentifier;
|
||||
this.deviceId = deviceId;
|
||||
this.headers = headers;
|
||||
this.remoteServerAddress = remoteServerAddress;
|
||||
this.webSocketCloseListener = webSocketCloseListener;
|
||||
this.proxyMessageSupplier = proxyMessageSupplier;
|
||||
this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -104,6 +114,10 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
channel.pipeline().addLast(sslContextBuilder.build().newHandler(channel.alloc()));
|
||||
}
|
||||
|
||||
final NoiseClientHandshakeHelper helper = authenticated
|
||||
? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair)
|
||||
: NoiseClientHandshakeHelper.NK(serverPublicKey);
|
||||
|
||||
channel.pipeline()
|
||||
.addLast(new HttpClientCodec())
|
||||
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
|
||||
|
@ -118,22 +132,24 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
Noise.MAX_PACKET_LEN,
|
||||
10_000))
|
||||
.addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener))
|
||||
.addLast(authenticated
|
||||
? new NoiseXXClientHandshakeHandler(ecKeyPair, rootPublicKey, accountIdentifier, deviceId)
|
||||
: new NoiseNXClientHandshakeHandler(rootPublicKey))
|
||||
// Listens for a Websocket HANDSHAKE_COMPLETE and begins the noise handshake when it is done
|
||||
.addLast(new NoiseClientHandshakeHandler(helper, initialPayload()))
|
||||
.addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
|
||||
throws Exception {
|
||||
if (event instanceof NoiseHandshakeCompleteEvent) {
|
||||
if (event instanceof NoiseClientHandshakeCompleteEvent handshakeCompleteEvent) {
|
||||
remoteContext.pipeline()
|
||||
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
|
||||
|
||||
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
|
||||
|
||||
// If there was a payload response on the handshake, write it back to our gRPC client
|
||||
handshakeCompleteEvent.fastResponse().ifPresent(plaintext ->
|
||||
localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext)));
|
||||
|
||||
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
|
||||
pendingReads.forEach(localContext::fireChannelRead);
|
||||
pendingReads.clear();
|
||||
|
||||
localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
|
||||
}
|
||||
|
||||
|
@ -165,4 +181,18 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
|||
pendingReads.forEach(ReferenceCountUtil::release);
|
||||
pendingReads.clear();
|
||||
}
|
||||
|
||||
private byte[] initialPayload() {
|
||||
if (!authenticated) {
|
||||
return fastOpenRequest;
|
||||
}
|
||||
|
||||
final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length);
|
||||
bb.putLong(accountIdentifier.getMostSignificantBits());
|
||||
bb.putLong(accountIdentifier.getLeastSignificantBits());
|
||||
bb.put(deviceId);
|
||||
bb.put(fastOpenRequest);
|
||||
bb.flip();
|
||||
return bb.array();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
|
||||
record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {}
|
|
@ -0,0 +1,184 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandler;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.ByteToMessageDecoder;
|
||||
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HexFormat;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* The noise tunnel streams bytes out of a gRPC client through noise and to a remote server. The server supports a "fast
|
||||
* open" optimization where the client can send a request along with the noise handshake. There's no direct way to
|
||||
* extract the request boundaries from the gRPC client's byte-stream, so {@link Http2Buffering#handler()} provides an
|
||||
* inbound pipeline handler that will parse the byte-stream back into HTTP/2 frames and buffer the first request.
|
||||
* <p>
|
||||
* Once an entire request has been buffered, the handler will remove itself from the pipeline and emit a
|
||||
* {@link FastOpenRequestBufferedEvent}
|
||||
*/
|
||||
class Http2Buffering {
|
||||
|
||||
/**
|
||||
* Create a pipeline handler that consumes serialized HTTP/2 ByteBufs and emits a fast-open request
|
||||
*/
|
||||
static ChannelInboundHandler handler() {
|
||||
return new Http2PrefaceHandler();
|
||||
}
|
||||
|
||||
private Http2Buffering() {
|
||||
}
|
||||
|
||||
private static class Http2PrefaceHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc7540.html#section-3.5
|
||||
private static final byte[] HTTP2_PREFACE =
|
||||
HexFormat.of().parseHex("505249202a20485454502f322e300d0a0d0a534d0d0a0d0a");
|
||||
private final ByteBuf read = Unpooled.buffer(HTTP2_PREFACE.length, HTTP2_PREFACE.length);
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||
if (message instanceof ByteBuf bb) {
|
||||
bb.readBytes(read);
|
||||
if (read.readableBytes() < HTTP2_PREFACE.length) {
|
||||
// Copied the message into the read buffer, but haven't yet got a full HTTP2 preface. Wait for more.
|
||||
return;
|
||||
}
|
||||
if (!Arrays.equals(read.array(), HTTP2_PREFACE)) {
|
||||
throw new IllegalStateException("HTTP/2 stream must start with HTTP/2 preface");
|
||||
}
|
||||
context.pipeline().replace(this, "http2frame1", new Http2LengthFieldFrameDecoder());
|
||||
context.pipeline().addAfter("http2frame1", "http2frame2", new Http2FrameDecoder());
|
||||
context.pipeline().addAfter("http2frame2", "http2frame3", new Http2FirstRequestHandler());
|
||||
context.fireChannelRead(bb);
|
||||
} else {
|
||||
throw new IllegalStateException("Unexpected message: " + message);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
ReferenceCountUtil.release(read);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private record Http2Frame(ByteBuf bytes, FrameType type, boolean endStream) {
|
||||
|
||||
private static final byte FLAG_END_STREAM = 0x01;
|
||||
|
||||
enum FrameType {
|
||||
SETTINGS,
|
||||
HEADERS,
|
||||
DATA,
|
||||
WINDOW_UPDATE,
|
||||
OTHER;
|
||||
|
||||
static FrameType fromSerializedType(final byte type) {
|
||||
return switch (type) {
|
||||
case 0x00 -> Http2Frame.FrameType.DATA;
|
||||
case 0x01 -> Http2Frame.FrameType.HEADERS;
|
||||
case 0x04 -> Http2Frame.FrameType.SETTINGS;
|
||||
case 0x08 -> Http2Frame.FrameType.WINDOW_UPDATE;
|
||||
default -> Http2Frame.FrameType.OTHER;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit ByteBuf of entire HTTP/2 frame
|
||||
*/
|
||||
private static class Http2LengthFieldFrameDecoder extends LengthFieldBasedFrameDecoder {
|
||||
|
||||
public Http2LengthFieldFrameDecoder() {
|
||||
// Frames are 3 bytes of length, 6 bytes of other header, and then length bytes of payload
|
||||
super(16 * 1024 * 1024, 0, 3, 6, 0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the serialized Http/2 frames into {@link Http2Frame} objects
|
||||
*/
|
||||
private static class Http2FrameDecoder extends ByteToMessageDecoder {
|
||||
|
||||
@Override
|
||||
protected void decode(final ChannelHandlerContext ctx, final ByteBuf in, final List<Object> out) throws Exception {
|
||||
// https://www.rfc-editor.org/rfc/rfc7540.html#section-4.1
|
||||
final Http2Frame.FrameType frameType = Http2Frame.FrameType.fromSerializedType(in.getByte(in.readerIndex() + 3));
|
||||
final boolean endStream = endStream(frameType, in.getByte(in.readerIndex() + 4));
|
||||
out.add(new Http2Frame(in.readBytes(in.readableBytes()), frameType, endStream));
|
||||
}
|
||||
|
||||
boolean endStream(Http2Frame.FrameType frameType, byte flags) {
|
||||
// A gRPC request are packed into HTTP/2 frames like:
|
||||
// HEADERS frame | DATA frame 1 (endStream=0) | ... | DATA frame N (endstream=1)
|
||||
//
|
||||
// Our goal is to get an entire request buffered, so as soon as we see a DATA frame with the end stream flag set
|
||||
// we have a whole request. Note that we could have pieces of multiple requests, but the only thing we care about
|
||||
// is having at least one complete request. In total, we can expect something like:
|
||||
// HTTP-preface | SETTINGS frame | Frames we don't care about ... | DATA (endstream=1)
|
||||
//
|
||||
// The connection isn't 'established' until the server has responded with their own SETTINGS frame with the ack
|
||||
// bit set, but HTTP/2 allows the client to send frames before getting the ACK.
|
||||
if (frameType == Http2Frame.FrameType.DATA) {
|
||||
return (flags & Http2Frame.FLAG_END_STREAM) == Http2Frame.FLAG_END_STREAM;
|
||||
}
|
||||
|
||||
// In theory, at least. Unfortunately, the java gRPC client always waits for the HTTP/2 handshake to complete
|
||||
// (which requires the server sending back the ack) before it actually sends any requests. So if we waited for a
|
||||
// DATA frame, it would never come. The gRPC-java implementation always at least sends a WINDOW_UPDATE, so we
|
||||
// might as well pack that in.
|
||||
return frameType == Http2Frame.FrameType.WINDOW_UPDATE;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Collect HTTP/2 frames until we get an entire "request" to send
|
||||
*/
|
||||
private static class Http2FirstRequestHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
final List<Http2Frame> pendingFrames = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||
if (message instanceof Http2Frame http2Frame) {
|
||||
if (pendingFrames.isEmpty() && http2Frame.type != Http2Frame.FrameType.SETTINGS) {
|
||||
throw new IllegalStateException(
|
||||
"HTTP/2 stream must start with HTTP/2 SETTINGS frame, got " + http2Frame.type);
|
||||
}
|
||||
pendingFrames.add(http2Frame);
|
||||
if (http2Frame.endStream) {
|
||||
// We have a whole "request", emit the first request event and remove the http2 buffering handlers
|
||||
final ByteBuf request = Unpooled.wrappedBuffer(Stream.concat(
|
||||
Stream.of(Unpooled.wrappedBuffer(Http2PrefaceHandler.HTTP2_PREFACE)),
|
||||
pendingFrames.stream().map(Http2Frame::bytes))
|
||||
.toArray(ByteBuf[]::new));
|
||||
pendingFrames.clear();
|
||||
context.pipeline().remove(Http2LengthFieldFrameDecoder.class);
|
||||
context.pipeline().remove(Http2FrameDecoder.class);
|
||||
context.pipeline().remove(this);
|
||||
context.fireUserEventTriggered(new FastOpenRequestBufferedEvent(request));
|
||||
}
|
||||
} else {
|
||||
throw new IllegalStateException("Unexpected message: " + message);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
pendingFrames.forEach(frame -> ReferenceCountUtil.release(frame.bytes()));
|
||||
pendingFrames.clear();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import java.util.Optional;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
|
||||
|
||||
@Override
|
||||
protected NoiseAnonymousHandler getHandler(final ECKeyPair serverKeyPair) {
|
||||
|
||||
return new NoiseAnonymousHandler(serverKeyPair);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected CipherStatePair doHandshake() throws Exception {
|
||||
return doHandshake(new byte[0]);
|
||||
}
|
||||
|
||||
private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
final HandshakeState clientHandshakeState =
|
||||
new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
|
||||
|
||||
clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0);
|
||||
clientHandshakeState.start();
|
||||
|
||||
// Send initiator handshake message
|
||||
|
||||
// 32 byte key, request payload, 16 byte AEAD tag
|
||||
final int initiateHandshakeMessageLength = 32 + requestPayload.length + 16;
|
||||
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeMessageLength];
|
||||
assertEquals(
|
||||
initiateHandshakeMessageLength,
|
||||
clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
|
||||
|
||||
final BinaryWebSocketFrame initiateHandshakeFrame = new BinaryWebSocketFrame(
|
||||
Unpooled.wrappedBuffer(initiateHandshakeMessage));
|
||||
|
||||
assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeFrame).await().isSuccess());
|
||||
assertEquals(0, initiateHandshakeFrame.refCnt());
|
||||
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
// Read responder handshake message
|
||||
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
||||
final BinaryWebSocketFrame responderHandshakeFrame = (BinaryWebSocketFrame)
|
||||
embeddedChannel.outboundMessages().poll();
|
||||
@SuppressWarnings("DataFlowIssue") final byte[] responderHandshakeBytes =
|
||||
new byte[responderHandshakeFrame.content().readableBytes()];
|
||||
responderHandshakeFrame.content().readBytes(responderHandshakeBytes);
|
||||
|
||||
// ephemeral key, empty encrypted payload AEAD tag
|
||||
final byte[] handshakeResponsePayload = new byte[32 + 16];
|
||||
|
||||
assertEquals(0,
|
||||
clientHandshakeState.readMessage(
|
||||
responderHandshakeBytes, 0, responderHandshakeBytes.length,
|
||||
handshakeResponsePayload, 0));
|
||||
|
||||
final byte[] serverPublicKey = new byte[32];
|
||||
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||
assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes());
|
||||
|
||||
return clientHandshakeState.split();
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeWithRequest() throws ShortBufferException, BadPaddingException {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class));
|
||||
|
||||
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake("ping".getBytes()));
|
||||
final byte[] response = readNextPlaintext(cipherStatePair);
|
||||
assertArrayEquals(response, "pong".getBytes());
|
||||
|
||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class));
|
||||
|
||||
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake(new byte[0]));
|
||||
assertNull(readNextPlaintext(cipherStatePair));
|
||||
|
||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,301 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
|
||||
private ClientPublicKeysManager clientPublicKeysManager;
|
||||
private final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||
|
||||
@Override
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||
|
||||
super.setUp();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NoiseAuthenticatedHandler getHandler(final ECKeyPair serverKeyPair) {
|
||||
return new NoiseAuthenticatedHandler(clientPublicKeysManager, serverKeyPair);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected CipherStatePair doHandshake() throws Throwable {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
return doHandshake(identityPayload(accountIdentifier, deviceId));
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeNoInitialRequest() throws Throwable {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
|
||||
assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId))));
|
||||
|
||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeWithInitialRequest() throws Throwable {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
|
||||
final ByteBuffer bb = ByteBuffer.allocate(17 + 4);
|
||||
bb.put(identityPayload(accountIdentifier, deviceId));
|
||||
bb.put("ping".getBytes());
|
||||
|
||||
final byte[] response = readNextPlaintext(doHandshake(bb.array()));
|
||||
assertEquals(response.length, 4);
|
||||
assertEquals(new String(response), "pong");
|
||||
|
||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeMissingIdentityInformation() {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
|
||||
assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES));
|
||||
|
||||
verifyNoInteractions(clientPublicKeysManager);
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeMalformedIdentityInformation() {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
|
||||
// no deviceId byte
|
||||
byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID());
|
||||
assertThrows(NoiseHandshakeException.class, () -> doHandshake(malformedIdentityPayload));
|
||||
|
||||
verifyNoInteractions(clientPublicKeysManager);
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeUnrecognizedDevice() {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakePublicKeyMismatch() {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
|
||||
|
||||
assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleInvalidExtraWrites() throws NoSuchAlgorithmException, ShortBufferException, InterruptedException {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
|
||||
final HandshakeState clientHandshakeState = clientHandshakeState();
|
||||
|
||||
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
||||
|
||||
final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(
|
||||
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))));
|
||||
assertTrue(embeddedChannel.writeOneInbound(initiatorMessageFrame).await().isSuccess());
|
||||
|
||||
// While waiting for the public key, send another message
|
||||
final ChannelFuture f = embeddedChannel.writeOneInbound(
|
||||
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(new byte[0]))).await();
|
||||
assertInstanceOf(NoiseHandshakeException.class, f.exceptionNow());
|
||||
|
||||
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
// shouldn't return any response or error, we've already processed an error
|
||||
embeddedChannel.checkException();
|
||||
assertNull(embeddedChannel.outboundMessages().poll());
|
||||
}
|
||||
|
||||
private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException {
|
||||
final HandshakeState clientHandshakeState =
|
||||
new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
|
||||
|
||||
clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0);
|
||||
clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0);
|
||||
clientHandshakeState.start();
|
||||
return clientHandshakeState;
|
||||
}
|
||||
|
||||
private byte[] initiatorHandshakeMessage(final HandshakeState clientHandshakeState, final byte[] payload)
|
||||
throws ShortBufferException {
|
||||
// Ephemeral key, encrypted static key, AEAD tag, encrypted payload, AEAD tag
|
||||
final byte[] initiatorMessageBytes = new byte[32 + 32 + 16 + payload.length + 16];
|
||||
int written = clientHandshakeState.writeMessage(initiatorMessageBytes, 0, payload, 0, payload.length);
|
||||
assertEquals(written, initiatorMessageBytes.length);
|
||||
return initiatorMessageBytes;
|
||||
}
|
||||
|
||||
private byte[] readHandshakeResponse(final HandshakeState clientHandshakeState, final byte[] message)
|
||||
throws ShortBufferException, BadPaddingException {
|
||||
|
||||
// 32 byte ephemeral server key, 16 byte AEAD tag for encrypted payload
|
||||
final int expectedResponsePayloadLength = message.length - 32 - 16;
|
||||
final byte[] responsePayload = new byte[expectedResponsePayloadLength];
|
||||
final int responsePayloadLength = clientHandshakeState.readMessage(message, 0, message.length, responsePayload, 0);
|
||||
assertEquals(expectedResponsePayloadLength, responsePayloadLength);
|
||||
return responsePayload;
|
||||
}
|
||||
|
||||
private CipherStatePair doHandshake(final byte[] payload) throws Throwable {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
final HandshakeState clientHandshakeState = clientHandshakeState();
|
||||
final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
|
||||
|
||||
final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame(
|
||||
Unpooled.wrappedBuffer(initiatorMessage));
|
||||
final ChannelFuture await = embeddedChannel.writeOneInbound(initiatorMessageFrame).await();
|
||||
assertEquals(0, initiatorMessageFrame.refCnt());
|
||||
if (!await.isSuccess()) {
|
||||
throw await.cause();
|
||||
}
|
||||
|
||||
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||
// and issue a "handshake complete" event.
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
// rethrow if running the task caused an error
|
||||
embeddedChannel.checkException();
|
||||
|
||||
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
||||
|
||||
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
|
||||
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
||||
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
|
||||
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
|
||||
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
|
||||
|
||||
assertEquals(readHandshakeResponse(clientHandshakeState, serverStaticKeyMessageBytes).length, 0);
|
||||
|
||||
final byte[] serverPublicKey = new byte[32];
|
||||
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||
assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes());
|
||||
|
||||
return clientHandshakeState.split();
|
||||
}
|
||||
|
||||
|
||||
private static byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) {
|
||||
final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17);
|
||||
clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits());
|
||||
clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits());
|
||||
clientIdentityPayloadBuffer.put(deviceId);
|
||||
clientIdentityPayloadBuffer.flip();
|
||||
return clientIdentityPayloadBuffer.array();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* A netty user event that indicates that the noise handshake finished successfully.
|
||||
*
|
||||
* @param fastResponse A response if the client included a request to send in the initiate handshake message payload and
|
||||
* the server included a payload in the handshake response.
|
||||
*/
|
||||
public record NoiseClientHandshakeCompleteEvent(Optional<byte[]> fastResponse) {}
|
|
@ -0,0 +1,55 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||
import java.util.Optional;
|
||||
|
||||
class NoiseClientHandshakeHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final NoiseClientHandshakeHelper handshakeHelper;
|
||||
private final byte[] payload;
|
||||
|
||||
NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper, final byte[] payload) {
|
||||
this.handshakeHelper = handshakeHelper;
|
||||
this.payload = payload;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
|
||||
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
|
||||
byte[] handshakeMessage = handshakeHelper.write(payload);
|
||||
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
}
|
||||
}
|
||||
super.userEventTriggered(context, event);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message)
|
||||
throws NoiseHandshakeException {
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
try {
|
||||
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame.content()));
|
||||
final Optional<byte[]> fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload);
|
||||
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
|
||||
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse));
|
||||
} finally {
|
||||
frame.release();
|
||||
}
|
||||
} else {
|
||||
context.fireChannelRead(message);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
handshakeHelper.destroy();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
|
||||
public class NoiseClientHandshakeHelper {
|
||||
|
||||
private final HandshakePattern handshakePattern;
|
||||
private final HandshakeState handshakeState;
|
||||
|
||||
private NoiseClientHandshakeHelper(HandshakePattern handshakePattern, HandshakeState handshakeState) {
|
||||
this.handshakePattern = handshakePattern;
|
||||
this.handshakeState = handshakeState;
|
||||
}
|
||||
|
||||
static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) {
|
||||
try {
|
||||
final HandshakeState state = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
|
||||
state.getLocalKeyPair().setPrivateKey(clientStaticKey.getPrivateKey().serialize(), 0);
|
||||
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
|
||||
state.start();
|
||||
return new NoiseClientHandshakeHelper(HandshakePattern.IK, state);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
}
|
||||
|
||||
static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) {
|
||||
try {
|
||||
final HandshakeState state = new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
|
||||
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
|
||||
state.start();
|
||||
return new NoiseClientHandshakeHelper(HandshakePattern.NK, state);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
}
|
||||
|
||||
byte[] write(final byte[] requestPayload) throws ShortBufferException {
|
||||
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeKeysLength() + requestPayload.length + 16];
|
||||
handshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length);
|
||||
return initiateHandshakeMessage;
|
||||
}
|
||||
|
||||
private int initiateHandshakeKeysLength() {
|
||||
return switch (handshakePattern) {
|
||||
// 32-byte ephemeral key, 32-byte encrypted static key, 16-byte AEAD tag
|
||||
case IK -> 32 + 32 + 16;
|
||||
// 32-byte ephemeral key
|
||||
case NK -> 32;
|
||||
};
|
||||
}
|
||||
|
||||
byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException {
|
||||
// Don't process additional messages if the handshake failed and we're just waiting to close
|
||||
if (handshakeState.getAction() != HandshakeState.READ_MESSAGE) {
|
||||
throw new NoiseHandshakeException("Received message with handshake state " + handshakeState.getAction());
|
||||
}
|
||||
final int payloadLength = responderHandshakeMessage.length - 16 - 32;
|
||||
final byte[] responsePayload = new byte[payloadLength];
|
||||
final int payloadBytesRead;
|
||||
try {
|
||||
payloadBytesRead = handshakeState
|
||||
.readMessage(responderHandshakeMessage, 0, responderHandshakeMessage.length, responsePayload, 0);
|
||||
if (payloadBytesRead != responsePayload.length) {
|
||||
throw new IllegalStateException(
|
||||
"Unexpected payload length, required " + payloadLength + " got " + payloadBytesRead);
|
||||
}
|
||||
return responsePayload;
|
||||
} catch (ShortBufferException e) {
|
||||
throw new IllegalStateException("Failed to deserialize payload of known length" + e.getMessage());
|
||||
} catch (BadPaddingException e) {
|
||||
throw new NoiseHandshakeException(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
CipherStatePair split() {
|
||||
return this.handshakeState.split();
|
||||
}
|
||||
|
||||
void destroy() {
|
||||
this.handshakeState.destroy();
|
||||
}
|
||||
}
|
|
@ -11,30 +11,26 @@ import io.netty.channel.ChannelPromise;
|
|||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* A Noise transport handler manages a bidirectional Noise session after a handshake has completed.
|
||||
*/
|
||||
class NoiseTransportHandler extends ChannelDuplexHandler {
|
||||
class NoiseClientTransportHandler extends ChannelDuplexHandler {
|
||||
|
||||
private final CipherStatePair cipherStatePair;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseTransportHandler.class);
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseClientTransportHandler.class);
|
||||
|
||||
NoiseTransportHandler(CipherStatePair cipherStatePair) {
|
||||
NoiseClientTransportHandler(CipherStatePair cipherStatePair) {
|
||||
this.cipherStatePair = cipherStatePair;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message)
|
||||
throws ShortBufferException, BadPaddingException {
|
||||
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
try {
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
try {
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
final CipherState cipherState = cipherStatePair.getReceiver();
|
||||
|
||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||
|
@ -45,22 +41,22 @@ class NoiseTransportHandler extends ChannelDuplexHandler {
|
|||
final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length);
|
||||
|
||||
context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength));
|
||||
} finally {
|
||||
frame.release();
|
||||
} else {
|
||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||
// error
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
} else {
|
||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||
// error
|
||||
} finally {
|
||||
ReferenceCountUtil.release(message);
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
|
||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
|
||||
throws Exception {
|
||||
if (message instanceof ByteBuf plaintext) {
|
||||
try {
|
||||
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
|
||||
final CipherState cipherState = cipherStatePair.getSender();
|
||||
final int plaintextLength = plaintext.readableBytes();
|
||||
|
||||
|
@ -75,7 +71,7 @@ class NoiseTransportHandler extends ChannelDuplexHandler {
|
|||
|
||||
context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
|
||||
} finally {
|
||||
plaintext.release();
|
||||
ReferenceCountUtil.release(plaintext);
|
||||
}
|
||||
} else {
|
||||
if (!(message instanceof WebSocketFrame)) {
|
||||
|
@ -83,7 +79,6 @@ class NoiseTransportHandler extends ChannelDuplexHandler {
|
|||
// get issued in response to exceptions)
|
||||
log.warn("Unexpected object in pipeline: {}", message);
|
||||
}
|
||||
|
||||
context.write(message, promise);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatNoException;
|
||||
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
|
||||
public class NoiseHandshakeHelperTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(HandshakePattern.class)
|
||||
void testWithPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
|
||||
doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), "pong".getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(HandshakePattern.class)
|
||||
void testWithRequestPayload(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
|
||||
doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), new byte[0]);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(HandshakePattern.class)
|
||||
void testWithoutPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
|
||||
doHandshake(pattern, new byte[0], new byte[0]);
|
||||
}
|
||||
|
||||
void doHandshake(final HandshakePattern pattern, final byte[] requestPayload, final byte[] responsePayload) throws ShortBufferException, NoiseHandshakeException {
|
||||
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
|
||||
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||
|
||||
NoiseHandshakeHelper serverHelper = new NoiseHandshakeHelper(pattern, serverKeyPair);
|
||||
NoiseClientHandshakeHelper clientHelper = switch (pattern) {
|
||||
case IK -> NoiseClientHandshakeHelper.IK(serverKeyPair.getPublicKey(), clientKeyPair);
|
||||
case NK -> NoiseClientHandshakeHelper.NK(serverKeyPair.getPublicKey());
|
||||
};
|
||||
|
||||
final byte[] initiate = clientHelper.write(requestPayload);
|
||||
final ByteBuf actualRequestPayload = serverHelper.read(initiate);
|
||||
assertThat(ByteBufUtil.getBytes(actualRequestPayload)).isEqualTo(requestPayload);
|
||||
|
||||
assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.WRITE_MESSAGE);
|
||||
|
||||
final byte[] respond = serverHelper.write(responsePayload);
|
||||
byte[] actualResponsePayload = clientHelper.read(respond);
|
||||
assertThat(actualResponsePayload).isEqualTo(responsePayload);
|
||||
|
||||
assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.SPLIT);
|
||||
assertThatNoException().isThrownBy(() -> serverHelper.getHandshakeState().split());
|
||||
assertThatNoException().isThrownBy(() -> clientHelper.split());
|
||||
}
|
||||
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import java.util.Optional;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
|
||||
class NoiseNXClientHandshakeHandler extends AbstractNoiseClientHandler {
|
||||
|
||||
private boolean receivedServerStaticKeyMessage = false;
|
||||
|
||||
NoiseNXClientHandshakeHandler(final ECPublicKey rootPublicKey) {
|
||||
super(rootPublicKey);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getNoiseProtocolName() {
|
||||
return NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void startHandshake() {
|
||||
getHandshakeState().start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
try {
|
||||
// Don't process additional messages if we're just waiting to close because the handshake failed
|
||||
if (receivedServerStaticKeyMessage) {
|
||||
return;
|
||||
}
|
||||
|
||||
receivedServerStaticKeyMessage = true;
|
||||
handleServerStaticKeyMessage(context, frame);
|
||||
|
||||
context.pipeline().replace(this, null, new NoiseTransportHandler(getHandshakeState().split()));
|
||||
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
|
||||
} finally {
|
||||
frame.release();
|
||||
}
|
||||
} else {
|
||||
context.fireChannelRead(message);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.Optional;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
class NoiseNXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest {
|
||||
|
||||
@Override
|
||||
protected NoiseNXHandshakeHandler getHandler(final ECKeyPair serverKeyPair,
|
||||
final byte[] serverPublicKeySignature) {
|
||||
|
||||
return new NoiseNXHandshakeHandler(serverKeyPair, serverPublicKeySignature);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshake()
|
||||
throws NoSuchAlgorithmException, ShortBufferException, InterruptedException, BadPaddingException {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class));
|
||||
|
||||
final HandshakeState clientHandshakeState =
|
||||
new HandshakeState(NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
|
||||
|
||||
clientHandshakeState.start();
|
||||
|
||||
{
|
||||
final byte[] ephemeralKeyMessageBytes = new byte[32];
|
||||
clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0);
|
||||
|
||||
final BinaryWebSocketFrame ephemeralKeyMessageFrame =
|
||||
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes));
|
||||
|
||||
assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess());
|
||||
assertEquals(0, ephemeralKeyMessageFrame.refCnt());
|
||||
}
|
||||
|
||||
{
|
||||
assertEquals(1, embeddedChannel.outboundMessages().size());
|
||||
|
||||
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
|
||||
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
||||
|
||||
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
|
||||
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
|
||||
|
||||
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
|
||||
|
||||
final byte[] serverPublicKeySignature = new byte[64];
|
||||
|
||||
final int payloadLength =
|
||||
clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0);
|
||||
|
||||
assertEquals(serverPublicKeySignature.length, payloadLength);
|
||||
|
||||
final byte[] serverPublicKey = new byte[32];
|
||||
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||
|
||||
assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature));
|
||||
}
|
||||
|
||||
assertEquals(new NoiseHandshakeCompleteEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class),
|
||||
"Handshake handler should remove self from pipeline after successful handshake");
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
|
||||
"Handshake handler should insert a Noise stream handler after successful handshake");
|
||||
}
|
||||
}
|
|
@ -1,135 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import io.netty.util.internal.ThreadLocalRandom;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import javax.crypto.AEADBadTagException;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class NoiseTransportHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
private CipherStatePair clientCipherStatePair;
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
// We use an NN handshake for this test just because it's a little shorter and easier to set up
|
||||
private static final String NOISE_PROTOCOL_NAME = "Noise_NN_25519_ChaChaPoly_BLAKE2b";
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException {
|
||||
final HandshakeState clientHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
|
||||
final HandshakeState serverHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.RESPONDER);
|
||||
|
||||
clientHandshakeState.start();
|
||||
serverHandshakeState.start();
|
||||
|
||||
final byte[] clientEphemeralKeyMessage = new byte[32];
|
||||
assertEquals(clientEphemeralKeyMessage.length,
|
||||
clientHandshakeState.writeMessage(clientEphemeralKeyMessage, 0, null, 0, 0));
|
||||
|
||||
serverHandshakeState.readMessage(clientEphemeralKeyMessage, 0, clientEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
|
||||
|
||||
// 32 bytes of key material plus a 16-byte MAC
|
||||
final byte[] serverEphemeralKeyMessage = new byte[48];
|
||||
assertEquals(serverEphemeralKeyMessage.length,
|
||||
serverHandshakeState.writeMessage(serverEphemeralKeyMessage, 0, null, 0, 0));
|
||||
|
||||
clientHandshakeState.readMessage(serverEphemeralKeyMessage, 0, serverEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
|
||||
|
||||
clientCipherStatePair = clientHandshakeState.split();
|
||||
embeddedChannel = new EmbeddedChannel(new NoiseTransportHandler(serverHandshakeState.split()));
|
||||
|
||||
clientHandshakeState.destroy();
|
||||
serverHandshakeState.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
void channelRead() throws ShortBufferException, InterruptedException {
|
||||
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
|
||||
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
|
||||
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
|
||||
|
||||
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext));
|
||||
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
|
||||
assertEquals(0, ciphertextFrame.refCnt());
|
||||
|
||||
final ByteBuf decryptedPlaintextBuffer = (ByteBuf) embeddedChannel.inboundMessages().poll();
|
||||
assertNotNull(decryptedPlaintextBuffer);
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
|
||||
final byte[] decryptedPlaintext = ByteBufUtil.getBytes(decryptedPlaintextBuffer);
|
||||
decryptedPlaintextBuffer.release();
|
||||
|
||||
assertArrayEquals(plaintext, decryptedPlaintext);
|
||||
}
|
||||
|
||||
@Test
|
||||
void channelReadBadCiphertext() throws InterruptedException {
|
||||
final byte[] bogusCiphertext = new byte[32];
|
||||
ThreadLocalRandom.current().nextBytes(bogusCiphertext);
|
||||
|
||||
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext));
|
||||
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
|
||||
|
||||
assertEquals(0, ciphertextFrame.refCnt());
|
||||
assertFalse(readCiphertextFuture.isSuccess());
|
||||
assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause());
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void channelReadUnexpectedMessageType() throws InterruptedException {
|
||||
final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await();
|
||||
|
||||
assertFalse(readFuture.isSuccess());
|
||||
assertInstanceOf(IllegalArgumentException.class, readFuture.cause());
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void write() throws InterruptedException, ShortBufferException, BadPaddingException {
|
||||
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
|
||||
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext);
|
||||
|
||||
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
|
||||
assertTrue(writePlaintextFuture.await().isSuccess());
|
||||
assertEquals(0, plaintextBuffer.refCnt());
|
||||
|
||||
final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
||||
assertNotNull(ciphertextFrame);
|
||||
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
||||
|
||||
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
|
||||
ciphertextFrame.release();
|
||||
|
||||
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
|
||||
clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length);
|
||||
|
||||
assertArrayEquals(plaintext, decryptedPlaintext);
|
||||
}
|
||||
|
||||
@Test
|
||||
void writeUnexpectedMessageType() throws InterruptedException {
|
||||
final Object unexpectedMessaged = new Object();
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged);
|
||||
assertTrue(writeFuture.await().isSuccess());
|
||||
|
||||
assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll());
|
||||
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
||||
}
|
||||
}
|
|
@ -1,22 +1,26 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.local.LocalServerChannel;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import java.net.SocketAddress;
|
||||
import java.net.URI;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.UUID;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
class NoiseWebSocketTunnelClient implements AutoCloseable {
|
||||
|
||||
|
@ -26,19 +30,86 @@ class NoiseWebSocketTunnelClient implements AutoCloseable {
|
|||
static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
|
||||
static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
|
||||
|
||||
public NoiseWebSocketTunnelClient(final SocketAddress remoteServerAddress,
|
||||
final URI websocketUri,
|
||||
final boolean authenticated,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final ECPublicKey rootPublicKey,
|
||||
@Nullable final UUID accountIdentifier,
|
||||
final byte deviceId,
|
||||
final HttpHeaders headers,
|
||||
final boolean useTls,
|
||||
@Nullable final X509Certificate trustedServerCertificate,
|
||||
@Nullable final Supplier<HAProxyMessage> proxyMessageSupplier,
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final WebSocketCloseListener webSocketCloseListener) {
|
||||
static class Builder {
|
||||
|
||||
final SocketAddress remoteServerAddress;
|
||||
NioEventLoopGroup eventLoopGroup;
|
||||
ECPublicKey serverPublicKey;
|
||||
|
||||
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
|
||||
HttpHeaders headers = new DefaultHttpHeaders();
|
||||
WebSocketCloseListener webSocketCloseListener = WebSocketCloseListener.NOOP_LISTENER;
|
||||
|
||||
boolean authenticated = false;
|
||||
ECKeyPair ecKeyPair = null;
|
||||
UUID accountIdentifier = null;
|
||||
byte deviceId = 0x00;
|
||||
boolean useTls;
|
||||
X509Certificate trustedServerCertificate = null;
|
||||
Supplier<HAProxyMessage> proxyMessageSupplier = null;
|
||||
|
||||
Builder(
|
||||
final SocketAddress remoteServerAddress,
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final ECPublicKey serverPublicKey) {
|
||||
this.remoteServerAddress = remoteServerAddress;
|
||||
this.eventLoopGroup = eventLoopGroup;
|
||||
this.serverPublicKey = serverPublicKey;
|
||||
}
|
||||
|
||||
Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
|
||||
this.authenticated = true;
|
||||
this.accountIdentifier = accountIdentifier;
|
||||
this.deviceId = deviceId;
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
|
||||
return this;
|
||||
}
|
||||
|
||||
Builder setWebsocketUri(final URI websocketUri) {
|
||||
this.websocketUri = websocketUri;
|
||||
return this;
|
||||
}
|
||||
|
||||
Builder setUseTls(X509Certificate trustedServerCertificate) {
|
||||
this.useTls = true;
|
||||
this.trustedServerCertificate = trustedServerCertificate;
|
||||
return this;
|
||||
}
|
||||
|
||||
Builder setProxyMessageSupplier(Supplier<HAProxyMessage> proxyMessageSupplier) {
|
||||
this.proxyMessageSupplier = proxyMessageSupplier;
|
||||
return this;
|
||||
}
|
||||
|
||||
Builder setHeaders(final HttpHeaders headers) {
|
||||
this.headers = headers;
|
||||
return this;
|
||||
}
|
||||
|
||||
Builder setWebSocketCloseListener(final WebSocketCloseListener webSocketCloseListener) {
|
||||
this.webSocketCloseListener = webSocketCloseListener;
|
||||
return this;
|
||||
}
|
||||
|
||||
Builder setServerPublicKey(ECPublicKey serverPublicKey) {
|
||||
this.serverPublicKey = serverPublicKey;
|
||||
return this;
|
||||
}
|
||||
|
||||
NoiseWebSocketTunnelClient build() {
|
||||
final NoiseWebSocketTunnelClient client =
|
||||
new NoiseWebSocketTunnelClient(eventLoopGroup, fastOpenRequest -> new EstablishRemoteConnectionHandler(
|
||||
useTls, trustedServerCertificate, websocketUri, authenticated, ecKeyPair, serverPublicKey,
|
||||
accountIdentifier, deviceId, headers, remoteServerAddress, webSocketCloseListener, proxyMessageSupplier,
|
||||
fastOpenRequest));
|
||||
client.start();
|
||||
return client;
|
||||
}
|
||||
}
|
||||
|
||||
private NoiseWebSocketTunnelClient(NioEventLoopGroup eventLoopGroup,
|
||||
Function<byte[], EstablishRemoteConnectionHandler> handler) {
|
||||
|
||||
this.serverBootstrap = new ServerBootstrap()
|
||||
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
|
||||
|
@ -47,28 +118,37 @@ class NoiseWebSocketTunnelClient implements AutoCloseable {
|
|||
.childHandler(new ChannelInitializer<LocalChannel>() {
|
||||
@Override
|
||||
protected void initChannel(final LocalChannel localChannel) {
|
||||
localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(useTls,
|
||||
trustedServerCertificate,
|
||||
websocketUri,
|
||||
authenticated,
|
||||
ecKeyPair,
|
||||
rootPublicKey,
|
||||
accountIdentifier,
|
||||
deviceId,
|
||||
headers,
|
||||
remoteServerAddress,
|
||||
webSocketCloseListener,
|
||||
proxyMessageSupplier));
|
||||
localChannel.pipeline()
|
||||
// We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the
|
||||
// stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put
|
||||
// in the handshake.
|
||||
.addLast(Http2Buffering.handler())
|
||||
// Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At
|
||||
// that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually
|
||||
// connect to the remote service
|
||||
.addLast(new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
|
||||
if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) {
|
||||
byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest());
|
||||
requestBufferedEvent.fastOpenRequest().release();
|
||||
ctx.pipeline().addLast(handler.apply(fastOpenRequest));
|
||||
}
|
||||
super.userEventTriggered(ctx, evt);
|
||||
}
|
||||
})
|
||||
.addLast(new ClientErrorHandler());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
LocalAddress getLocalAddress() {
|
||||
return (LocalAddress) serverChannel.localAddress();
|
||||
}
|
||||
|
||||
NoiseWebSocketTunnelClient start() throws InterruptedException {
|
||||
serverChannel = serverBootstrap.bind().await().channel();
|
||||
private NoiseWebSocketTunnelClient start() {
|
||||
serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel();
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -76,4 +156,5 @@ class NoiseWebSocketTunnelClient implements AutoCloseable {
|
|||
public void close() throws InterruptedException {
|
||||
serverChannel.close().await();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ import java.util.concurrent.Executors;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.function.Supplier;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.TrustManagerFactory;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
|
@ -67,7 +68,6 @@ import org.signal.chat.rpc.GetRequestAttributesResponse;
|
|||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
|
||||
|
@ -89,7 +89,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
private ClientConnectionManager clientConnectionManager;
|
||||
private ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private ECKeyPair rootKeyPair;
|
||||
private ECKeyPair serverKeyPair;
|
||||
private ECKeyPair clientKeyPair;
|
||||
|
||||
private ManagedLocalGrpcServer authenticatedGrpcServer;
|
||||
|
@ -153,9 +153,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
|
||||
}
|
||||
|
||||
rootKeyPair = Curve.generateKeyPair();
|
||||
clientKeyPair = Curve.generateKeyPair();
|
||||
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
|
||||
serverKeyPair = Curve.generateKeyPair();
|
||||
|
||||
clientConnectionManager = new ClientConnectionManager();
|
||||
|
||||
|
@ -192,14 +191,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
anonymousGrpcServer.start();
|
||||
|
||||
tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
|
||||
new X509Certificate[] { serverTlsCertificate },
|
||||
new X509Certificate[]{serverTlsCertificate},
|
||||
serverTlsPrivateKey,
|
||||
nioEventLoopGroup,
|
||||
delegatedTaskExecutor,
|
||||
clientConnectionManager,
|
||||
clientPublicKeysManager,
|
||||
serverKeyPair,
|
||||
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
|
||||
authenticatedGrpcServerAddress,
|
||||
anonymousGrpcServerAddress,
|
||||
RECOGNIZED_PROXY_SECRET);
|
||||
|
@ -214,7 +212,6 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
clientConnectionManager,
|
||||
clientPublicKeysManager,
|
||||
serverKeyPair,
|
||||
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
|
||||
authenticatedGrpcServerAddress,
|
||||
anonymousGrpcServerAddress,
|
||||
RECOGNIZED_PROXY_SECRET);
|
||||
|
@ -241,9 +238,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = { true, false })
|
||||
@ValueSource(booleans = {true, false})
|
||||
void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
|
||||
try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), includeProxyMessage)) {
|
||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
||||
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
|
||||
.build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
|
@ -259,24 +258,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = { true, false })
|
||||
@ValueSource(booleans = {true, false})
|
||||
void connectAuthenticatedPlaintext(final boolean includeProxyMessage) throws InterruptedException {
|
||||
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient(
|
||||
tlsNoiseWebSocketTunnelServer.getLocalAddress(),
|
||||
NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
|
||||
true,
|
||||
clientKeyPair,
|
||||
rootKeyPair.getPublicKey(),
|
||||
ACCOUNT_IDENTIFIER,
|
||||
DEVICE_ID,
|
||||
new DefaultHttpHeaders(),
|
||||
true,
|
||||
serverTlsCertificate,
|
||||
includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null,
|
||||
nioEventLoopGroup,
|
||||
WebSocketCloseListener.NOOP_LISTENER)
|
||||
.start()) {
|
||||
|
||||
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient
|
||||
.Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
|
||||
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID)
|
||||
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
|
||||
.build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
|
@ -295,10 +283,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
void connectAuthenticatedBadServerKeySignature() throws InterruptedException {
|
||||
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
||||
|
||||
|
||||
// Try to verify the server's public key with something other than the key with which it was signed
|
||||
try (final NoiseWebSocketTunnelClient client =
|
||||
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders(), false)) {
|
||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
||||
.setWebSocketCloseListener(webSocketCloseListener)
|
||||
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
|
||||
.build()) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
|
@ -312,7 +301,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
}
|
||||
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
||||
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -322,7 +312,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
|
||||
|
||||
try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(webSocketCloseListener)) {
|
||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
||||
.setWebSocketCloseListener(webSocketCloseListener)
|
||||
.build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
|
@ -335,7 +327,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
}
|
||||
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
||||
ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -345,8 +338,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
try (final NoiseWebSocketTunnelClient client =
|
||||
buildAndStartAuthenticatedClient(webSocketCloseListener)) {
|
||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
||||
.setWebSocketCloseListener(webSocketCloseListener)
|
||||
.build()) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
|
@ -360,29 +354,18 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
}
|
||||
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
||||
ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAuthenticatedToAnonymousService() throws InterruptedException {
|
||||
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
||||
|
||||
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient(
|
||||
tlsNoiseWebSocketTunnelServer.getLocalAddress(),
|
||||
URI.create("wss://localhost/anonymous"),
|
||||
true,
|
||||
clientKeyPair,
|
||||
rootKeyPair.getPublicKey(),
|
||||
ACCOUNT_IDENTIFIER,
|
||||
DEVICE_ID,
|
||||
new DefaultHttpHeaders(),
|
||||
true,
|
||||
serverTlsCertificate,
|
||||
null,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
.start()) {
|
||||
|
||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
||||
.setWebsocketUri(NoiseWebSocketTunnelClient.ANONYMOUS_WEBSOCKET_URI)
|
||||
.setWebSocketCloseListener(webSocketCloseListener)
|
||||
.build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
|
@ -395,12 +378,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
}
|
||||
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
||||
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAnonymous() throws InterruptedException {
|
||||
try (final NoiseWebSocketTunnelClient client = buildAndStartAnonymousClient()) {
|
||||
try (final NoiseWebSocketTunnelClient client = anonymous().build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
|
@ -420,9 +404,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
||||
|
||||
// Try to verify the server's public key with something other than the key with which it was signed
|
||||
try (final NoiseWebSocketTunnelClient client =
|
||||
buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) {
|
||||
|
||||
try (final NoiseWebSocketTunnelClient client = anonymous()
|
||||
.setWebSocketCloseListener(webSocketCloseListener)
|
||||
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
|
||||
.build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
|
@ -435,29 +420,18 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
}
|
||||
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
||||
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAnonymousToAuthenticatedService() throws InterruptedException {
|
||||
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
||||
|
||||
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient(
|
||||
tlsNoiseWebSocketTunnelServer.getLocalAddress(),
|
||||
URI.create("wss://localhost/authenticated"),
|
||||
false,
|
||||
null,
|
||||
rootKeyPair.getPublicKey(),
|
||||
null,
|
||||
(byte) 0,
|
||||
new DefaultHttpHeaders(),
|
||||
true,
|
||||
serverTlsCertificate,
|
||||
null,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
.start()) {
|
||||
|
||||
try (final NoiseWebSocketTunnelClient client = anonymous()
|
||||
.setWebsocketUri(NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI)
|
||||
.setWebSocketCloseListener(webSocketCloseListener)
|
||||
.build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
|
@ -470,7 +444,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
}
|
||||
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||
verify(webSocketCloseListener).handleWebSocketClosedByServer(
|
||||
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||
}
|
||||
|
||||
private ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
|
||||
|
@ -506,8 +481,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
|
||||
.uri(authenticatedUri)
|
||||
.PUT(HttpRequest.BodyPublishers.ofString("test"))
|
||||
.build(),
|
||||
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||
.build(),
|
||||
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||
"Non-GET requests should not be allowed");
|
||||
|
||||
assertEquals(426, httpClient.send(HttpRequest.newBuilder()
|
||||
|
@ -538,8 +513,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
.add("Accept-Language", acceptLanguage)
|
||||
.add("User-Agent", userAgent);
|
||||
|
||||
try (final NoiseWebSocketTunnelClient client =
|
||||
buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), headers)) {
|
||||
try (final NoiseWebSocketTunnelClient client = anonymous().setHeaders(headers).build()) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
|
@ -582,7 +556,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
};
|
||||
|
||||
try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(webSocketCloseListener)) {
|
||||
try (final NoiseWebSocketTunnelClient client = authenticated()
|
||||
.setWebSocketCloseListener(webSocketCloseListener)
|
||||
.build()) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
|
@ -606,63 +582,24 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
|||
}
|
||||
}
|
||||
|
||||
private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException {
|
||||
return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER);
|
||||
private NoiseWebSocketTunnelClient.Builder anonymous() {
|
||||
return new NoiseWebSocketTunnelClient
|
||||
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
|
||||
.setUseTls(serverTlsCertificate);
|
||||
|
||||
}
|
||||
|
||||
private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener)
|
||||
throws InterruptedException {
|
||||
|
||||
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), false);
|
||||
private NoiseWebSocketTunnelClient.Builder authenticated() {
|
||||
return new NoiseWebSocketTunnelClient
|
||||
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
|
||||
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID)
|
||||
.setUseTls(serverTlsCertificate);
|
||||
}
|
||||
|
||||
private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener,
|
||||
final ECPublicKey rootPublicKey,
|
||||
final HttpHeaders headers,
|
||||
final boolean includeProxyMessage) throws InterruptedException {
|
||||
|
||||
return new NoiseWebSocketTunnelClient(tlsNoiseWebSocketTunnelServer.getLocalAddress(),
|
||||
NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
|
||||
true,
|
||||
clientKeyPair,
|
||||
rootPublicKey,
|
||||
ACCOUNT_IDENTIFIER,
|
||||
DEVICE_ID,
|
||||
headers,
|
||||
true,
|
||||
serverTlsCertificate,
|
||||
includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
.start();
|
||||
}
|
||||
|
||||
private NoiseWebSocketTunnelClient buildAndStartAnonymousClient() throws InterruptedException {
|
||||
return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders());
|
||||
}
|
||||
|
||||
private NoiseWebSocketTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener,
|
||||
final ECPublicKey rootPublicKey,
|
||||
final HttpHeaders headers) throws InterruptedException {
|
||||
|
||||
return new NoiseWebSocketTunnelClient(tlsNoiseWebSocketTunnelServer.getLocalAddress(),
|
||||
NoiseWebSocketTunnelClient.ANONYMOUS_WEBSOCKET_URI,
|
||||
false,
|
||||
null,
|
||||
rootPublicKey,
|
||||
null,
|
||||
(byte) 0,
|
||||
headers,
|
||||
true,
|
||||
serverTlsCertificate,
|
||||
null,
|
||||
nioEventLoopGroup,
|
||||
webSocketCloseListener)
|
||||
.start();
|
||||
}
|
||||
|
||||
private static HAProxyMessage buildProxyMessage() {
|
||||
return new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
|
||||
"10.0.0.1", "10.0.0.2", 12345, 443);
|
||||
private static Supplier<HAProxyMessage> proxyMessageSupplier(boolean includeProxyMesage) {
|
||||
return includeProxyMesage
|
||||
? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
|
||||
"10.0.0.1", "10.0.0.2", 12345, 443)
|
||||
: null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,89 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
|
||||
class NoiseXXClientHandshakeHandler extends AbstractNoiseClientHandler {
|
||||
|
||||
private final ECKeyPair ecKeyPair;
|
||||
|
||||
private final UUID accountIdentifier;
|
||||
private final byte deviceId;
|
||||
|
||||
private boolean receivedServerStaticKeyMessage = false;
|
||||
|
||||
NoiseXXClientHandshakeHandler(final ECKeyPair ecKeyPair,
|
||||
final ECPublicKey rootPublicKey,
|
||||
final UUID accountIdentifier,
|
||||
final byte deviceId) {
|
||||
|
||||
super(rootPublicKey);
|
||||
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
|
||||
this.accountIdentifier = accountIdentifier;
|
||||
this.deviceId = deviceId;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getNoiseProtocolName() {
|
||||
return NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void startHandshake() {
|
||||
final HandshakeState handshakeState = getHandshakeState();
|
||||
|
||||
// Noise-java derives the public key from the private key, so we just need to set the private key
|
||||
handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0);
|
||||
handshakeState.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message)
|
||||
throws NoiseHandshakeException, ShortBufferException {
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
try {
|
||||
// Don't process additional messages if the handshake failed and we're just waiting to close
|
||||
if (receivedServerStaticKeyMessage) {
|
||||
return;
|
||||
}
|
||||
|
||||
receivedServerStaticKeyMessage = true;
|
||||
handleServerStaticKeyMessage(context, frame);
|
||||
|
||||
final ByteBuffer clientIdentityBuffer = ByteBuffer.allocate(17);
|
||||
clientIdentityBuffer.putLong(accountIdentifier.getMostSignificantBits());
|
||||
clientIdentityBuffer.putLong(accountIdentifier.getLeastSignificantBits());
|
||||
clientIdentityBuffer.put(deviceId);
|
||||
clientIdentityBuffer.flip();
|
||||
|
||||
final HandshakeState handshakeState = getHandshakeState();
|
||||
|
||||
// We're sending two 32-byte keys plus the client identity payload
|
||||
final byte[] staticKeyAndIdentityMessage = new byte[64 + clientIdentityBuffer.remaining()];
|
||||
handshakeState.writeMessage(
|
||||
staticKeyAndIdentityMessage, 0, clientIdentityBuffer.array(), 0, clientIdentityBuffer.remaining());
|
||||
|
||||
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(staticKeyAndIdentityMessage)))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
|
||||
context.pipeline().replace(this, null, new NoiseTransportHandler(handshakeState.split()));
|
||||
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
|
||||
} finally {
|
||||
frame.release();
|
||||
}
|
||||
} else {
|
||||
context.fireChannelRead(message);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,453 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherState;
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
||||
class NoiseXXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest {
|
||||
|
||||
private ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
@Override
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||
|
||||
super.setUp();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NoiseXXHandshakeHandler getHandler(final ECKeyPair serverKeyPair,
|
||||
final byte[] serverPublicKeySignature) {
|
||||
|
||||
return new NoiseXXHandshakeHandler(clientPublicKeysManager, serverKeyPair, serverPublicKeySignature);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshake()
|
||||
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
|
||||
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
|
||||
|
||||
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||
// and issue a "handshake complete" event.
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||
"Handshake handler should remove self from pipeline after successful handshake");
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
|
||||
"Handshake handler should insert a Noise stream handler after successful handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeMissingIdentityInformation()
|
||||
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
|
||||
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||
|
||||
{
|
||||
final byte[] clientStaticKeyMessageBytes = new byte[64];
|
||||
final int messageLength =
|
||||
clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, EmptyArrays.EMPTY_BYTES, 0, 0);
|
||||
|
||||
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
|
||||
|
||||
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
|
||||
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
|
||||
|
||||
final ChannelFuture writeClientStaticKeyMessageFuture =
|
||||
getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await();
|
||||
|
||||
assertFalse(writeClientStaticKeyMessageFuture.isSuccess());
|
||||
assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause());
|
||||
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
|
||||
}
|
||||
|
||||
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||
// and issue a "handshake complete" event.
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeMalformedIdentityInformation()
|
||||
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
|
||||
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||
|
||||
{
|
||||
final byte[] clientStaticKeyMessageBytes = new byte[96];
|
||||
final int messageLength =
|
||||
clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, new byte[32], 0, 32);
|
||||
|
||||
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
|
||||
|
||||
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
|
||||
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
|
||||
|
||||
final ChannelFuture writeClientStaticKeyMessageFuture =
|
||||
getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await();
|
||||
|
||||
assertFalse(writeClientStaticKeyMessageFuture.isSuccess());
|
||||
assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause());
|
||||
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
|
||||
}
|
||||
|
||||
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||
// and issue a "handshake complete" event.
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeUnrecognizedDevice()
|
||||
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
|
||||
|
||||
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||
// and issue a "handshake complete" event.
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException);
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakePublicKeyMismatch()
|
||||
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
|
||||
|
||||
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
|
||||
|
||||
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||
// and issue a "handshake complete" event.
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException);
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeBufferedReads()
|
||||
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||
|
||||
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
||||
|
||||
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
|
||||
|
||||
final ByteBuf[] additionalMessages = new ByteBuf[4];
|
||||
final CipherState senderState = clientHandshakeState.split().getSender();
|
||||
|
||||
try {
|
||||
for (int i = 0; i < additionalMessages.length; i++) {
|
||||
final byte[] contentBytes = new byte[32];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
// Copy the "plaintext" portion of the content bytes for future assertions
|
||||
additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16);
|
||||
|
||||
// Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD
|
||||
// tag
|
||||
senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16);
|
||||
|
||||
assertTrue(
|
||||
embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await()
|
||||
.isSuccess());
|
||||
}
|
||||
|
||||
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
|
||||
|
||||
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||
// and issue a "handshake complete" event.
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||
"Handshake handler should remove self from pipeline after successful handshake");
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
|
||||
"Handshake handler should insert a Noise stream handler after successful handshake");
|
||||
|
||||
for (final ByteBuf additionalMessage : additionalMessages) {
|
||||
assertEquals(additionalMessage, embeddedChannel.inboundMessages().poll(),
|
||||
"Buffered message should pass through pipeline after successful handshake");
|
||||
}
|
||||
} finally {
|
||||
for (final ByteBuf additionalMessage : additionalMessages) {
|
||||
additionalMessage.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeFailureBufferedReads()
|
||||
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||
|
||||
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
||||
|
||||
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
|
||||
|
||||
final ByteBuf[] additionalMessages = new ByteBuf[4];
|
||||
final CipherState senderState = clientHandshakeState.split().getSender();
|
||||
|
||||
try {
|
||||
for (int i = 0; i < additionalMessages.length; i++) {
|
||||
final byte[] contentBytes = new byte[32];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
// Copy the "plaintext" portion of the content bytes for future assertions
|
||||
additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16);
|
||||
|
||||
// Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD
|
||||
// tag
|
||||
senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16);
|
||||
|
||||
assertTrue(embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await().isSuccess());
|
||||
}
|
||||
|
||||
findPublicKeyFuture.complete(Optional.empty());
|
||||
|
||||
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||
// and issue a "handshake complete" event.
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty(),
|
||||
"Buffered messages should not pass through pipeline after failed handshake");
|
||||
} finally {
|
||||
for (final ByteBuf additionalMessage : additionalMessages) {
|
||||
additionalMessage.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private HandshakeState exchangeClientEphemeralAndServerStaticMessages(final ECKeyPair clientKeyPair)
|
||||
throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException, InterruptedException {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
final HandshakeState clientHandshakeState =
|
||||
new HandshakeState(NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
|
||||
|
||||
clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0);
|
||||
clientHandshakeState.start();
|
||||
|
||||
{
|
||||
final byte[] ephemeralKeyMessageBytes = new byte[32];
|
||||
clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0);
|
||||
|
||||
final BinaryWebSocketFrame ephemeralKeyMessageFrame =
|
||||
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes));
|
||||
|
||||
assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess());
|
||||
assertEquals(0, ephemeralKeyMessageFrame.refCnt());
|
||||
}
|
||||
|
||||
{
|
||||
assertEquals(1, embeddedChannel.outboundMessages().size());
|
||||
|
||||
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
|
||||
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
||||
|
||||
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
|
||||
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
|
||||
|
||||
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
|
||||
|
||||
final byte[] serverPublicKeySignature = new byte[64];
|
||||
|
||||
final int payloadLength =
|
||||
clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0);
|
||||
|
||||
assertEquals(serverPublicKeySignature.length, payloadLength);
|
||||
|
||||
final byte[] serverPublicKey = new byte[32];
|
||||
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||
|
||||
assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature));
|
||||
}
|
||||
|
||||
return clientHandshakeState;
|
||||
}
|
||||
|
||||
private void sendClientStaticKey(final HandshakeState handshakeState, final UUID accountIdentifier, final byte deviceId)
|
||||
throws ShortBufferException, InterruptedException {
|
||||
|
||||
final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17);
|
||||
clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits());
|
||||
clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits());
|
||||
clientIdentityPayloadBuffer.put(deviceId);
|
||||
clientIdentityPayloadBuffer.flip();
|
||||
|
||||
final byte[] clientStaticKeyMessageBytes = new byte[81];
|
||||
final int messageLength =
|
||||
handshakeState.writeMessage(clientStaticKeyMessageBytes, 0, clientIdentityPayloadBuffer.array(), 0, clientIdentityPayloadBuffer.remaining());
|
||||
|
||||
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
|
||||
|
||||
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
|
||||
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
|
||||
|
||||
assertTrue(getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await().isSuccess());
|
||||
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* A TypedNoiseChannelDuplexHandler is a convenience {@link ChannelDuplexHandler} that can be inserted in a pipeline
|
||||
* after a successful websocket handshake. It expects inbound messages to be {@link BinaryWebSocketFrame}s and outbound
|
||||
* messages to be bytes.
|
||||
*/
|
||||
abstract class TypedNoiseChannelDuplexHandler extends ChannelDuplexHandler {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(TypedNoiseChannelDuplexHandler.class);
|
||||
|
||||
/**
|
||||
* Handle an inbound message. The frame will be automatically released after the method is finished running.
|
||||
*
|
||||
* @param context The current {@link ChannelHandlerContext}
|
||||
* @param frameBytes A {@link ByteBuf} extracted from a {@link BinaryWebSocketFrame} that contains a complete noise
|
||||
* packet
|
||||
* @throws Exception
|
||||
*/
|
||||
abstract void handleInbound(final ChannelHandlerContext context, ByteBuf frameBytes) throws Exception;
|
||||
|
||||
/**
|
||||
* Handle an outbound byte message. The message will be automatically released after the method is finished running.
|
||||
*
|
||||
* @param context The current {@link ChannelHandlerContext}
|
||||
* @param bytes The bytes to write
|
||||
* @throws Exception
|
||||
*/
|
||||
abstract void handleOutbound(final ChannelHandlerContext context, final ByteBuf bytes,
|
||||
final ChannelPromise promise) throws Exception;
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
try {
|
||||
if (message instanceof BinaryWebSocketFrame frame) {
|
||||
handleInbound(context, frame.content());
|
||||
} else {
|
||||
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||
// error
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
} finally {
|
||||
ReferenceCountUtil.release(message);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
|
||||
throws Exception {
|
||||
if (message instanceof ByteBuf serverResponse) {
|
||||
try {
|
||||
handleOutbound(context, serverResponse, promise);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(serverResponse);
|
||||
}
|
||||
} else {
|
||||
if (!(message instanceof WebSocketFrame)) {
|
||||
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
|
||||
// get issued in response to exceptions)
|
||||
log.warn("Unexpected object in pipeline: {}", message);
|
||||
}
|
||||
context.write(message, promise);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -79,7 +79,6 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
|||
embeddedChannel = new MutableRemoteAddressEmbeddedChannel(
|
||||
new WebsocketHandshakeCompleteHandler(mock(ClientPublicKeysManager.class),
|
||||
Curve.generateKeyPair(),
|
||||
new byte[64],
|
||||
RECOGNIZED_PROXY_SECRET),
|
||||
userEventRecordingHandler);
|
||||
|
||||
|
@ -88,7 +87,7 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
|
||||
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends ChannelHandler> expectedHandlerClass) {
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
|
||||
|
||||
|
@ -102,8 +101,8 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
|||
|
||||
private static List<Arguments> handleWebSocketHandshakeComplete() {
|
||||
return List.of(
|
||||
Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
|
||||
Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
|
||||
Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseAuthenticatedHandler.class),
|
||||
Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseAnonymousHandler.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -130,5 +130,7 @@ linkDevice.secret: AAAAAAAAAAA=
|
|||
|
||||
tlsKeyStore.password: unset
|
||||
|
||||
noiseTunnel.noiseStaticPrivateKey: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
# The below private key was generated exclusively for testing purposes. Do not use it in any other context.
|
||||
# Corresponding public key: cYUAFtkWK/4x3AfW/yw7qgIo/mQUaRSWaPolGQkiL14=
|
||||
noiseTunnel.noiseStaticPrivateKey: qK5FD9WmuhoLPsS/Z4swcZkwDn9OpeM5ZmcEVMpEQ24=
|
||||
noiseTunnel.recognizedProxySecret: ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789AAAAAAA
|
||||
|
|
|
@ -452,7 +452,6 @@ callingTurnManualTable:
|
|||
noiseTunnel:
|
||||
port: 8443
|
||||
noiseStaticPrivateKey: secret://noiseTunnel.noiseStaticPrivateKey
|
||||
noiseRootPublicKeySignature: ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz
|
||||
recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret
|
||||
|
||||
externalRequestFilter:
|
||||
|
|
Loading…
Reference in New Issue