Update noise-gRPC protocol errors
This commit is contained in:
parent
b8d5b2c8ea
commit
0cc5431867
|
@ -10,8 +10,12 @@ import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
|
|
||||||
public class DeviceIdUtil {
|
public class DeviceIdUtil {
|
||||||
|
|
||||||
|
public static boolean isValid(int deviceId) {
|
||||||
|
return deviceId >= Device.PRIMARY_ID && deviceId <= Byte.MAX_VALUE;
|
||||||
|
}
|
||||||
|
|
||||||
static byte validate(int deviceId) {
|
static byte validate(int deviceId) {
|
||||||
if (deviceId < Device.PRIMARY_ID || deviceId > Byte.MAX_VALUE) {
|
if (!isValid(deviceId)) {
|
||||||
throw Status.INVALID_ARGUMENT.withDescription("Device ID is out of range").asRuntimeException();
|
throw Status.INVALID_ARGUMENT.withDescription("Device ID is out of range").asRuntimeException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Indicates that an attempt to authenticate a remote client failed for some reason.
|
|
||||||
*/
|
|
||||||
public class ClientAuthenticationException extends NoStackTraceException {
|
|
||||||
}
|
|
|
@ -1,10 +1,9 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
import io.netty.channel.ChannelDuplexHandler;
|
|
||||||
import io.netty.channel.ChannelFutureListener;
|
import io.netty.channel.ChannelFutureListener;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import javax.crypto.BadPaddingException;
|
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||||
|
@ -16,9 +15,6 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||||
public class ErrorHandler extends ChannelInboundHandlerAdapter {
|
public class ErrorHandler extends ChannelInboundHandlerAdapter {
|
||||||
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
|
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
|
||||||
|
|
||||||
private static OutboundCloseErrorMessage UNAUTHENTICATED_CLOSE = new OutboundCloseErrorMessage(
|
|
||||||
OutboundCloseErrorMessage.Code.AUTHENTICATION_ERROR,
|
|
||||||
"Not authenticated");
|
|
||||||
private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage(
|
private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage(
|
||||||
OutboundCloseErrorMessage.Code.NOISE_ERROR,
|
OutboundCloseErrorMessage.Code.NOISE_ERROR,
|
||||||
"Noise encryption error");
|
"Noise encryption error");
|
||||||
|
@ -29,7 +25,6 @@ public class ErrorHandler extends ChannelInboundHandlerAdapter {
|
||||||
case NoiseHandshakeException e -> new OutboundCloseErrorMessage(
|
case NoiseHandshakeException e -> new OutboundCloseErrorMessage(
|
||||||
OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR,
|
OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR,
|
||||||
e.getMessage());
|
e.getMessage());
|
||||||
case ClientAuthenticationException ignored -> UNAUTHENTICATED_CLOSE;
|
|
||||||
case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
||||||
case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
||||||
default -> {
|
default -> {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import io.netty.channel.ChannelInitializer;
|
||||||
import io.netty.channel.local.LocalAddress;
|
import io.netty.channel.local.LocalAddress;
|
||||||
import io.netty.channel.local.LocalChannel;
|
import io.netty.channel.local.LocalChannel;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import java.net.InetAddress;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
@ -48,7 +49,9 @@ public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAd
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
|
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
|
||||||
if (event instanceof NoiseIdentityDeterminedEvent(final Optional<AuthenticatedDevice> authenticatedDevice)) {
|
if (event instanceof NoiseIdentityDeterminedEvent(
|
||||||
|
final Optional<AuthenticatedDevice> authenticatedDevice,
|
||||||
|
InetAddress remoteAddress, String userAgent, String acceptLanguage)) {
|
||||||
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
|
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
|
||||||
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
|
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
|
||||||
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
|
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
|
||||||
|
@ -57,6 +60,9 @@ public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAd
|
||||||
? authenticatedGrpcServerAddress
|
? authenticatedGrpcServerAddress
|
||||||
: anonymousGrpcServerAddress;
|
: anonymousGrpcServerAddress;
|
||||||
|
|
||||||
|
GrpcClientConnectionManager.handleHandshakeInitiated(
|
||||||
|
remoteChannelContext.channel(), remoteAddress, userAgent, acceptLanguage);
|
||||||
|
|
||||||
new Bootstrap()
|
new Bootstrap()
|
||||||
.remoteAddress(grpcServerAddress)
|
.remoteAddress(grpcServerAddress)
|
||||||
.channel(LocalChannel.class)
|
.channel(LocalChannel.class)
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.concurrent.CompletableFuture;
|
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A NoiseAnonymousHandler is a netty pipeline element that handles the responder side of an unauthenticated handshake
|
|
||||||
* and noise encryption/decryption.
|
|
||||||
* <p>
|
|
||||||
* A noise NK handshake must be used for unauthenticated connections. Optionally, the initiator can also include an
|
|
||||||
* initial request in their payload. If provided, this allows the server to begin processing the request without an
|
|
||||||
* initial message delay (fast open).
|
|
||||||
* <p>
|
|
||||||
* Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent}
|
|
||||||
* indicating that initiator connected anonymously.
|
|
||||||
*/
|
|
||||||
public class NoiseAnonymousHandler extends NoiseHandler {
|
|
||||||
|
|
||||||
public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) {
|
|
||||||
super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
CompletableFuture<HandshakeResult> handleHandshakePayload(final ChannelHandlerContext context,
|
|
||||||
final Optional<byte[]> initiatorPublicKey, final ByteBuf handshakePayload) {
|
|
||||||
return CompletableFuture.completedFuture(new HandshakeResult(
|
|
||||||
handshakePayload,
|
|
||||||
Optional.empty()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,96 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
|
||||||
import io.netty.util.ReferenceCountUtil;
|
|
||||||
import java.security.MessageDigest;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.UUID;
|
|
||||||
import java.util.concurrent.CompletableFuture;
|
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
|
||||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A NoiseAuthenticatedHandler is a netty pipeline element that handles the responder side of an authenticated handshake
|
|
||||||
* and noise encryption/decryption. Authenticated handshakes are noise IK handshakes where the initiator's static public
|
|
||||||
* key is authenticated by the responder.
|
|
||||||
* <p>
|
|
||||||
* The authenticated handshake requires the initiator to provide a payload with their first handshake message that
|
|
||||||
* includes their account identifier and device id in network byte-order. Optionally, the initiator can also include an
|
|
||||||
* initial request in their payload. If provided, this allows the server to begin processing the request without an
|
|
||||||
* initial message delay (fast open).
|
|
||||||
* <pre>
|
|
||||||
* +-----------------+----------------+------------------------+
|
|
||||||
* | UUID (16) | deviceId (1) | request bytes (N) |
|
|
||||||
* +-----------------+----------------+------------------------+
|
|
||||||
* </pre>
|
|
||||||
* <p>
|
|
||||||
* For a successful handshake, the static key provided in the handshake message must match the server's stored public
|
|
||||||
* key for the device identified by the provided ACI and deviceId.
|
|
||||||
* <p>
|
|
||||||
* As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}.
|
|
||||||
*/
|
|
||||||
public class NoiseAuthenticatedHandler extends NoiseHandler {
|
|
||||||
|
|
||||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
|
||||||
|
|
||||||
public NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
|
||||||
final ECKeyPair ecKeyPair) {
|
|
||||||
super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair));
|
|
||||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
CompletableFuture<HandshakeResult> handleHandshakePayload(
|
|
||||||
final ChannelHandlerContext context,
|
|
||||||
final Optional<byte[]> initiatorPublicKey,
|
|
||||||
final ByteBuf handshakePayload) throws NoiseHandshakeException {
|
|
||||||
if (handshakePayload.readableBytes() < 17) {
|
|
||||||
throw new NoiseHandshakeException("Invalid handshake payload");
|
|
||||||
}
|
|
||||||
|
|
||||||
final byte[] publicKeyFromClient = initiatorPublicKey
|
|
||||||
.orElseThrow(() -> new IllegalStateException("No remote public key"));
|
|
||||||
|
|
||||||
// Advances the read index by 16 bytes
|
|
||||||
final UUID accountIdentifier = parseUUID(handshakePayload);
|
|
||||||
|
|
||||||
// Advances the read index by 1 byte
|
|
||||||
final byte deviceId = handshakePayload.readByte();
|
|
||||||
|
|
||||||
final ByteBuf fastOpenRequest = handshakePayload.slice();
|
|
||||||
return clientPublicKeysManager
|
|
||||||
.findPublicKey(accountIdentifier, deviceId)
|
|
||||||
.handleAsync((storedPublicKey, throwable) -> {
|
|
||||||
if (throwable != null) {
|
|
||||||
ReferenceCountUtil.release(fastOpenRequest);
|
|
||||||
throw ExceptionUtils.wrap(throwable);
|
|
||||||
}
|
|
||||||
final boolean valid = storedPublicKey
|
|
||||||
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
|
|
||||||
.orElse(false);
|
|
||||||
if (!valid) {
|
|
||||||
throw ExceptionUtils.wrap(new ClientAuthenticationException());
|
|
||||||
}
|
|
||||||
return new HandshakeResult(
|
|
||||||
fastOpenRequest,
|
|
||||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
|
|
||||||
}, context.executor());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a {@link UUID} out of bytes, advancing the readerIdx by 16
|
|
||||||
*
|
|
||||||
* @param bytes The {@link ByteBuf} to read from
|
|
||||||
* @return The parsed UUID
|
|
||||||
* @throws NoiseHandshakeException If a UUID could not be parsed from bytes
|
|
||||||
*/
|
|
||||||
private UUID parseUUID(final ByteBuf bytes) throws NoiseHandshakeException {
|
|
||||||
if (bytes.readableBytes() < 16) {
|
|
||||||
throw new NoiseHandshakeException("Could not parse account identifier");
|
|
||||||
}
|
|
||||||
return new UUID(bytes.readLong(), bytes.readLong());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -11,159 +11,59 @@ import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufUtil;
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.ChannelDuplexHandler;
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
import io.netty.channel.ChannelFutureListener;
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelPromise;
|
import io.netty.channel.ChannelPromise;
|
||||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import io.netty.util.concurrent.PromiseCombiner;
|
import io.netty.util.concurrent.PromiseCombiner;
|
||||||
import io.netty.util.internal.EmptyArrays;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.concurrent.CompletableFuture;
|
|
||||||
import javax.crypto.BadPaddingException;
|
import javax.crypto.BadPaddingException;
|
||||||
import javax.crypto.ShortBufferException;
|
import javax.crypto.ShortBufferException;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
|
|
||||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts
|
* A bidirectional {@link io.netty.channel.ChannelHandler} that decrypts inbound messages, and encrypts outbound
|
||||||
* inbound messages, and encrypts outbound messages
|
* messages
|
||||||
*/
|
*/
|
||||||
public abstract class NoiseHandler extends ChannelDuplexHandler {
|
public class NoiseHandler extends ChannelDuplexHandler {
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
|
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
|
||||||
|
private final CipherStatePair cipherStatePair;
|
||||||
|
|
||||||
private enum State {
|
NoiseHandler(CipherStatePair cipherStatePair) {
|
||||||
// Waiting for handshake to complete
|
this.cipherStatePair = cipherStatePair;
|
||||||
HANDSHAKE,
|
|
||||||
// Can freely exchange encrypted noise messages on an established session
|
|
||||||
TRANSPORT,
|
|
||||||
// Finished with error
|
|
||||||
ERROR
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private final NoiseHandshakeHelper handshakeHelper;
|
|
||||||
|
|
||||||
private State state = State.HANDSHAKE;
|
|
||||||
private CipherStatePair cipherStatePair;
|
|
||||||
|
|
||||||
NoiseHandler(NoiseHandshakeHelper handshakeHelper) {
|
|
||||||
this.handshakeHelper = handshakeHelper;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The result of processing an initiator handshake payload
|
|
||||||
*
|
|
||||||
* @param fastOpenRequest A fast-open request included in the handshake. If none was present, this should be an
|
|
||||||
* empty ByteBuf
|
|
||||||
* @param authenticatedDevice If present, the successfully authenticated initiator identity
|
|
||||||
*/
|
|
||||||
record HandshakeResult(ByteBuf fastOpenRequest, Optional<AuthenticatedDevice> authenticatedDevice) {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse and potentially authenticate the initiator handshake message
|
|
||||||
*
|
|
||||||
* @param context A {@link ChannelHandlerContext}
|
|
||||||
* @param initiatorPublicKey The initiator's static public key, if a handshake pattern that includes it was used
|
|
||||||
* @param handshakePayload The handshake payload provided in the initiator message
|
|
||||||
* @return A {@link HandshakeResult} that includes an authenticated device and a parsed fast-open request if one was
|
|
||||||
* present in the handshake payload.
|
|
||||||
* @throws NoiseHandshakeException If the handshake payload was invalid
|
|
||||||
* @throws ClientAuthenticationException If the initiatorPublicKey could not be authenticated
|
|
||||||
*/
|
|
||||||
abstract CompletableFuture<HandshakeResult> handleHandshakePayload(
|
|
||||||
final ChannelHandlerContext context,
|
|
||||||
final Optional<byte[]> initiatorPublicKey,
|
|
||||||
final ByteBuf handshakePayload) throws NoiseHandshakeException, ClientAuthenticationException;
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
try {
|
try {
|
||||||
if (message instanceof ByteBuf frame) {
|
if (message instanceof ByteBuf frame) {
|
||||||
if (frame.readableBytes() > Noise.MAX_PACKET_LEN) {
|
if (frame.readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||||
final String error = "Invalid noise message length " + frame.readableBytes();
|
throw new NoiseException("Invalid noise message length " + frame.readableBytes());
|
||||||
throw state == State.HANDSHAKE ? new NoiseHandshakeException(error) : new NoiseException(error);
|
|
||||||
}
|
}
|
||||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||||
// We'll need to copy it to a heap buffer.
|
// We'll need to copy it to a heap buffer.
|
||||||
handleInboundMessage(context, ByteBufUtil.getBytes(frame));
|
handleInboundDataMessage(context, ByteBufUtil.getBytes(frame));
|
||||||
} else {
|
} else {
|
||||||
// Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error
|
// Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error
|
||||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
|
||||||
fail(context, e);
|
|
||||||
} finally {
|
} finally {
|
||||||
ReferenceCountUtil.release(message);
|
ReferenceCountUtil.release(message);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleInboundMessage(final ChannelHandlerContext context, final byte[] frameBytes)
|
|
||||||
throws NoiseHandshakeException, ShortBufferException, BadPaddingException, ClientAuthenticationException {
|
|
||||||
switch (state) {
|
|
||||||
|
|
||||||
// Got an initiator handshake message
|
private void handleInboundDataMessage(final ChannelHandlerContext context, final byte[] frameBytes)
|
||||||
case HANDSHAKE -> {
|
throws ShortBufferException, BadPaddingException {
|
||||||
final ByteBuf payload = handshakeHelper.read(frameBytes);
|
final CipherState cipherState = cipherStatePair.getReceiver();
|
||||||
handleHandshakePayload(context, handshakeHelper.remotePublicKey(), payload).whenCompleteAsync(
|
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
|
||||||
(result, throwable) -> {
|
final int plaintextLength = cipherState.decryptWithAd(null,
|
||||||
if (state == State.ERROR) {
|
frameBytes, 0,
|
||||||
return;
|
frameBytes, 0,
|
||||||
}
|
frameBytes.length);
|
||||||
if (throwable != null) {
|
|
||||||
fail(context, ExceptionUtils.unwrap(throwable));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(result.authenticatedDevice()));
|
|
||||||
|
|
||||||
// Now that we've authenticated, write the handshake response
|
// Forward the decrypted plaintext along
|
||||||
byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES);
|
context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength));
|
||||||
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
|
|
||||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
|
||||||
|
|
||||||
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
|
|
||||||
this.state = State.TRANSPORT;
|
|
||||||
this.cipherStatePair = handshakeHelper.getHandshakeState().split();
|
|
||||||
if (result.fastOpenRequest().isReadable()) {
|
|
||||||
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
|
|
||||||
// encrypt the response when the server writes back through us
|
|
||||||
context.fireChannelRead(result.fastOpenRequest());
|
|
||||||
} else {
|
|
||||||
ReferenceCountUtil.release(result.fastOpenRequest());
|
|
||||||
}
|
|
||||||
}, context.executor());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Got a client message that should be decrypted and forwarded
|
|
||||||
case TRANSPORT -> {
|
|
||||||
final CipherState cipherState = cipherStatePair.getReceiver();
|
|
||||||
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
|
|
||||||
final int plaintextLength = cipherState.decryptWithAd(null,
|
|
||||||
frameBytes, 0,
|
|
||||||
frameBytes, 0,
|
|
||||||
frameBytes.length);
|
|
||||||
|
|
||||||
// Forward the decrypted plaintext along
|
|
||||||
context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength));
|
|
||||||
}
|
|
||||||
|
|
||||||
// The session is already in an error state, drop the message
|
|
||||||
case ERROR -> {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the state to the error state (so subsequent messages fast-fail) and propagate the failure reason on the
|
|
||||||
* context
|
|
||||||
*/
|
|
||||||
private void fail(final ChannelHandlerContext context, final Throwable cause) {
|
|
||||||
this.state = State.ERROR;
|
|
||||||
context.fireExceptionCaught(cause);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -208,4 +108,12 @@ public abstract class NoiseHandler extends ChannelDuplexHandler {
|
||||||
context.write(message, promise);
|
context.write(message, promise);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved(ChannelHandlerContext var1) {
|
||||||
|
if (cipherStatePair != null) {
|
||||||
|
cipherStatePair.destroy();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,197 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufInputStream;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.InetAddress;
|
||||||
|
import java.security.MessageDigest;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.DeviceIdUtil;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||||
|
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles the responder side of a noise handshake and then replaces itself with a {@link NoiseHandler} which will
|
||||||
|
* encrypt/decrypt subsequent data frames
|
||||||
|
* <p>
|
||||||
|
* The handler expects to receive a single inbound message, a {@link NoiseHandshakeInit} that includes the initiator
|
||||||
|
* handshake message, connection metadata, and the type of handshake determined by the framing layer. This handler
|
||||||
|
* currently supports two types of handshakes.
|
||||||
|
* <p>
|
||||||
|
* The first are IK handshakes where the initiator's static public key is authenticated by the responder. The initiator
|
||||||
|
* handshake message must contain the ACI and deviceId of the initiator. To be authenticated, the static key provided in
|
||||||
|
* the handshake message must match the server's stored public key for the device identified by the provided ACI and
|
||||||
|
* deviceId.
|
||||||
|
* <p>
|
||||||
|
* The second are NK handshakes which are anonymous.
|
||||||
|
* <p>
|
||||||
|
* Optionally, the initiator can also include an initial request in their payload. If provided, this allows the server
|
||||||
|
* to begin processing the request without an initial message delay (fast open).
|
||||||
|
* <p>
|
||||||
|
* Once the handshake has been validated, a {@link NoiseIdentityDeterminedEvent} will be fired. For an IK handshake,
|
||||||
|
* this will include the {@link org.whispersystems.textsecuregcm.auth.AuthenticatedDevice} of the initiator. This
|
||||||
|
* handler will then replace itself with a {@link NoiseHandler} with a noise state pair ready to encrypt/decrypt data
|
||||||
|
* frames.
|
||||||
|
*/
|
||||||
|
public class NoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private static final byte[] HANDSHAKE_WRONG_PK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||||
|
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY)
|
||||||
|
.build().toByteArray();
|
||||||
|
private static final byte[] HANDSHAKE_OK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||||
|
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
|
||||||
|
.build().toByteArray();
|
||||||
|
|
||||||
|
// We might get additional messages while we're waiting to process a handshake, so keep track of where we are
|
||||||
|
private boolean receivedHandshakeInit = false;
|
||||||
|
|
||||||
|
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||||
|
private final ECKeyPair ecKeyPair;
|
||||||
|
|
||||||
|
public NoiseHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) {
|
||||||
|
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||||
|
this.ecKeyPair = ecKeyPair;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
|
try {
|
||||||
|
if (!(message instanceof NoiseHandshakeInit handshakeInit)) {
|
||||||
|
// Anything except HandshakeInit should have been filtered out of the pipeline by now; treat this as an error
|
||||||
|
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||||
|
}
|
||||||
|
if (receivedHandshakeInit) {
|
||||||
|
throw new NoiseHandshakeException("Should not receive messages until handshake complete");
|
||||||
|
}
|
||||||
|
receivedHandshakeInit = true;
|
||||||
|
|
||||||
|
if (handshakeInit.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||||
|
throw new NoiseHandshakeException("Invalid noise message length " + handshakeInit.content().readableBytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||||
|
// We'll need to copy it to a heap buffer
|
||||||
|
handleInboundHandshake(context,
|
||||||
|
handshakeInit.getRemoteAddress(),
|
||||||
|
handshakeInit.getHandshakePattern(),
|
||||||
|
ByteBufUtil.getBytes(handshakeInit.content()));
|
||||||
|
} finally {
|
||||||
|
ReferenceCountUtil.release(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void handleInboundHandshake(
|
||||||
|
final ChannelHandlerContext context,
|
||||||
|
final InetAddress remoteAddress,
|
||||||
|
final HandshakePattern handshakePattern,
|
||||||
|
final byte[] frameBytes) throws NoiseHandshakeException {
|
||||||
|
final NoiseHandshakeHelper handshakeHelper = new NoiseHandshakeHelper(handshakePattern, ecKeyPair);
|
||||||
|
final ByteBuf payload = handshakeHelper.read(frameBytes);
|
||||||
|
|
||||||
|
// Parse the handshake message
|
||||||
|
final NoiseTunnelProtos.HandshakeInit handshakeInit;
|
||||||
|
try {
|
||||||
|
handshakeInit = NoiseTunnelProtos.HandshakeInit.parseFrom(new ByteBufInputStream(payload));
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new NoiseHandshakeException("Failed to parse handshake message");
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (handshakePattern) {
|
||||||
|
case NK -> {
|
||||||
|
if (handshakeInit.getDeviceId() != 0 || !handshakeInit.getAci().isEmpty()) {
|
||||||
|
throw new NoiseHandshakeException("Anonymous handshake should not include identifiers");
|
||||||
|
}
|
||||||
|
handleAuthenticated(context, handshakeHelper, remoteAddress, handshakeInit, Optional.empty());
|
||||||
|
}
|
||||||
|
case IK -> {
|
||||||
|
final byte[] publicKeyFromClient = handshakeHelper.remotePublicKey()
|
||||||
|
.orElseThrow(() -> new IllegalStateException("No remote public key"));
|
||||||
|
final UUID accountIdentifier = aci(handshakeInit);
|
||||||
|
final byte deviceId = deviceId(handshakeInit);
|
||||||
|
clientPublicKeysManager
|
||||||
|
.findPublicKey(accountIdentifier, deviceId)
|
||||||
|
.whenCompleteAsync((storedPublicKey, throwable) -> {
|
||||||
|
if (throwable != null) {
|
||||||
|
context.fireExceptionCaught(ExceptionUtils.unwrap(throwable));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
final boolean valid = storedPublicKey
|
||||||
|
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
|
||||||
|
.orElse(false);
|
||||||
|
if (!valid) {
|
||||||
|
// Write a handshake response indicating that the client used the wrong public key
|
||||||
|
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_WRONG_PK);
|
||||||
|
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
|
||||||
|
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||||
|
|
||||||
|
context.fireExceptionCaught(new NoiseHandshakeException("Bad public key"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
handleAuthenticated(context,
|
||||||
|
handshakeHelper, remoteAddress, handshakeInit,
|
||||||
|
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
|
||||||
|
}, context.executor());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private void handleAuthenticated(final ChannelHandlerContext context,
|
||||||
|
final NoiseHandshakeHelper handshakeHelper,
|
||||||
|
final InetAddress remoteAddress,
|
||||||
|
final NoiseTunnelProtos.HandshakeInit handshakeInit,
|
||||||
|
final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
|
||||||
|
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(
|
||||||
|
maybeAuthenticatedDevice,
|
||||||
|
remoteAddress,
|
||||||
|
handshakeInit.getUserAgent(),
|
||||||
|
handshakeInit.getAcceptLanguage()));
|
||||||
|
|
||||||
|
// Now that we've authenticated, write the handshake response
|
||||||
|
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_OK);
|
||||||
|
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
|
||||||
|
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||||
|
|
||||||
|
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
|
||||||
|
// Note: It may be tempting to swap the before/remove for a replace, but then when we forward the fast open
|
||||||
|
// request it will go through the NoiseHandler. We want to skip the NoiseHandler because we've already
|
||||||
|
// decrypted the fastOpen request
|
||||||
|
context.pipeline()
|
||||||
|
.addBefore(context.name(), null, new NoiseHandler(handshakeHelper.getHandshakeState().split()));
|
||||||
|
context.pipeline().remove(NoiseHandshakeHandler.class);
|
||||||
|
if (!handshakeInit.getFastOpenRequest().isEmpty()) {
|
||||||
|
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
|
||||||
|
// encrypt the response when the server writes back through us
|
||||||
|
context.fireChannelRead(Unpooled.wrappedBuffer(handshakeInit.getFastOpenRequest().asReadOnlyByteBuffer()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static UUID aci(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
|
||||||
|
try {
|
||||||
|
return UUIDUtil.fromByteString(handshakePayload.getAci());
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
throw new NoiseHandshakeException("Could not parse aci");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static byte deviceId(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
|
||||||
|
if (!DeviceIdUtil.isValid(handshakePayload.getDeviceId())) {
|
||||||
|
throw new NoiseHandshakeException("Invalid deviceId");
|
||||||
|
}
|
||||||
|
return (byte) handshakePayload.getDeviceId();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.DefaultByteBufHolder;
|
||||||
|
import java.net.InetAddress;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A message that includes the initiator's handshake message, connection metadata, and the handshake type. The metadata
|
||||||
|
* and handshake type are extracted from the framing layer, so this allows receivers to be framing layer agnostic.
|
||||||
|
*/
|
||||||
|
public class NoiseHandshakeInit extends DefaultByteBufHolder {
|
||||||
|
|
||||||
|
private final InetAddress remoteAddress;
|
||||||
|
private final HandshakePattern handshakePattern;
|
||||||
|
|
||||||
|
public NoiseHandshakeInit(
|
||||||
|
final InetAddress remoteAddress,
|
||||||
|
final HandshakePattern handshakePattern,
|
||||||
|
final ByteBuf initiatorHandshakeMessage) {
|
||||||
|
super(initiatorHandshakeMessage);
|
||||||
|
this.remoteAddress = remoteAddress;
|
||||||
|
this.handshakePattern = handshakePattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
public InetAddress getRemoteAddress() {
|
||||||
|
return remoteAddress;
|
||||||
|
}
|
||||||
|
|
||||||
|
public HandshakePattern getHandshakePattern() {
|
||||||
|
return handshakePattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import java.net.InetAddress;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
|
|
||||||
|
@ -9,5 +10,12 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
*
|
*
|
||||||
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
|
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
|
||||||
* type that performs authentication
|
* type that performs authentication
|
||||||
|
* @param remoteAddress the remote address of the connecting client
|
||||||
|
* @param userAgent the client supplied userAgent
|
||||||
|
* @param acceptLanguage the client supplied acceptLanguage
|
||||||
*/
|
*/
|
||||||
public record NoiseIdentityDeterminedEvent(Optional<AuthenticatedDevice> authenticatedDevice) {}
|
public record NoiseIdentityDeterminedEvent(
|
||||||
|
Optional<AuthenticatedDevice> authenticatedDevice,
|
||||||
|
InetAddress remoteAddress,
|
||||||
|
String userAgent,
|
||||||
|
String acceptLanguage) {}
|
||||||
|
|
|
@ -9,6 +9,7 @@ package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
*/
|
*/
|
||||||
public record OutboundCloseErrorMessage(Code code, String message) {
|
public record OutboundCloseErrorMessage(Code code, String message) {
|
||||||
public enum Code {
|
public enum Code {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The server decided to close the connection. This could be because the server is going away, or it could be
|
* The server decided to close the connection. This could be because the server is going away, or it could be
|
||||||
* because the credentials for the connected client have been updated.
|
* because the credentials for the connected client have been updated.
|
||||||
|
@ -25,11 +26,6 @@ public record OutboundCloseErrorMessage(Code code, String message) {
|
||||||
*/
|
*/
|
||||||
NOISE_HANDSHAKE_ERROR,
|
NOISE_HANDSHAKE_ERROR,
|
||||||
|
|
||||||
/**
|
|
||||||
* The provided credentials were not valid
|
|
||||||
*/
|
|
||||||
AUTHENTICATION_ERROR,
|
|
||||||
|
|
||||||
INTERNAL_SERVER_ERROR
|
INTERNAL_SERVER_ERROR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,12 +30,12 @@ public class NoiseDirectFrame extends DefaultByteBufHolder {
|
||||||
|
|
||||||
public enum FrameType {
|
public enum FrameType {
|
||||||
/**
|
/**
|
||||||
* The payload is the initiator message or the responder message for a Noise NK handshake. If established, the
|
* The payload is the initiator message for a Noise NK handshake. If established, the
|
||||||
* session will be unauthenticated.
|
* session will be unauthenticated.
|
||||||
*/
|
*/
|
||||||
NK_HANDSHAKE((byte) 1),
|
NK_HANDSHAKE((byte) 1),
|
||||||
/**
|
/**
|
||||||
* The payload is the initiator message or the responder message for a Noise IK handshake. If established, the
|
* The payload is the initiator message for a Noise IK handshake. If established, the
|
||||||
* session will be authenticated.
|
* session will be authenticated.
|
||||||
*/
|
*/
|
||||||
IK_HANDSHAKE((byte) 2),
|
IK_HANDSHAKE((byte) 2),
|
||||||
|
@ -44,9 +44,10 @@ public class NoiseDirectFrame extends DefaultByteBufHolder {
|
||||||
*/
|
*/
|
||||||
DATA((byte) 3),
|
DATA((byte) 3),
|
||||||
/**
|
/**
|
||||||
* A framing layer error occurred. The payload carries error details.
|
* A frame sent before the connection is closed. The payload is a protobuf indicating why the connection is being
|
||||||
|
* closed.
|
||||||
*/
|
*/
|
||||||
ERROR((byte) 4);
|
CLOSE((byte) 4);
|
||||||
|
|
||||||
private final byte frameType;
|
private final byte frameType;
|
||||||
|
|
||||||
|
@ -64,7 +65,7 @@ public class NoiseDirectFrame extends DefaultByteBufHolder {
|
||||||
public boolean isHandshake() {
|
public boolean isHandshake() {
|
||||||
return switch (this) {
|
return switch (this) {
|
||||||
case IK_HANDSHAKE, NK_HANDSHAKE -> true;
|
case IK_HANDSHAKE, NK_HANDSHAKE -> true;
|
||||||
case DATA, ERROR -> false;
|
case DATA, CLOSE -> false;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,7 @@ public class NoiseDirectFrameCodec extends ChannelDuplexHandler {
|
||||||
case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE;
|
case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE;
|
||||||
case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE;
|
case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE;
|
||||||
case 3 -> NoiseDirectFrame.FrameType.DATA;
|
case 3 -> NoiseDirectFrame.FrameType.DATA;
|
||||||
case 4 -> NoiseDirectFrame.FrameType.ERROR;
|
case 4 -> NoiseDirectFrame.FrameType.CLOSE;
|
||||||
default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits);
|
default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -4,66 +4,44 @@
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||||
|
|
||||||
import io.netty.channel.ChannelDuplexHandler;
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.InetSocketAddress;
|
import java.net.InetSocketAddress;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Waits for a Handshake {@link NoiseDirectFrame} and then installs a {@link NoiseDirectDataFrameCodec} and
|
* Waits for a Handshake {@link NoiseDirectFrame} and then replaces itself with a {@link NoiseDirectDataFrameCodec} and
|
||||||
* {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} and removes itself
|
* forwards the handshake frame along as a {@link NoiseHandshakeInit} message
|
||||||
*/
|
*/
|
||||||
public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter {
|
public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
|
||||||
private final ECKeyPair ecKeyPair;
|
|
||||||
|
|
||||||
public NoiseDirectHandshakeSelector(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) {
|
|
||||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
|
||||||
this.ecKeyPair = ecKeyPair;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||||
if (msg instanceof NoiseDirectFrame frame) {
|
if (msg instanceof NoiseDirectFrame frame) {
|
||||||
try {
|
try {
|
||||||
// We've received an inbound handshake frame so we know what kind of NoiseHandler we need (authenticated or
|
if (!(ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress)) {
|
||||||
// anonymous). We construct it here, and then remember the handshake type so we can annotate our handshake
|
|
||||||
// response with the correct frame type whenever we receive it.
|
|
||||||
final ChannelDuplexHandler noiseHandler = switch (frame.frameType()) {
|
|
||||||
case DATA, ERROR ->
|
|
||||||
throw new NoiseHandshakeException("Invalid frame type for first message " + frame.frameType());
|
|
||||||
case IK_HANDSHAKE -> new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair);
|
|
||||||
case NK_HANDSHAKE -> new NoiseAnonymousHandler(ecKeyPair);
|
|
||||||
};
|
|
||||||
if (ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) {
|
|
||||||
// TODO: Provide connection metadata / headers in handshake payload
|
|
||||||
GrpcClientConnectionManager.handleHandshakeInitiated(ctx.channel(),
|
|
||||||
inetSocketAddress.getAddress(),
|
|
||||||
"NoiseDirect",
|
|
||||||
"");
|
|
||||||
|
|
||||||
} else {
|
|
||||||
throw new IOException("Could not determine remote address");
|
throw new IOException("Could not determine remote address");
|
||||||
}
|
}
|
||||||
|
// We've received an inbound handshake frame. Pull the framing-protocol specific data the downstream handler
|
||||||
|
// needs into a NoiseHandshakeInit message and forward that along
|
||||||
|
final NoiseHandshakeInit handshakeMessage = new NoiseHandshakeInit(inetSocketAddress.getAddress(),
|
||||||
|
switch (frame.frameType()) {
|
||||||
|
case DATA -> throw new NoiseHandshakeException("First message must have handshake frame type");
|
||||||
|
case CLOSE -> throw new IllegalStateException("Close frames should not reach handshake selector");
|
||||||
|
case IK_HANDSHAKE -> HandshakePattern.IK;
|
||||||
|
case NK_HANDSHAKE -> HandshakePattern.NK;
|
||||||
|
}, frame.content());
|
||||||
|
|
||||||
// Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames
|
// Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames
|
||||||
// should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded
|
// should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded
|
||||||
// for network serialization. Note that we need to install the Data frame handler before firing the read,
|
// for network serialization. Note that we need to install the Data frame handler before firing the read,
|
||||||
// because we may receive an outbound message from the noiseHandler
|
// because we may receive an outbound message from the noiseHandler
|
||||||
ctx.pipeline().addAfter(ctx.name(), null, noiseHandler);
|
|
||||||
ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec());
|
ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec());
|
||||||
ctx.fireChannelRead(frame.content());
|
ctx.fireChannelRead(handshakeMessage);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
ReferenceCountUtil.release(msg);
|
ReferenceCountUtil.release(msg);
|
||||||
throw e;
|
throw e;
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||||
|
|
||||||
|
import io.micrometer.core.instrument.Metrics;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Watches for inbound close frames and closes the connection in response
|
||||||
|
*/
|
||||||
|
public class NoiseDirectInboundCloseHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(ChannelInboundHandlerAdapter.class, "clientClose");
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||||
|
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
|
||||||
|
try {
|
||||||
|
final NoiseDirectProtos.CloseReason closeReason = NoiseDirectProtos.CloseReason
|
||||||
|
.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||||
|
|
||||||
|
Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "reason", closeReason.getCode().name()).increment();
|
||||||
|
} finally {
|
||||||
|
ReferenceCountUtil.release(msg);
|
||||||
|
ctx.close();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ctx.fireChannelRead(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,20 +17,19 @@ class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter {
|
||||||
@Override
|
@Override
|
||||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||||
if (msg instanceof OutboundCloseErrorMessage err) {
|
if (msg instanceof OutboundCloseErrorMessage err) {
|
||||||
final NoiseDirectProtos.Error.Type type = switch (err.code()) {
|
final NoiseDirectProtos.CloseReason.Code code = switch (err.code()) {
|
||||||
case SERVER_CLOSED -> NoiseDirectProtos.Error.Type.UNAVAILABLE;
|
case SERVER_CLOSED -> NoiseDirectProtos.CloseReason.Code.UNAVAILABLE;
|
||||||
case NOISE_ERROR -> NoiseDirectProtos.Error.Type.ENCRYPTION_ERROR;
|
case NOISE_ERROR -> NoiseDirectProtos.CloseReason.Code.ENCRYPTION_ERROR;
|
||||||
case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.Error.Type.HANDSHAKE_ERROR;
|
case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.CloseReason.Code.HANDSHAKE_ERROR;
|
||||||
case AUTHENTICATION_ERROR -> NoiseDirectProtos.Error.Type.AUTHENTICATION_ERROR;
|
case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.CloseReason.Code.INTERNAL_ERROR;
|
||||||
case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.Error.Type.INTERNAL_ERROR;
|
|
||||||
};
|
};
|
||||||
final NoiseDirectProtos.Error proto = NoiseDirectProtos.Error.newBuilder()
|
final NoiseDirectProtos.CloseReason proto = NoiseDirectProtos.CloseReason.newBuilder()
|
||||||
.setType(type)
|
.setCode(code)
|
||||||
.setMessage(err.message())
|
.setMessage(err.message())
|
||||||
.build();
|
.build();
|
||||||
final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize());
|
final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize());
|
||||||
proto.writeTo(new ByteBufOutputStream(byteBuf));
|
proto.writeTo(new ByteBufOutputStream(byteBuf));
|
||||||
ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.ERROR, byteBuf))
|
ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.CLOSE, byteBuf))
|
||||||
.addListener(ChannelFutureListener.CLOSE);
|
.addListener(ChannelFutureListener.CLOSE);
|
||||||
} else {
|
} else {
|
||||||
ctx.write(msg, promise);
|
ctx.write(msg, promise);
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler;
|
import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler;
|
import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeHandler;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler;
|
import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler;
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
|
@ -50,18 +51,20 @@ public class NoiseDirectTunnelServer implements Managed {
|
||||||
protected void initChannel(SocketChannel socketChannel) {
|
protected void initChannel(SocketChannel socketChannel) {
|
||||||
socketChannel.pipeline()
|
socketChannel.pipeline()
|
||||||
.addLast(new ProxyProtocolDetectionHandler())
|
.addLast(new ProxyProtocolDetectionHandler())
|
||||||
.addLast(new HAProxyMessageHandler());
|
.addLast(new HAProxyMessageHandler())
|
||||||
|
|
||||||
socketChannel.pipeline()
|
|
||||||
// frame byte followed by a 2-byte length field
|
// frame byte followed by a 2-byte length field
|
||||||
.addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2))
|
.addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2))
|
||||||
// Parses NoiseDirectFrames from wire bytes and vice versa
|
// Parses NoiseDirectFrames from wire bytes and vice versa
|
||||||
.addLast(new NoiseDirectFrameCodec())
|
.addLast(new NoiseDirectFrameCodec())
|
||||||
|
// Terminate the connection if the client sends us a close frame
|
||||||
|
.addLast(new NoiseDirectInboundCloseHandler())
|
||||||
// Turn generic OutboundCloseErrorMessages into noise direct error frames
|
// Turn generic OutboundCloseErrorMessages into noise direct error frames
|
||||||
.addLast(new NoiseDirectOutboundErrorHandler())
|
.addLast(new NoiseDirectOutboundErrorHandler())
|
||||||
// Waits for the handshake to finish and then replaces itself with a NoiseDirectFrameCodec and a
|
// Forwards the first payload supplemented with handshake metadata, and then replaces itself with a
|
||||||
// NoiseHandler to handle noise encryption/decryption
|
// NoiseDirectDataFrameCodec to handle subsequent data frames
|
||||||
.addLast(new NoiseDirectHandshakeSelector(clientPublicKeysManager, ecKeyPair))
|
.addLast(new NoiseDirectHandshakeSelector())
|
||||||
|
// Performs the noise handshake and then replace itself with a NoiseHandler
|
||||||
|
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
|
||||||
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||||
// once the Noise handshake has completed
|
// once the Noise handshake has completed
|
||||||
.addLast(new EstablishLocalGrpcConnectionHandler(
|
.addLast(new EstablishLocalGrpcConnectionHandler(
|
||||||
|
|
|
@ -4,8 +4,7 @@ import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
|
|
||||||
enum ApplicationWebSocketCloseReason {
|
enum ApplicationWebSocketCloseReason {
|
||||||
NOISE_HANDSHAKE_ERROR(4001),
|
NOISE_HANDSHAKE_ERROR(4001),
|
||||||
CLIENT_AUTHENTICATION_ERROR(4002),
|
NOISE_ENCRYPTION_ERROR(4002);
|
||||||
NOISE_ENCRYPTION_ERROR(4003);
|
|
||||||
|
|
||||||
private final int statusCode;
|
private final int statusCode;
|
||||||
|
|
||||||
|
|
|
@ -108,9 +108,12 @@ public class NoiseWebSocketTunnelServer implements Managed {
|
||||||
.addLast(new WebSocketOutboundErrorHandler())
|
.addLast(new WebSocketOutboundErrorHandler())
|
||||||
.addLast(new RejectUnsupportedMessagesHandler())
|
.addLast(new RejectUnsupportedMessagesHandler())
|
||||||
.addLast(new WebsocketPayloadCodec())
|
.addLast(new WebsocketPayloadCodec())
|
||||||
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
|
// The WebSocket handshake complete listener will forward the first payload supplemented with
|
||||||
// a WebSocket handshake has been completed
|
// data from the websocket handshake completion event, and then remove itself from the pipeline
|
||||||
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret))
|
.addLast(new WebsocketHandshakeCompleteHandler(recognizedProxySecret))
|
||||||
|
// The NoiseHandshakeHandler will perform the noise handshake and then replace itself with a
|
||||||
|
// NoiseHandler
|
||||||
|
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
|
||||||
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||||
// once the Noise handshake has completed
|
// once the Noise handshake has completed
|
||||||
.addLast(new EstablishLocalGrpcConnectionHandler(grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
.addLast(new EstablishLocalGrpcConnectionHandler(grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
||||||
|
|
|
@ -7,14 +7,9 @@ import io.netty.channel.ChannelPromise;
|
||||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||||
import javax.crypto.BadPaddingException;
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.ClientAuthenticationException;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames
|
* Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames
|
||||||
|
@ -46,7 +41,6 @@ class WebSocketOutboundErrorHandler extends ChannelDuplexHandler {
|
||||||
case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code();
|
case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code();
|
||||||
case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode();
|
case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode();
|
||||||
case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode();
|
case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode();
|
||||||
case AUTHENTICATION_ERROR -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode();
|
|
||||||
case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
|
case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
|
||||||
};
|
};
|
||||||
ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise)
|
ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise)
|
||||||
|
|
|
@ -2,14 +2,14 @@ package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import com.google.common.net.InetAddresses;
|
import com.google.common.net.InetAddresses;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.channel.ChannelFutureListener;
|
import io.netty.channel.ChannelFutureListener;
|
||||||
import io.netty.channel.ChannelHandler;
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
|
||||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import java.net.InetAddress;
|
import java.net.InetAddress;
|
||||||
import java.net.InetSocketAddress;
|
import java.net.InetSocketAddress;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
@ -17,13 +17,10 @@ import java.security.MessageDigest;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import javax.annotation.Nullable;
|
import javax.annotation.Nullable;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler;
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
|
* A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
|
||||||
|
@ -31,10 +28,6 @@ import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
*/
|
*/
|
||||||
class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
|
||||||
|
|
||||||
private final ECKeyPair ecKeyPair;
|
|
||||||
|
|
||||||
private final byte[] recognizedProxySecret;
|
private final byte[] recognizedProxySecret;
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
|
private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
|
||||||
|
@ -45,12 +38,10 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
|
static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
|
||||||
|
|
||||||
WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
private InetAddress remoteAddress = null;
|
||||||
final ECKeyPair ecKeyPair,
|
private HandshakePattern handshakePattern = null;
|
||||||
final String recognizedProxySecret) {
|
|
||||||
|
|
||||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
WebsocketHandshakeCompleteHandler(final String recognizedProxySecret) {
|
||||||
this.ecKeyPair = ecKeyPair;
|
|
||||||
|
|
||||||
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
|
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
|
||||||
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
|
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
|
||||||
|
@ -61,53 +52,58 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||||
@Override
|
@Override
|
||||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||||
final InetAddress preferredRemoteAddress;
|
final Optional<InetAddress> maybePreferredRemoteAddress =
|
||||||
{
|
getPreferredRemoteAddress(context, handshakeCompleteEvent);
|
||||||
final Optional<InetAddress> maybePreferredRemoteAddress =
|
|
||||||
getPreferredRemoteAddress(context, handshakeCompleteEvent);
|
|
||||||
|
|
||||||
if (maybePreferredRemoteAddress.isEmpty()) {
|
if (maybePreferredRemoteAddress.isEmpty()) {
|
||||||
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR,
|
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR,
|
||||||
"Could not determine remote address"))
|
"Could not determine remote address"))
|
||||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
|
|
||||||
preferredRemoteAddress = maybePreferredRemoteAddress.get();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GrpcClientConnectionManager.handleHandshakeInitiated(context.channel(),
|
remoteAddress = maybePreferredRemoteAddress.get();
|
||||||
preferredRemoteAddress,
|
handshakePattern = switch (handshakeCompleteEvent.requestUri()) {
|
||||||
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
|
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> HandshakePattern.IK;
|
||||||
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));
|
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> HandshakePattern.NK;
|
||||||
|
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
|
||||||
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
|
// internal error if something slipped through.
|
||||||
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH ->
|
default -> throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
|
||||||
new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair);
|
|
||||||
|
|
||||||
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH ->
|
|
||||||
new NoiseAnonymousHandler(ecKeyPair);
|
|
||||||
|
|
||||||
default -> {
|
|
||||||
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
|
|
||||||
// internal error if something slipped through.
|
|
||||||
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
context.pipeline().replace(WebsocketHandshakeCompleteHandler.this, null, noiseHandshakeHandler);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
context.fireUserEventTriggered(event);
|
context.fireUserEventTriggered(event);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object msg) {
|
||||||
|
try {
|
||||||
|
if (!(msg instanceof ByteBuf frame)) {
|
||||||
|
throw new IllegalStateException("Unexpected msg type: " + msg.getClass());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (handshakePattern == null || remoteAddress == null) {
|
||||||
|
throw new IllegalStateException("Received payload before websocket handshake complete");
|
||||||
|
}
|
||||||
|
|
||||||
|
final NoiseHandshakeInit handshakeMessage =
|
||||||
|
new NoiseHandshakeInit(remoteAddress, handshakePattern, frame);
|
||||||
|
|
||||||
|
context.pipeline().remove(WebsocketHandshakeCompleteHandler.class);
|
||||||
|
context.fireChannelRead(handshakeMessage);
|
||||||
|
} catch (Exception e) {
|
||||||
|
ReferenceCountUtil.release(msg);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerContext context,
|
private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerContext context,
|
||||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||||
|
|
||||||
final byte[] recognizedProxySecretFromHeader =
|
final byte[] recognizedProxySecretFromHeader =
|
||||||
handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "")
|
handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "")
|
||||||
.getBytes(StandardCharsets.UTF_8);
|
.getBytes(StandardCharsets.UTF_8);
|
||||||
|
|
||||||
final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader);
|
final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader);
|
||||||
|
|
||||||
|
|
|
@ -8,15 +8,46 @@ syntax = "proto3";
|
||||||
option java_package = "org.whispersystems.textsecuregcm.grpc.net.noisedirect";
|
option java_package = "org.whispersystems.textsecuregcm.grpc.net.noisedirect";
|
||||||
option java_outer_classname = "NoiseDirectProtos";
|
option java_outer_classname = "NoiseDirectProtos";
|
||||||
|
|
||||||
message Error {
|
message CloseReason {
|
||||||
enum Type {
|
enum Code {
|
||||||
UNSPECIFIED = 0;
|
UNSPECIFIED = 0;
|
||||||
HANDSHAKE_ERROR = 1;
|
// Indicates non-error termination
|
||||||
ENCRYPTION_ERROR = 2;
|
// Examples:
|
||||||
UNAVAILABLE = 3;
|
// - The client is finished with the connection
|
||||||
INTERNAL_ERROR = 4;
|
OK = 1;
|
||||||
AUTHENTICATION_ERROR = 5;
|
|
||||||
|
// There was an issue with the handshake. If sent after a handshake response,
|
||||||
|
// the response includes more information about the nature of the error
|
||||||
|
// Examples:
|
||||||
|
// - The client did not provide a handshake message
|
||||||
|
// - The client had incorrect authentication credentials. The handshake
|
||||||
|
// payload includes additional details
|
||||||
|
HANDSHAKE_ERROR = 2;
|
||||||
|
|
||||||
|
// There was an encryption/decryption issue after the handshake
|
||||||
|
// Examples:
|
||||||
|
// - The client incorrectly encrypted a noise message and it had a bad
|
||||||
|
// AEAD tag
|
||||||
|
ENCRYPTION_ERROR = 3;
|
||||||
|
|
||||||
|
// The server is temporarily unavailable, going away, or requires a
|
||||||
|
// connection reset
|
||||||
|
// Examples:
|
||||||
|
// - The server is shutting down
|
||||||
|
// - The client’s authentication credentials have been rotated
|
||||||
|
UNAVAILABLE = 4;
|
||||||
|
|
||||||
|
// There was an an internal error
|
||||||
|
// Examples:
|
||||||
|
// - The server experienced a temporary database outage that prevented it
|
||||||
|
// from checking the client's credentials
|
||||||
|
INTERNAL_ERROR = 5;
|
||||||
}
|
}
|
||||||
Type type = 1;
|
|
||||||
|
Code code = 1;
|
||||||
|
|
||||||
|
// If present, includes details about the error. Implementations should never
|
||||||
|
// parse or otherwise implement conditional logic based on the contents of the
|
||||||
|
// error message string, it is for logging and debugging purposes only.
|
||||||
string message = 2;
|
string message = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
option java_package = "org.whispersystems.textsecuregcm.grpc.net";
|
||||||
|
option java_outer_classname = "NoiseTunnelProtos";
|
||||||
|
|
||||||
|
message HandshakeInit {
|
||||||
|
string user_agent = 1;
|
||||||
|
|
||||||
|
// An Accept-Language as described in
|
||||||
|
// https://httpwg.org/specs/rfc9110.html#field.accept-language
|
||||||
|
string accept_language = 2;
|
||||||
|
|
||||||
|
// A UUID serialized as 16 bytes (big end first). Must be unset (empty) for an
|
||||||
|
// unauthenticated handshake
|
||||||
|
bytes aci = 3;
|
||||||
|
|
||||||
|
// The deviceId, 0 < deviceId < 128. Must be unset for an unauthenticated
|
||||||
|
// handshake
|
||||||
|
uint32 device_id = 4;
|
||||||
|
|
||||||
|
// The first bytes of the application request byte stream, may contain less
|
||||||
|
// than a full request
|
||||||
|
bytes fast_open_request = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message HandshakeResponse {
|
||||||
|
enum Code {
|
||||||
|
UNSPECIFIED = 0;
|
||||||
|
|
||||||
|
// The noise session may be used to send application layer requests
|
||||||
|
OK = 1;
|
||||||
|
|
||||||
|
// The provided client static key did not match the registered public key
|
||||||
|
// for the provided aci/deviceId.
|
||||||
|
WRONG_PUBLIC_KEY = 2;
|
||||||
|
|
||||||
|
// The client version is to old, it should be upgraded before retrying
|
||||||
|
DEPRECATED = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The handshake outcome
|
||||||
|
Code code = 1;
|
||||||
|
|
||||||
|
// Additional information about an error status, for debugging only
|
||||||
|
string error_details = 2;
|
||||||
|
|
||||||
|
// An optional response to a fast_open_request provided in the HandshakeInit.
|
||||||
|
// Note that a response may not be present even if a fast_open_request was
|
||||||
|
// present. If so, the response will be returned in a later message.
|
||||||
|
bytes fast_open_response = 3;
|
||||||
|
}
|
|
@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
|
||||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
import com.southernstorm.noise.protocol.Noise;
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
|
@ -16,12 +17,13 @@ import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.ChannelFuture;
|
import io.netty.channel.ChannelFuture;
|
||||||
import io.netty.channel.ChannelFutureListener;
|
import io.netty.channel.ChannelFutureListener;
|
||||||
import io.netty.channel.ChannelHandler;
|
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
import io.netty.channel.embedded.EmbeddedChannel;
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import java.net.InetAddress;
|
||||||
|
import java.net.UnknownHostException;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
@ -36,16 +38,29 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.ValueSource;
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
import org.signal.libsignal.protocol.ecc.Curve;
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||||
|
|
||||||
abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
protected ECKeyPair serverKeyPair;
|
protected ECKeyPair serverKeyPair;
|
||||||
|
protected ClientPublicKeysManager clientPublicKeysManager;
|
||||||
|
|
||||||
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
|
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
|
||||||
|
|
||||||
private EmbeddedChannel embeddedChannel;
|
private EmbeddedChannel embeddedChannel;
|
||||||
|
|
||||||
|
static final String USER_AGENT = "Test/User-Agent";
|
||||||
|
static final String ACCEPT_LANGUAGE = "test-lang";
|
||||||
|
static final InetAddress REMOTE_ADDRESS;
|
||||||
|
static {
|
||||||
|
try {
|
||||||
|
REMOTE_ADDRESS = InetAddress.getByAddress(new byte[]{0,1,2,3});
|
||||||
|
} catch (UnknownHostException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static class PongHandler extends ChannelInboundHandlerAdapter {
|
private static class PongHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -93,7 +108,10 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
void setUp() {
|
void setUp() {
|
||||||
serverKeyPair = Curve.generateKeyPair();
|
serverKeyPair = Curve.generateKeyPair();
|
||||||
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
|
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
|
||||||
embeddedChannel = new EmbeddedChannel(getHandler(serverKeyPair), noiseHandshakeCompleteHandler);
|
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||||
|
embeddedChannel = new EmbeddedChannel(
|
||||||
|
new NoiseHandshakeHandler(clientPublicKeysManager, serverKeyPair),
|
||||||
|
noiseHandshakeCompleteHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
|
@ -110,8 +128,6 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
|
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract ChannelHandler getHandler(final ECKeyPair serverKeyPair);
|
|
||||||
|
|
||||||
protected abstract CipherStatePair doHandshake() throws Throwable;
|
protected abstract CipherStatePair doHandshake() throws Throwable;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -140,7 +156,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
|
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
|
||||||
|
|
||||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(content).await();
|
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new NoiseHandshakeInit(REMOTE_ADDRESS, HandshakePattern.IK, content)).await();
|
||||||
|
|
||||||
assertFalse(writeFuture.isSuccess());
|
assertFalse(writeFuture.isSuccess());
|
||||||
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
|
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
|
||||||
|
@ -292,4 +308,19 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big));
|
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big));
|
||||||
assertThrows(NoiseException.class, embeddedChannel::checkException);
|
assertThrows(NoiseException.class, embeddedChannel::checkException);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void channelAttributes() throws Throwable {
|
||||||
|
doHandshake();
|
||||||
|
final NoiseIdentityDeterminedEvent event = getNoiseHandshakeCompleteEvent();
|
||||||
|
assertEquals(REMOTE_ADDRESS, event.remoteAddress());
|
||||||
|
assertEquals(USER_AGENT, event.userAgent());
|
||||||
|
assertEquals(ACCEPT_LANGUAGE, event.acceptLanguage());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected NoiseTunnelProtos.HandshakeInit.Builder baseHandshakeInit() {
|
||||||
|
return NoiseTunnelProtos.HandshakeInit.newBuilder()
|
||||||
|
.setUserAgent(USER_AGENT)
|
||||||
|
.setAcceptLanguage(ACCEPT_LANGUAGE);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -248,7 +248,10 @@ public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractL
|
||||||
} finally {
|
} finally {
|
||||||
channel.shutdown();
|
channel.shutdown();
|
||||||
}
|
}
|
||||||
assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR);
|
assertEquals(
|
||||||
|
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY,
|
||||||
|
client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode());
|
||||||
|
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -269,12 +272,35 @@ public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractL
|
||||||
} finally {
|
} finally {
|
||||||
channel.shutdown();
|
channel.shutdown();
|
||||||
}
|
}
|
||||||
|
assertEquals(
|
||||||
assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR);
|
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY,
|
||||||
|
client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode());
|
||||||
|
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void clientNormalClosure() throws InterruptedException {
|
||||||
|
final NoiseTunnelClient client = anonymous().build();
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
try {
|
||||||
|
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||||
|
|
||||||
|
assertTrue(response.getAccountIdentifier().isEmpty());
|
||||||
|
assertEquals(0, response.getDeviceId());
|
||||||
|
client.close();
|
||||||
|
|
||||||
|
// When we gracefully close the tunnel client, we should send an OK close frame
|
||||||
|
final CloseFrameEvent closeFrame = client.closeFrameFuture().join();
|
||||||
|
assertEquals(CloseFrameEvent.CloseInitiator.CLIENT, closeFrame.closeInitiator());
|
||||||
|
assertEquals(CloseFrameEvent.CloseReason.OK, closeFrame.closeReason());
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void connectAnonymous() throws InterruptedException {
|
void connectAnonymous() throws InterruptedException {
|
||||||
try (final NoiseTunnelClient client = anonymous().build()) {
|
try (final NoiseTunnelClient client = anonymous().build()) {
|
||||||
|
|
|
@ -8,28 +8,23 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
import com.google.protobuf.ByteString;
|
||||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
import com.southernstorm.noise.protocol.HandshakeState;
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.embedded.EmbeddedChannel;
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import javax.crypto.BadPaddingException;
|
import javax.crypto.BadPaddingException;
|
||||||
import javax.crypto.ShortBufferException;
|
import javax.crypto.ShortBufferException;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
|
||||||
|
|
||||||
class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
|
class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
protected NoiseAnonymousHandler getHandler(final ECKeyPair serverKeyPair) {
|
|
||||||
|
|
||||||
return new NoiseAnonymousHandler(serverKeyPair);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected CipherStatePair doHandshake() throws Exception {
|
protected CipherStatePair doHandshake() throws Exception {
|
||||||
return doHandshake(new byte[0]);
|
return doHandshake(baseHandshakeInit().build().toByteArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception {
|
private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception {
|
||||||
|
@ -49,26 +44,35 @@ class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
assertEquals(
|
assertEquals(
|
||||||
initiateHandshakeMessageLength,
|
initiateHandshakeMessageLength,
|
||||||
clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
|
clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
|
||||||
final ByteBuf initiateHandshakeMessageBuf = Unpooled.wrappedBuffer(initiateHandshakeMessage);
|
final NoiseHandshakeInit message = new NoiseHandshakeInit(
|
||||||
assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeMessageBuf).await().isSuccess());
|
REMOTE_ADDRESS,
|
||||||
assertEquals(0, initiateHandshakeMessageBuf.refCnt());
|
HandshakePattern.NK,
|
||||||
|
Unpooled.wrappedBuffer(initiateHandshakeMessage));
|
||||||
|
assertTrue(embeddedChannel.writeOneInbound(message).await().isSuccess());
|
||||||
|
assertEquals(0, message.refCnt());
|
||||||
|
|
||||||
embeddedChannel.runPendingTasks();
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
// Read responder handshake message
|
// Read responder handshake message
|
||||||
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
||||||
final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||||
@SuppressWarnings("DataFlowIssue") final byte[] responderHandshakeBytes =
|
assertNotNull(responderHandshakeFrame);
|
||||||
new byte[responderHandshakeFrame.readableBytes()];
|
final byte[] responderHandshakeBytes = ByteBufUtil.getBytes(responderHandshakeFrame);
|
||||||
responderHandshakeFrame.readBytes(responderHandshakeBytes);
|
|
||||||
|
|
||||||
// ephemeral key, empty encrypted payload AEAD tag
|
final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponse = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||||
final byte[] handshakeResponsePayload = new byte[32 + 16];
|
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
|
||||||
|
.build();
|
||||||
|
|
||||||
assertEquals(0,
|
// ephemeral key, payload, AEAD tag
|
||||||
|
assertEquals(32 + expectedHandshakeResponse.getSerializedSize() + 16, responderHandshakeBytes.length);
|
||||||
|
|
||||||
|
final byte[] handshakeResponsePlaintext = new byte[expectedHandshakeResponse.getSerializedSize()];
|
||||||
|
assertEquals(expectedHandshakeResponse.getSerializedSize(),
|
||||||
clientHandshakeState.readMessage(
|
clientHandshakeState.readMessage(
|
||||||
responderHandshakeBytes, 0, responderHandshakeBytes.length,
|
responderHandshakeBytes, 0, responderHandshakeBytes.length,
|
||||||
handshakeResponsePayload, 0));
|
handshakeResponsePlaintext, 0));
|
||||||
|
|
||||||
|
assertEquals(expectedHandshakeResponse, NoiseTunnelProtos.HandshakeResponse.parseFrom(handshakeResponsePlaintext));
|
||||||
|
|
||||||
final byte[] serverPublicKey = new byte[32];
|
final byte[] serverPublicKey = new byte[32];
|
||||||
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||||
|
@ -78,27 +82,35 @@ class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void handleCompleteHandshakeWithRequest() throws ShortBufferException, BadPaddingException {
|
void handleCompleteHandshakeWithRequest() throws Exception {
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class));
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||||
|
|
||||||
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake("ping".getBytes()));
|
final byte[] handshakePlaintext = baseHandshakeInit()
|
||||||
|
.setFastOpenRequest(ByteString.copyFromUtf8("ping")).build()
|
||||||
|
.toByteArray();
|
||||||
|
|
||||||
|
final CipherStatePair cipherStatePair = doHandshake(handshakePlaintext);
|
||||||
final byte[] response = readNextPlaintext(cipherStatePair);
|
final byte[] response = readNextPlaintext(cipherStatePair);
|
||||||
assertArrayEquals(response, "pong".getBytes());
|
assertArrayEquals(response, "pong".getBytes());
|
||||||
|
|
||||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
|
assertEquals(
|
||||||
|
new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||||
|
getNoiseHandshakeCompleteEvent());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException {
|
void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException {
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAnonymousHandler.class));
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||||
|
|
||||||
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake(new byte[0]));
|
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake());
|
||||||
assertNull(readNextPlaintext(cipherStatePair));
|
assertNull(readNextPlaintext(cipherStatePair));
|
||||||
|
|
||||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
|
assertEquals(
|
||||||
|
new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||||
|
getNoiseHandshakeCompleteEvent());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,18 +8,20 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.verifyNoInteractions;
|
import static org.mockito.Mockito.verifyNoInteractions;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import com.google.protobuf.ByteString;
|
||||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
import com.southernstorm.noise.protocol.HandshakeState;
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
import com.southernstorm.noise.protocol.Noise;
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.ChannelFuture;
|
import io.netty.channel.ChannelFuture;
|
||||||
import io.netty.channel.embedded.EmbeddedChannel;
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
import io.netty.util.internal.EmptyArrays;
|
import io.netty.util.internal.EmptyArrays;
|
||||||
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.security.NoSuchAlgorithmException;
|
import java.security.NoSuchAlgorithmException;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
@ -28,40 +30,24 @@ import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
import javax.crypto.BadPaddingException;
|
import javax.crypto.BadPaddingException;
|
||||||
import javax.crypto.ShortBufferException;
|
import javax.crypto.ShortBufferException;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.signal.libsignal.protocol.ecc.Curve;
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler;
|
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler;
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.Device;
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||||
|
|
||||||
class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
|
|
||||||
private ClientPublicKeysManager clientPublicKeysManager;
|
|
||||||
private final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
private final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
@Override
|
|
||||||
@BeforeEach
|
|
||||||
void setUp() {
|
|
||||||
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
|
||||||
|
|
||||||
super.setUp();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected NoiseAuthenticatedHandler getHandler(final ECKeyPair serverKeyPair) {
|
|
||||||
return new NoiseAuthenticatedHandler(clientPublicKeysManager, serverKeyPair);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected CipherStatePair doHandshake() throws Throwable {
|
protected CipherStatePair doHandshake() throws Throwable {
|
||||||
final UUID accountIdentifier = UUID.randomUUID();
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(1, Device.MAXIMUM_DEVICE_ID + 1);
|
||||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||||
return doHandshake(identityPayload(accountIdentifier, deviceId));
|
return doHandshake(identityPayload(accountIdentifier, deviceId));
|
||||||
|
@ -71,7 +57,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
void handleCompleteHandshakeNoInitialRequest() throws Throwable {
|
void handleCompleteHandshakeNoInitialRequest() throws Throwable {
|
||||||
|
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||||
|
|
||||||
final UUID accountIdentifier = UUID.randomUUID();
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
@ -81,7 +67,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
|
|
||||||
assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId))));
|
assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId))));
|
||||||
|
|
||||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
|
assertEquals(
|
||||||
|
new NoiseIdentityDeterminedEvent(
|
||||||
|
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)),
|
||||||
|
REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||||
getNoiseHandshakeCompleteEvent());
|
getNoiseHandshakeCompleteEvent());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +78,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
void handleCompleteHandshakeWithInitialRequest() throws Throwable {
|
void handleCompleteHandshakeWithInitialRequest() throws Throwable {
|
||||||
|
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||||
|
|
||||||
final UUID accountIdentifier = UUID.randomUUID();
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
@ -97,15 +86,19 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||||
|
|
||||||
final ByteBuffer bb = ByteBuffer.allocate(17 + 4);
|
final byte[] handshakeInit = identifiedHandshakeInit(accountIdentifier, deviceId)
|
||||||
bb.put(identityPayload(accountIdentifier, deviceId));
|
.setFastOpenRequest(ByteString.copyFromUtf8("ping"))
|
||||||
bb.put("ping".getBytes());
|
.build()
|
||||||
|
.toByteArray();
|
||||||
|
|
||||||
final byte[] response = readNextPlaintext(doHandshake(bb.array()));
|
final byte[] response = readNextPlaintext(doHandshake(handshakeInit));
|
||||||
assertEquals(response.length, 4);
|
assertEquals(4, response.length);
|
||||||
assertEquals(new String(response), "pong");
|
assertEquals("pong", new String(response));
|
||||||
|
|
||||||
assertEquals(new NoiseIdentityDeterminedEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
|
assertEquals(
|
||||||
|
new NoiseIdentityDeterminedEvent(
|
||||||
|
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)),
|
||||||
|
REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||||
getNoiseHandshakeCompleteEvent());
|
getNoiseHandshakeCompleteEvent());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,7 +106,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
void handleCompleteHandshakeMissingIdentityInformation() {
|
void handleCompleteHandshakeMissingIdentityInformation() {
|
||||||
|
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||||
|
|
||||||
assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES));
|
assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES));
|
||||||
|
|
||||||
|
@ -121,7 +114,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
|
|
||||||
assertNull(getNoiseHandshakeCompleteEvent());
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||||
|
|
||||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||||
|
@ -132,7 +125,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
void handleCompleteHandshakeMalformedIdentityInformation() {
|
void handleCompleteHandshakeMalformedIdentityInformation() {
|
||||||
|
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||||
|
|
||||||
// no deviceId byte
|
// no deviceId byte
|
||||||
byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID());
|
byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID());
|
||||||
|
@ -142,7 +135,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
|
|
||||||
assertNull(getNoiseHandshakeCompleteEvent());
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||||
|
|
||||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||||
|
@ -150,10 +143,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void handleCompleteHandshakeUnrecognizedDevice() {
|
void handleCompleteHandshakeUnrecognizedDevice() throws Throwable {
|
||||||
|
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||||
|
|
||||||
final UUID accountIdentifier = UUID.randomUUID();
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
@ -161,11 +154,13 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||||
|
|
||||||
assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
|
doHandshake(
|
||||||
|
identityPayload(accountIdentifier, deviceId),
|
||||||
|
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY);
|
||||||
|
|
||||||
assertNull(getNoiseHandshakeCompleteEvent());
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||||
|
|
||||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||||
|
@ -173,10 +168,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void handleCompleteHandshakePublicKeyMismatch() {
|
void handleCompleteHandshakePublicKeyMismatch() throws Throwable {
|
||||||
|
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||||
|
|
||||||
final UUID accountIdentifier = UUID.randomUUID();
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
@ -184,18 +179,21 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
|
||||||
|
|
||||||
assertThrows(ClientAuthenticationException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
|
doHandshake(
|
||||||
|
identityPayload(accountIdentifier, deviceId),
|
||||||
|
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY);
|
||||||
|
|
||||||
assertNull(getNoiseHandshakeCompleteEvent());
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class),
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void handleInvalidExtraWrites() throws NoSuchAlgorithmException, ShortBufferException, InterruptedException {
|
void handleInvalidExtraWrites()
|
||||||
|
throws NoSuchAlgorithmException, ShortBufferException, InterruptedException {
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
assertNotNull(embeddedChannel.pipeline().get(NoiseAuthenticatedHandler.class));
|
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||||
|
|
||||||
final UUID accountIdentifier = UUID.randomUUID();
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
@ -205,25 +203,23 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
||||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
||||||
|
|
||||||
final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer(
|
final NoiseHandshakeInit handshakeInit = new NoiseHandshakeInit(
|
||||||
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId)));
|
REMOTE_ADDRESS,
|
||||||
assertTrue(embeddedChannel.writeOneInbound(initiatorMessageFrame).await().isSuccess());
|
HandshakePattern.IK,
|
||||||
|
Unpooled.wrappedBuffer(
|
||||||
|
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))));
|
||||||
|
assertTrue(embeddedChannel.writeOneInbound(handshakeInit).await().isSuccess());
|
||||||
|
|
||||||
// While waiting for the public key, send another message
|
// While waiting for the public key, send another message
|
||||||
final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await();
|
final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await();
|
||||||
assertInstanceOf(NoiseHandshakeException.class, f.exceptionNow());
|
assertInstanceOf(IllegalArgumentException.class, f.exceptionNow());
|
||||||
|
|
||||||
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
|
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
|
||||||
embeddedChannel.runPendingTasks();
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
// shouldn't return any response or error, we've already processed an error
|
|
||||||
embeddedChannel.checkException();
|
|
||||||
assertNull(embeddedChannel.outboundMessages().poll());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void handleOversizeHandshakeMessage() {
|
public void handleOversizeHandshakeMessage() {
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
|
||||||
final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
|
final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
|
||||||
ByteBuffer.wrap(big)
|
ByteBuffer.wrap(big)
|
||||||
.put(UUIDUtil.toBytes(UUID.randomUUID()))
|
.put(UUIDUtil.toBytes(UUID.randomUUID()))
|
||||||
|
@ -231,6 +227,15 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
assertThrows(NoiseHandshakeException.class, () -> doHandshake(big));
|
assertThrows(NoiseHandshakeException.class, () -> doHandshake(big));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void handleKeyLookupError() throws Throwable {
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||||
|
.thenReturn(CompletableFuture.failedFuture(new IOException()));
|
||||||
|
assertThrows(IOException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
|
||||||
|
}
|
||||||
|
|
||||||
private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException {
|
private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException {
|
||||||
final HandshakeState clientHandshakeState =
|
final HandshakeState clientHandshakeState =
|
||||||
new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
|
new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
|
||||||
|
@ -262,15 +267,22 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
private CipherStatePair doHandshake(final byte[] payload) throws Throwable {
|
private CipherStatePair doHandshake(final byte[] payload) throws Throwable {
|
||||||
|
return doHandshake(payload, NoiseTunnelProtos.HandshakeResponse.Code.OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
private CipherStatePair doHandshake(final byte[] payload, final NoiseTunnelProtos.HandshakeResponse.Code expectedStatus) throws Throwable {
|
||||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
|
||||||
final HandshakeState clientHandshakeState = clientHandshakeState();
|
final HandshakeState clientHandshakeState = clientHandshakeState();
|
||||||
final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
|
final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
|
||||||
|
|
||||||
final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer(initiatorMessage);
|
final NoiseHandshakeInit initMessage = new NoiseHandshakeInit(
|
||||||
final ChannelFuture await = embeddedChannel.writeOneInbound(initiatorMessageFrame).await();
|
REMOTE_ADDRESS,
|
||||||
assertEquals(0, initiatorMessageFrame.refCnt());
|
HandshakePattern.IK,
|
||||||
if (!await.isSuccess()) {
|
Unpooled.wrappedBuffer(initiatorMessage));
|
||||||
|
final ChannelFuture await = embeddedChannel.writeOneInbound(initMessage).await();
|
||||||
|
assertEquals(0, initMessage.refCnt());
|
||||||
|
if (!await.isSuccess() && expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) {
|
||||||
throw await.cause();
|
throw await.cause();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -280,17 +292,27 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
// and issue a "handshake complete" event.
|
// and issue a "handshake complete" event.
|
||||||
embeddedChannel.runPendingTasks();
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
// rethrow if running the task caused an error
|
// rethrow if running the task caused an error, and the caller isn't expecting an error
|
||||||
embeddedChannel.checkException();
|
if (expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) {
|
||||||
|
embeddedChannel.checkException();
|
||||||
|
}
|
||||||
|
|
||||||
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
||||||
|
|
||||||
final ByteBuf serverStaticKeyMessageFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
final ByteBuf handshakeResponseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||||
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
|
assertNotNull(handshakeResponseFrame);
|
||||||
new byte[serverStaticKeyMessageFrame.readableBytes()];
|
final byte[] handshakeResponseCiphertextBytes = ByteBufUtil.getBytes(handshakeResponseFrame);
|
||||||
serverStaticKeyMessageFrame.readBytes(serverStaticKeyMessageBytes);
|
|
||||||
|
|
||||||
assertEquals(readHandshakeResponse(clientHandshakeState, serverStaticKeyMessageBytes).length, 0);
|
final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponsePlaintext = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||||
|
.setCode(expectedStatus)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final byte[] actualHandshakeResponsePlaintext =
|
||||||
|
readHandshakeResponse(clientHandshakeState, handshakeResponseCiphertextBytes);
|
||||||
|
|
||||||
|
assertEquals(
|
||||||
|
expectedHandshakeResponsePlaintext,
|
||||||
|
NoiseTunnelProtos.HandshakeResponse.parseFrom(actualHandshakeResponsePlaintext));
|
||||||
|
|
||||||
final byte[] serverPublicKey = new byte[32];
|
final byte[] serverPublicKey = new byte[32];
|
||||||
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||||
|
@ -299,13 +321,15 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||||
return clientHandshakeState.split();
|
return clientHandshakeState.split();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private NoiseTunnelProtos.HandshakeInit.Builder identifiedHandshakeInit(final UUID accountIdentifier, final byte deviceId) {
|
||||||
|
return baseHandshakeInit()
|
||||||
|
.setAci(UUIDUtil.toByteString(accountIdentifier))
|
||||||
|
.setDeviceId(deviceId);
|
||||||
|
}
|
||||||
|
|
||||||
private static byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) {
|
private byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) {
|
||||||
final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17);
|
return identifiedHandshakeInit(accountIdentifier, deviceId)
|
||||||
clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits());
|
.build()
|
||||||
clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits());
|
.toByteArray();
|
||||||
clientIdentityPayloadBuffer.put(deviceId);
|
|
||||||
clientIdentityPayloadBuffer.flip();
|
|
||||||
return clientIdentityPayloadBuffer.array();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,10 +10,10 @@ import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
|
||||||
public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) {
|
public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) {
|
||||||
|
|
||||||
public enum CloseReason {
|
public enum CloseReason {
|
||||||
|
OK,
|
||||||
SERVER_CLOSED,
|
SERVER_CLOSED,
|
||||||
NOISE_ERROR,
|
NOISE_ERROR,
|
||||||
NOISE_HANDSHAKE_ERROR,
|
NOISE_HANDSHAKE_ERROR,
|
||||||
AUTHENTICATION_ERROR,
|
|
||||||
INTERNAL_SERVER_ERROR,
|
INTERNAL_SERVER_ERROR,
|
||||||
UNKNOWN
|
UNKNOWN
|
||||||
}
|
}
|
||||||
|
@ -27,27 +27,27 @@ public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeIniti
|
||||||
CloseWebSocketFrame closeWebSocketFrame,
|
CloseWebSocketFrame closeWebSocketFrame,
|
||||||
CloseInitiator closeInitiator) {
|
CloseInitiator closeInitiator) {
|
||||||
final CloseReason code = switch (closeWebSocketFrame.statusCode()) {
|
final CloseReason code = switch (closeWebSocketFrame.statusCode()) {
|
||||||
case 4003 -> CloseReason.NOISE_ERROR;
|
|
||||||
case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR;
|
case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR;
|
||||||
case 4002 -> CloseReason.AUTHENTICATION_ERROR;
|
case 4002 -> CloseReason.NOISE_ERROR;
|
||||||
case 1011 -> CloseReason.INTERNAL_SERVER_ERROR;
|
case 1011 -> CloseReason.INTERNAL_SERVER_ERROR;
|
||||||
case 1012 -> CloseReason.SERVER_CLOSED;
|
case 1012 -> CloseReason.SERVER_CLOSED;
|
||||||
|
case 1000 -> CloseReason.OK;
|
||||||
default -> CloseReason.UNKNOWN;
|
default -> CloseReason.UNKNOWN;
|
||||||
};
|
};
|
||||||
return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText());
|
return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static CloseFrameEvent fromNoiseDirectErrorFrame(
|
public static CloseFrameEvent fromNoiseDirectCloseFrame(
|
||||||
NoiseDirectProtos.Error noiseDirectError,
|
NoiseDirectProtos.CloseReason noiseDirectCloseReason,
|
||||||
CloseInitiator closeInitiator) {
|
CloseInitiator closeInitiator) {
|
||||||
final CloseReason code = switch (noiseDirectError.getType()) {
|
final CloseReason code = switch (noiseDirectCloseReason.getCode()) {
|
||||||
|
case OK -> CloseReason.OK;
|
||||||
case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR;
|
case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR;
|
||||||
case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR;
|
case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR;
|
||||||
case UNAVAILABLE -> CloseReason.SERVER_CLOSED;
|
case UNAVAILABLE -> CloseReason.SERVER_CLOSED;
|
||||||
case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR;
|
case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR;
|
||||||
case AUTHENTICATION_ERROR -> CloseReason.AUTHENTICATION_ERROR;
|
|
||||||
case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN;
|
case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN;
|
||||||
};
|
};
|
||||||
return new CloseFrameEvent(code, closeInitiator, noiseDirectError.getMessage());
|
return new CloseFrameEvent(code, closeInitiator, noiseDirectCloseReason.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,12 +11,10 @@ import io.netty.channel.socket.SocketChannel;
|
||||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import java.net.SocketAddress;
|
import java.net.SocketAddress;
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import javax.annotation.Nullable;
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler;
|
import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -31,12 +29,10 @@ import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler;
|
||||||
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
private final List<ChannelHandler> remoteHandlerStack;
|
private final List<ChannelHandler> remoteHandlerStack;
|
||||||
@Nullable
|
private final NoiseTunnelProtos.HandshakeInit handshakeInit;
|
||||||
private final AuthenticatedDevice authenticatedDevice;
|
|
||||||
|
|
||||||
private final SocketAddress remoteServerAddress;
|
private final SocketAddress remoteServerAddress;
|
||||||
// If provided, will be sent with the payload in the noise handshake
|
// If provided, will be sent with the payload in the noise handshake
|
||||||
private final byte[] fastOpenRequest;
|
|
||||||
|
|
||||||
private final List<Object> pendingReads = new ArrayList<>();
|
private final List<Object> pendingReads = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -44,13 +40,11 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
EstablishRemoteConnectionHandler(
|
EstablishRemoteConnectionHandler(
|
||||||
final List<ChannelHandler> remoteHandlerStack,
|
final List<ChannelHandler> remoteHandlerStack,
|
||||||
@Nullable final AuthenticatedDevice authenticatedDevice,
|
|
||||||
final SocketAddress remoteServerAddress,
|
final SocketAddress remoteServerAddress,
|
||||||
@Nullable byte[] fastOpenRequest) {
|
final NoiseTunnelProtos.HandshakeInit handshakeInit) {
|
||||||
this.remoteHandlerStack = remoteHandlerStack;
|
this.remoteHandlerStack = remoteHandlerStack;
|
||||||
this.authenticatedDevice = authenticatedDevice;
|
this.handshakeInit = handshakeInit;
|
||||||
this.remoteServerAddress = remoteServerAddress;
|
this.remoteServerAddress = remoteServerAddress;
|
||||||
this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -72,16 +66,19 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||||
throws Exception {
|
throws Exception {
|
||||||
switch (event) {
|
switch (event) {
|
||||||
case ReadyForNoiseHandshakeEvent ignored ->
|
case ReadyForNoiseHandshakeEvent ignored ->
|
||||||
remoteContext.writeAndFlush(Unpooled.wrappedBuffer(initialPayload()))
|
remoteContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeInit.toByteArray()))
|
||||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||||
case NoiseClientHandshakeCompleteEvent(Optional<byte[]> fastResponse) -> {
|
case NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) -> {
|
||||||
remoteContext.pipeline()
|
remoteContext.pipeline()
|
||||||
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
|
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
|
||||||
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
|
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
|
||||||
|
|
||||||
// If there was a payload response on the handshake, write it back to our gRPC client
|
// If there was a payload response on the handshake, write it back to our gRPC client
|
||||||
fastResponse.ifPresent(plaintext ->
|
if (!handshakeResponse.getFastOpenResponse().isEmpty()) {
|
||||||
localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext)));
|
localContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeResponse
|
||||||
|
.getFastOpenResponse()
|
||||||
|
.asReadOnlyByteBuffer()));
|
||||||
|
}
|
||||||
|
|
||||||
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
|
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
|
||||||
pendingReads.forEach(localContext::fireChannelRead);
|
pendingReads.forEach(localContext::fireChannelRead);
|
||||||
|
@ -120,17 +117,4 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||||
pendingReads.clear();
|
pendingReads.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
private byte[] initialPayload() {
|
|
||||||
if (authenticatedDevice == null) {
|
|
||||||
return fastOpenRequest;
|
|
||||||
}
|
|
||||||
|
|
||||||
final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length);
|
|
||||||
bb.putLong(authenticatedDevice.accountIdentifier().getMostSignificantBits());
|
|
||||||
bb.putLong(authenticatedDevice.accountIdentifier().getLeastSignificantBits());
|
|
||||||
bb.put(authenticatedDevice.deviceId());
|
|
||||||
bb.put(fastOpenRequest);
|
|
||||||
bb.flip();
|
|
||||||
return bb.array();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||||
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -12,4 +14,4 @@ import java.util.Optional;
|
||||||
* @param fastResponse A response if the client included a request to send in the initiate handshake message payload and
|
* @param fastResponse A response if the client included a request to send in the initiate handshake message payload and
|
||||||
* the server included a payload in the handshake response.
|
* the server included a payload in the handshake response.
|
||||||
*/
|
*/
|
||||||
public record NoiseClientHandshakeCompleteEvent(Optional<byte[]> fastResponse) {}
|
public record NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) {}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
|
import com.google.protobuf.InvalidProtocolBufferException;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufUtil;
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
|
@ -7,9 +8,10 @@ import io.netty.channel.ChannelDuplexHandler;
|
||||||
import io.netty.channel.ChannelFutureListener;
|
import io.netty.channel.ChannelFutureListener;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelPromise;
|
import io.netty.channel.ChannelPromise;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
|
||||||
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||||
|
|
||||||
public class NoiseClientHandshakeHandler extends ChannelDuplexHandler {
|
public class NoiseClientHandshakeHandler extends ChannelDuplexHandler {
|
||||||
|
|
||||||
|
@ -38,9 +40,13 @@ public class NoiseClientHandshakeHandler extends ChannelDuplexHandler {
|
||||||
if (message instanceof ByteBuf frame) {
|
if (message instanceof ByteBuf frame) {
|
||||||
try {
|
try {
|
||||||
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame));
|
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame));
|
||||||
final Optional<byte[]> fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload);
|
final NoiseTunnelProtos.HandshakeResponse handshakeResponse =
|
||||||
|
NoiseTunnelProtos.HandshakeResponse.parseFrom(payload);
|
||||||
|
|
||||||
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
|
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
|
||||||
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse));
|
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(handshakeResponse));
|
||||||
|
} catch (InvalidProtocolBufferException e) {
|
||||||
|
throw new NoiseHandshakeException("Failed to parse handshake response");
|
||||||
} finally {
|
} finally {
|
||||||
frame.release();
|
frame.release();
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,9 +8,11 @@ import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.ChannelDuplexHandler;
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelPromise;
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A Noise transport handler manages a bidirectional Noise session after a handshake has completed.
|
* A Noise transport handler manages a bidirectional Noise session after a handshake has completed.
|
||||||
|
@ -72,8 +74,10 @@ public class NoiseClientTransportHandler extends ChannelDuplexHandler {
|
||||||
ReferenceCountUtil.release(plaintext);
|
ReferenceCountUtil.release(plaintext);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Clients only write ByteBufs or close the connection on errors, so any other message is unexpected
|
if (!(message instanceof CloseWebSocketFrame || message instanceof NoiseDirectFrame)) {
|
||||||
log.warn("Unexpected object in pipeline: {}", message);
|
// Clients only write ByteBufs or a close frame on errors, so any other message is unexpected
|
||||||
|
log.warn("Unexpected object in pipeline: {}", message);
|
||||||
|
}
|
||||||
context.write(message, promise);
|
context.write(message, promise);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,21 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||||
|
|
||||||
|
import com.google.protobuf.ByteString;
|
||||||
import com.southernstorm.noise.protocol.Noise;
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
import io.netty.bootstrap.ServerBootstrap;
|
import io.netty.bootstrap.ServerBootstrap;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufAllocator;
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
import io.netty.buffer.ByteBufUtil;
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.channel.*;
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.Channel;
|
||||||
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandler;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.channel.ChannelInitializer;
|
||||||
|
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
import io.netty.channel.local.LocalAddress;
|
import io.netty.channel.local.LocalAddress;
|
||||||
import io.netty.channel.local.LocalChannel;
|
import io.netty.channel.local.LocalChannel;
|
||||||
import io.netty.channel.local.LocalServerChannel;
|
import io.netty.channel.local.LocalServerChannel;
|
||||||
|
@ -17,6 +27,13 @@ import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
|
||||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||||
import io.netty.handler.codec.http.HttpClientCodec;
|
import io.netty.handler.codec.http.HttpClientCodec;
|
||||||
import io.netty.handler.codec.http.HttpHeaders;
|
import io.netty.handler.codec.http.HttpHeaders;
|
||||||
|
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
|
||||||
|
import io.netty.handler.ssl.SslContextBuilder;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import java.net.SocketAddress;
|
import java.net.SocketAddress;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.security.cert.X509Certificate;
|
import java.security.cert.X509Certificate;
|
||||||
|
@ -26,26 +43,21 @@ import java.util.UUID;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
|
import javax.net.ssl.SSLException;
|
||||||
import io.netty.handler.codec.http.HttpObjectAggregator;
|
|
||||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
|
|
||||||
import io.netty.handler.ssl.SslContextBuilder;
|
|
||||||
import io.netty.util.ReferenceCountUtil;
|
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
|
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrameCodec;
|
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrameCodec;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
|
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.websocket.WebsocketPayloadCodec;
|
import org.whispersystems.textsecuregcm.grpc.net.websocket.WebsocketPayloadCodec;
|
||||||
|
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||||
import javax.net.ssl.SSLException;
|
|
||||||
|
|
||||||
public class NoiseTunnelClient implements AutoCloseable {
|
public class NoiseTunnelClient implements AutoCloseable {
|
||||||
|
|
||||||
private final CompletableFuture<CloseFrameEvent> closeEventFuture;
|
private final CompletableFuture<CloseFrameEvent> closeEventFuture;
|
||||||
|
private final CompletableFuture<NoiseClientHandshakeCompleteEvent> handshakeEventFuture;
|
||||||
|
private final CompletableFuture<Void> userCloseFuture;
|
||||||
private final ServerBootstrap serverBootstrap;
|
private final ServerBootstrap serverBootstrap;
|
||||||
private Channel serverChannel;
|
private Channel serverChannel;
|
||||||
|
|
||||||
|
@ -66,11 +78,10 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
FramingType framingType = FramingType.WEBSOCKET;
|
FramingType framingType = FramingType.WEBSOCKET;
|
||||||
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
|
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
|
||||||
HttpHeaders headers = new DefaultHttpHeaders();
|
HttpHeaders headers = new DefaultHttpHeaders();
|
||||||
|
NoiseTunnelProtos.HandshakeInit.Builder handshakeInit = NoiseTunnelProtos.HandshakeInit.newBuilder();
|
||||||
|
|
||||||
boolean authenticated = false;
|
boolean authenticated = false;
|
||||||
ECKeyPair ecKeyPair = null;
|
ECKeyPair ecKeyPair = null;
|
||||||
UUID accountIdentifier = null;
|
|
||||||
byte deviceId = 0x00;
|
|
||||||
boolean useTls;
|
boolean useTls;
|
||||||
X509Certificate trustedServerCertificate = null;
|
X509Certificate trustedServerCertificate = null;
|
||||||
Supplier<HAProxyMessage> proxyMessageSupplier = null;
|
Supplier<HAProxyMessage> proxyMessageSupplier = null;
|
||||||
|
@ -86,8 +97,8 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
|
|
||||||
public Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
|
public Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
|
||||||
this.authenticated = true;
|
this.authenticated = true;
|
||||||
this.accountIdentifier = accountIdentifier;
|
handshakeInit.setAci(UUIDUtil.toByteString(accountIdentifier));
|
||||||
this.deviceId = deviceId;
|
handshakeInit.setDeviceId(deviceId);
|
||||||
this.ecKeyPair = ecKeyPair;
|
this.ecKeyPair = ecKeyPair;
|
||||||
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
|
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
|
||||||
return this;
|
return this;
|
||||||
|
@ -109,6 +120,16 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setUserAgent(final String userAgent) {
|
||||||
|
handshakeInit.setUserAgent(userAgent);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setAcceptLanguage(final String acceptLanguage) {
|
||||||
|
handshakeInit.setAcceptLanguage(acceptLanguage);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public Builder setHeaders(final HttpHeaders headers) {
|
public Builder setHeaders(final HttpHeaders headers) {
|
||||||
this.headers = headers;
|
this.headers = headers;
|
||||||
return this;
|
return this;
|
||||||
|
@ -155,17 +176,41 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
|
|
||||||
handlers.add(new NoiseClientHandshakeHandler(helper));
|
handlers.add(new NoiseClientHandshakeHandler(helper));
|
||||||
|
|
||||||
|
// When the noise handshake completes we'll save the response from the server so client users can inspect it
|
||||||
|
final UserEventFuture<NoiseClientHandshakeCompleteEvent> handshakeEventHandler =
|
||||||
|
new UserEventFuture<>(NoiseClientHandshakeCompleteEvent.class);
|
||||||
|
handlers.add(handshakeEventHandler);
|
||||||
|
|
||||||
// Whenever the framing layer sends or receives a close frame, it will emit a CloseFrameEvent and we'll save off
|
// Whenever the framing layer sends or receives a close frame, it will emit a CloseFrameEvent and we'll save off
|
||||||
// information about why the connection was closed.
|
// information about why the connection was closed.
|
||||||
final UserEventFuture<CloseFrameEvent> closeEventHandler = new UserEventFuture<>(CloseFrameEvent.class);
|
final UserEventFuture<CloseFrameEvent> closeEventHandler = new UserEventFuture<>(CloseFrameEvent.class);
|
||||||
handlers.add(closeEventHandler);
|
handlers.add(closeEventHandler);
|
||||||
|
|
||||||
|
// When the user closes the client, write a normal closure close frame
|
||||||
|
final CompletableFuture<Void> userCloseFuture = new CompletableFuture<>();
|
||||||
|
handlers.add(new ChannelInboundHandlerAdapter() {
|
||||||
|
@Override
|
||||||
|
public void handlerAdded(final ChannelHandlerContext ctx) {
|
||||||
|
userCloseFuture.thenRunAsync(() -> ctx.pipeline().writeAndFlush(switch (framingType) {
|
||||||
|
case WEBSOCKET -> new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE);
|
||||||
|
case NOISE_DIRECT -> new NoiseDirectFrame(
|
||||||
|
NoiseDirectFrame.FrameType.CLOSE,
|
||||||
|
Unpooled.wrappedBuffer(NoiseDirectProtos.CloseReason
|
||||||
|
.newBuilder()
|
||||||
|
.setCode(NoiseDirectProtos.CloseReason.Code.OK)
|
||||||
|
.build()
|
||||||
|
.toByteArray()));
|
||||||
|
})
|
||||||
|
.addListener(ChannelFutureListener.CLOSE),
|
||||||
|
ctx.executor());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
final NoiseTunnelClient client =
|
final NoiseTunnelClient client =
|
||||||
new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, fastOpenRequest -> new EstablishRemoteConnectionHandler(
|
new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, handshakeEventHandler.future, userCloseFuture, fastOpenRequest -> new EstablishRemoteConnectionHandler(
|
||||||
handlers,
|
handlers,
|
||||||
authenticated ? new AuthenticatedDevice(accountIdentifier, deviceId) : null,
|
|
||||||
remoteServerAddress,
|
remoteServerAddress,
|
||||||
fastOpenRequest));
|
handshakeInit.setFastOpenRequest(ByteString.copyFrom(fastOpenRequest)).build()));
|
||||||
client.start();
|
client.start();
|
||||||
return client;
|
return client;
|
||||||
}
|
}
|
||||||
|
@ -173,9 +218,13 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
|
|
||||||
private NoiseTunnelClient(NioEventLoopGroup eventLoopGroup,
|
private NoiseTunnelClient(NioEventLoopGroup eventLoopGroup,
|
||||||
CompletableFuture<CloseFrameEvent> closeEventFuture,
|
CompletableFuture<CloseFrameEvent> closeEventFuture,
|
||||||
|
CompletableFuture<NoiseClientHandshakeCompleteEvent> handshakeEventFuture,
|
||||||
|
CompletableFuture<Void> userCloseFuture,
|
||||||
Function<byte[], EstablishRemoteConnectionHandler> handler) {
|
Function<byte[], EstablishRemoteConnectionHandler> handler) {
|
||||||
|
|
||||||
|
this.userCloseFuture = userCloseFuture;
|
||||||
this.closeEventFuture = closeEventFuture;
|
this.closeEventFuture = closeEventFuture;
|
||||||
|
this.handshakeEventFuture = handshakeEventFuture;
|
||||||
this.serverBootstrap = new ServerBootstrap()
|
this.serverBootstrap = new ServerBootstrap()
|
||||||
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
|
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
|
||||||
.channel(LocalServerChannel.class)
|
.channel(LocalServerChannel.class)
|
||||||
|
@ -194,10 +243,10 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
.addLast(new ChannelInboundHandlerAdapter() {
|
.addLast(new ChannelInboundHandlerAdapter() {
|
||||||
@Override
|
@Override
|
||||||
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
|
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
|
||||||
if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) {
|
if (evt instanceof FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest)) {
|
||||||
byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest());
|
byte[] fastOpenRequestBytes = ByteBufUtil.getBytes(fastOpenRequest);
|
||||||
requestBufferedEvent.fastOpenRequest().release();
|
fastOpenRequest.release();
|
||||||
ctx.pipeline().addLast(handler.apply(fastOpenRequest));
|
ctx.pipeline().addLast(handler.apply(fastOpenRequestBytes));
|
||||||
}
|
}
|
||||||
super.userEventTriggered(ctx, evt);
|
super.userEventTriggered(ctx, evt);
|
||||||
}
|
}
|
||||||
|
@ -216,7 +265,7 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
|
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) {
|
||||||
if (cls.isInstance(evt)) {
|
if (cls.isInstance(evt)) {
|
||||||
future.complete((T) evt);
|
future.complete((T) evt);
|
||||||
}
|
}
|
||||||
|
@ -236,6 +285,7 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() throws InterruptedException {
|
public void close() throws InterruptedException {
|
||||||
|
userCloseFuture.complete(null);
|
||||||
serverChannel.close().await();
|
serverChannel.close().await();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,6 +296,14 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
return closeEventFuture;
|
return closeEventFuture;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return A future that completes when the noise handshake finishes
|
||||||
|
*/
|
||||||
|
public CompletableFuture<NoiseClientHandshakeCompleteEvent> getHandshakeEventFuture() {
|
||||||
|
return handshakeEventFuture;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
private static List<ChannelHandler> noiseDirectHandlerStack(boolean authenticated) {
|
private static List<ChannelHandler> noiseDirectHandlerStack(boolean authenticated) {
|
||||||
return List.of(
|
return List.of(
|
||||||
new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2),
|
new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2),
|
||||||
|
@ -259,12 +317,12 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) {
|
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
|
||||||
try {
|
try {
|
||||||
final NoiseDirectProtos.Error errorPayload =
|
final NoiseDirectProtos.CloseReason closeReason =
|
||||||
NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||||
ctx.fireUserEventTriggered(
|
ctx.fireUserEventTriggered(
|
||||||
CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.SERVER));
|
CloseFrameEvent.fromNoiseDirectCloseFrame(closeReason, CloseFrameEvent.CloseInitiator.SERVER));
|
||||||
} finally {
|
} finally {
|
||||||
ReferenceCountUtil.release(msg);
|
ReferenceCountUtil.release(msg);
|
||||||
}
|
}
|
||||||
|
@ -275,11 +333,11 @@ public class NoiseTunnelClient implements AutoCloseable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) {
|
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
|
||||||
final NoiseDirectProtos.Error errorPayload =
|
final NoiseDirectProtos.CloseReason errorPayload =
|
||||||
NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||||
ctx.fireUserEventTriggered(
|
ctx.fireUserEventTriggered(
|
||||||
CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT));
|
CloseFrameEvent.fromNoiseDirectCloseFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT));
|
||||||
}
|
}
|
||||||
ctx.write(msg, promise);
|
ctx.write(msg, promise);
|
||||||
}
|
}
|
||||||
|
|
|
@ -126,11 +126,13 @@ class TlsWebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelSe
|
||||||
|
|
||||||
final HttpHeaders headers = new DefaultHttpHeaders()
|
final HttpHeaders headers = new DefaultHttpHeaders()
|
||||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||||
.add("X-Forwarded-For", remoteAddress)
|
.add("X-Forwarded-For", remoteAddress);
|
||||||
.add("Accept-Language", acceptLanguage)
|
|
||||||
.add("User-Agent", userAgent);
|
|
||||||
|
|
||||||
try (final NoiseTunnelClient client = anonymous().setHeaders(headers).build()) {
|
try (final NoiseTunnelClient client = anonymous()
|
||||||
|
.setHeaders(headers)
|
||||||
|
.setUserAgent(userAgent)
|
||||||
|
.setAcceptLanguage(acceptLanguage)
|
||||||
|
.build()) {
|
||||||
|
|
||||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
|
|
@ -3,12 +3,10 @@ package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||||
import io.netty.channel.local.LocalAddress;
|
import io.netty.channel.local.LocalAddress;
|
||||||
import io.netty.channel.nio.NioEventLoopGroup;
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
import java.util.concurrent.Executor;
|
import java.util.concurrent.Executor;
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
|
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
|
|
|
@ -6,9 +6,9 @@ import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
|
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
|
||||||
import static org.junit.jupiter.params.provider.Arguments.arguments;
|
import static org.junit.jupiter.params.provider.Arguments.arguments;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
|
|
||||||
import com.google.common.net.InetAddresses;
|
import com.google.common.net.InetAddresses;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.ChannelHandler;
|
import io.netty.channel.ChannelHandler;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
@ -17,7 +17,6 @@ import io.netty.channel.local.LocalAddress;
|
||||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||||
import io.netty.handler.codec.http.HttpHeaders;
|
import io.netty.handler.codec.http.HttpHeaders;
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||||
import io.netty.util.Attribute;
|
|
||||||
import java.net.InetAddress;
|
import java.net.InetAddress;
|
||||||
import java.net.InetSocketAddress;
|
import java.net.InetSocketAddress;
|
||||||
import java.net.SocketAddress;
|
import java.net.SocketAddress;
|
||||||
|
@ -32,13 +31,10 @@ import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.Arguments;
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.signal.libsignal.protocol.ecc.Curve;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
|
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler;
|
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler;
|
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
|
||||||
|
|
||||||
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
|
@ -84,9 +80,7 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
||||||
userEventRecordingHandler = new UserEventRecordingHandler();
|
userEventRecordingHandler = new UserEventRecordingHandler();
|
||||||
|
|
||||||
embeddedChannel = new MutableRemoteAddressEmbeddedChannel(
|
embeddedChannel = new MutableRemoteAddressEmbeddedChannel(
|
||||||
new WebsocketHandshakeCompleteHandler(mock(ClientPublicKeysManager.class),
|
new WebsocketHandshakeCompleteHandler(RECOGNIZED_PROXY_SECRET),
|
||||||
Curve.generateKeyPair(),
|
|
||||||
RECOGNIZED_PROXY_SECRET),
|
|
||||||
userEventRecordingHandler);
|
userEventRecordingHandler);
|
||||||
|
|
||||||
embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0));
|
embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0));
|
||||||
|
@ -94,22 +88,25 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource
|
@MethodSource
|
||||||
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends ChannelHandler> expectedHandlerClass) {
|
void handleWebSocketHandshakeComplete(final String uri, final HandshakePattern pattern) {
|
||||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||||
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
|
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
|
||||||
|
|
||||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||||
|
|
||||||
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
|
||||||
assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
|
|
||||||
|
|
||||||
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
|
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
|
||||||
|
|
||||||
|
final byte[] payload = TestRandomUtil.nextBytes(100);
|
||||||
|
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload));
|
||||||
|
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
||||||
|
final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll();
|
||||||
|
assertNotNull(init);
|
||||||
|
assertEquals(init.getHandshakePattern(), pattern);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Arguments> handleWebSocketHandshakeComplete() {
|
private static List<Arguments> handleWebSocketHandshakeComplete() {
|
||||||
return List.of(
|
return List.of(
|
||||||
Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseAuthenticatedHandler.class),
|
Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, HandshakePattern.IK),
|
||||||
Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseAnonymousHandler.class));
|
Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, HandshakePattern.NK));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -141,13 +138,19 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
||||||
embeddedChannel.setRemoteAddress(remoteAddress);
|
embeddedChannel.setRemoteAddress(remoteAddress);
|
||||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||||
|
|
||||||
|
final byte[] payload = TestRandomUtil.nextBytes(100);
|
||||||
|
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload));
|
||||||
assertEquals(expectedRemoteAddress,
|
final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll();
|
||||||
Optional.ofNullable(embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY))
|
assertEquals(
|
||||||
.map(Attribute::get)
|
expectedRemoteAddress,
|
||||||
.map(RequestAttributes::remoteAddress)
|
Optional.ofNullable(init)
|
||||||
|
.map(NoiseHandshakeInit::getRemoteAddress)
|
||||||
.orElse(null));
|
.orElse(null));
|
||||||
|
if (expectedRemoteAddress == null) {
|
||||||
|
assertThrows(IllegalStateException.class, embeddedChannel::checkException);
|
||||||
|
} else {
|
||||||
|
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Arguments> getRemoteAddress() {
|
private static List<Arguments> getRemoteAddress() {
|
||||||
|
|
Loading…
Reference in New Issue