diff --git a/pom.xml b/pom.xml
index 2ba7feaa9..1c5011d46 100644
--- a/pom.xml
+++ b/pom.xml
@@ -286,6 +286,11 @@
libsignal-server
0.39.0
+
+ org.signal.forks
+ noise-java
+ 0.1.0
+
org.apache.logging.log4j
log4j-bom
diff --git a/service/config/sample.yml b/service/config/sample.yml
index e7a32878f..9f6ec4d7d 100644
--- a/service/config/sample.yml
+++ b/service/config/sample.yml
@@ -40,8 +40,6 @@ metrics:
- ^lettuce\..+$
reportOnStop: true
-grpcPort: 8080
-
tlsKeyStore:
password: secret://tlsKeyStore.password
diff --git a/service/pom.xml b/service/pom.xml
index 41fe438f5..bd9f692cd 100644
--- a/service/pom.xml
+++ b/service/pom.xml
@@ -52,6 +52,11 @@
libsignal-server
+
+ org.signal.forks
+ noise-java
+
+
io.dropwizard
dropwizard-core
@@ -242,8 +247,7 @@
io.grpc
- grpc-netty-shaded
- runtime
+ grpc-netty
io.grpc
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java
index 3c60e922e..450ac7581 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java
@@ -298,11 +298,6 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private TusConfiguration tus;
- @Valid
- @NotNull
- @JsonProperty
- private int grpcPort;
-
@Valid
@NotNull
@JsonProperty
@@ -539,10 +534,6 @@ public class WhisperServerConfiguration extends Configuration {
return tus;
}
- public int getGrpcPort() {
- return grpcPort;
- }
-
public ClientReleaseConfiguration getClientReleaseConfiguration() {
return clientRelease;
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index eb8b680d5..edacc7973 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -21,13 +21,13 @@ import io.dropwizard.core.setup.Bootstrap;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.jetty.HttpsConnectorFactory;
import io.grpc.ServerBuilder;
-import io.grpc.ServerInterceptors;
import io.lettuce.core.metrics.MicrometerCommandLatencyRecorder;
import io.lettuce.core.metrics.MicrometerOptions;
import io.lettuce.core.resource.ClientResources;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.grpc.MetricCollectingServerInterceptor;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
+import io.netty.channel.local.LocalAddress;
import java.net.http.HttpClient;
import java.time.Clock;
import java.time.Duration;
@@ -134,13 +134,14 @@ import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService;
import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor;
import org.whispersystems.textsecuregcm.grpc.ExternalServiceCredentialsAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.ExternalServiceCredentialsGrpcService;
-import org.whispersystems.textsecuregcm.grpc.GrpcServerManagedWrapper;
import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
import org.whispersystems.textsecuregcm.grpc.PaymentsGrpcService;
import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService;
import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor;
+import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
+import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.PushChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
@@ -753,20 +754,67 @@ public class WhisperServerService extends Application grpcServer = ServerBuilder.forPort(config.getGrpcPort())
- .addService(ServerInterceptors.intercept(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager), basicCredentialAuthenticationInterceptor))
- .addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
- .addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
- .addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config))
- .addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keysManager, rateLimiters), basicCredentialAuthenticationInterceptor))
- .addService(new KeysAnonymousGrpcService(accountsManager, keysManager))
- .addService(new PaymentsGrpcService(currencyManager))
- .addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
- config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()), basicCredentialAuthenticationInterceptor))
- .addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkProfileOperations));
+ final ManagedDefaultEventLoopGroup localEventLoopGroup = new ManagedDefaultEventLoopGroup();
+
+ final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);
+ final MetricCollectingServerInterceptor metricCollectingServerInterceptor =
+ new MetricCollectingServerInterceptor(Metrics.globalRegistry);
+
+ final ErrorMappingInterceptor errorMappingInterceptor = new ErrorMappingInterceptor();
+ final AcceptLanguageInterceptor acceptLanguageInterceptor = new AcceptLanguageInterceptor();
+ final UserAgentInterceptor userAgentInterceptor = new UserAgentInterceptor();
+
+ final LocalAddress anonymousGrpcServerAddress = new LocalAddress("grpc-anonymous");
+ final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("grpc-authenticated");
+
+ final ManagedLocalGrpcServer anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, localEventLoopGroup) {
+ @Override
+ protected void configureServer(final ServerBuilder> serverBuilder) {
+ // Note: interceptors run in the reverse order they are added; the remote deprecation filter
+ // depends on the user-agent context so it has to come first here!
+ // http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
+ serverBuilder
+ // TODO: specialize metrics with user-agent platform
+ .intercept(metricCollectingServerInterceptor)
+ .intercept(errorMappingInterceptor)
+ .intercept(acceptLanguageInterceptor)
+ .intercept(remoteDeprecationFilter)
+ .intercept(userAgentInterceptor)
+ .addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
+ .addService(new KeysAnonymousGrpcService(accountsManager, keysManager))
+ .addService(new PaymentsGrpcService(currencyManager))
+ .addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config))
+ .addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkProfileOperations));
+ }
+ };
+
+ final ManagedLocalGrpcServer authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, localEventLoopGroup) {
+ @Override
+ protected void configureServer(final ServerBuilder> serverBuilder) {
+ // Note: interceptors run in the reverse order they are added; the remote deprecation filter
+ // depends on the user-agent context so it has to come first here!
+ // http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
+ serverBuilder
+ // TODO: specialize metrics with user-agent platform
+ .intercept(metricCollectingServerInterceptor)
+ .intercept(errorMappingInterceptor)
+ .intercept(acceptLanguageInterceptor)
+ .intercept(remoteDeprecationFilter)
+ .intercept(userAgentInterceptor)
+ .intercept(new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager)))
+ .addService(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager))
+ .addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
+ .addService(new KeysGrpcService(accountsManager, keysManager, rateLimiters))
+ .addService(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
+ config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()));
+ }
+ };
+
+ environment.lifecycle().manage(localEventLoopGroup);
+ environment.lifecycle().manage(anonymousGrpcServer);
+ environment.lifecycle().manage(authenticatedGrpcServer);
final List filters = new ArrayList<>();
- final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);
filters.add(remoteDeprecationFilter);
filters.add(new RemoteAddressFilter(useRemoteAddress));
@@ -776,19 +824,6 @@ public class WhisperServerService extends Application accountAuthFilter =
new BasicCredentialAuthFilter.Builder()
.setAuthenticator(accountAuthenticator)
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/GrpcServerManagedWrapper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/GrpcServerManagedWrapper.java
deleted file mode 100644
index 1b88c290d..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/GrpcServerManagedWrapper.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Copyright 2023 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-
-package org.whispersystems.textsecuregcm.grpc;
-
-import java.io.IOException;
-import java.util.concurrent.TimeUnit;
-
-import io.dropwizard.lifecycle.Managed;
-import io.grpc.Server;
-
-public class GrpcServerManagedWrapper implements Managed {
-
- private final Server server;
-
- public GrpcServerManagedWrapper(final Server server) {
- this.server = server;
- }
-
- @Override
- public void start() throws IOException {
- server.start();
- }
-
- @Override
- public void stop() {
- try {
- server.shutdown().awaitTermination(5, TimeUnit.MINUTES);
- } catch (InterruptedException e) {
- server.shutdownNow();
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandler.java
new file mode 100644
index 000000000..41594644d
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandler.java
@@ -0,0 +1,124 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import com.southernstorm.noise.protocol.HandshakeState;
+import io.netty.buffer.ByteBufUtil;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import io.netty.util.internal.EmptyArrays;
+import java.security.NoSuchAlgorithmException;
+import javax.crypto.BadPaddingException;
+import javax.crypto.ShortBufferException;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+
+/**
+ * An abstract base class for XX- and NX-patterned Noise responder handshake handlers.
+ *
+ * @see The Noise Protocol Framework
+ */
+abstract class AbstractNoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
+
+ private final ECKeyPair ecKeyPair;
+ private final byte[] publicKeySignature;
+
+ private final HandshakeState handshakeState;
+
+ private static final int EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH = 32;
+
+ /**
+ * Constructs a new Noise handler with the given static server keys and static public key signature. The static public
+ * key must be signed by a trusted root private key whose public key is known to and trusted by authenticating
+ * clients.
+ *
+ * @param noiseProtocolName the name of the Noise protocol implemented by this handshake handler
+ * @param ecKeyPair the static key pair for this server
+ * @param publicKeySignature an Ed25519 signature of the raw bytes of the static public key
+ */
+ AbstractNoiseHandshakeHandler(final String noiseProtocolName,
+ final ECKeyPair ecKeyPair,
+ final byte[] publicKeySignature) {
+
+ this.ecKeyPair = ecKeyPair;
+ this.publicKeySignature = publicKeySignature;
+
+ try {
+ this.handshakeState = new HandshakeState(noiseProtocolName, HandshakeState.RESPONDER);
+ } catch (final NoSuchAlgorithmException e) {
+ throw new AssertionError("Unsupported Noise algorithm: " + noiseProtocolName, e);
+ }
+ }
+
+ protected HandshakeState getHandshakeState() {
+ return handshakeState;
+ }
+
+ /**
+ * Handles an initial ephemeral key message from a client, advancing the handshake state and sending the server's
+ * static keys to the client. Both XX and NX patterns begin with a client sending its ephemeral key to the server.
+ * Clients must not include an additional payload with their ephemeral key message. The server's reply contains its
+ * static keys along with an Ed25519 signature of its public static key by a trusted root key.
+ *
+ * @param context the channel handler context for this message
+ * @param frame the websocket frame containing the ephemeral key message
+ *
+ * @throws NoiseHandshakeException if the ephemeral key message from the client was not of the expected size or if a
+ * general Noise encryption error occurred
+ */
+ protected void handleEphemeralKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
+ throws NoiseHandshakeException {
+
+ if (frame.content().readableBytes() != EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH) {
+ throw new NoiseHandshakeException("Unexpected ephemeral key message length");
+ }
+
+ // Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is
+ // making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key
+ // from the supplied private key (and will in fact overwrite any previously-set public key when setting a private
+ // key), so we just set the private key here.
+ handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0);
+ handshakeState.start();
+
+ // The initial message from the client should just include a plaintext ephemeral key with no payload. The frame is
+ // coming off the wire and so will be in a direct buffer that doesn't have a backing array.
+ final byte[] ephemeralKeyMessage = ByteBufUtil.getBytes(frame.content());
+ frame.content().readBytes(ephemeralKeyMessage);
+
+ try {
+ handshakeState.readMessage(ephemeralKeyMessage, 0, ephemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
+ } catch (final ShortBufferException e) {
+ // This should never happen since we're checking the length of the frame up front
+ throw new NoiseHandshakeException("Unexpected client payload");
+ } catch (final BadPaddingException e) {
+ // It turns out this should basically never happen because (a) we're not using padding and (b) the "bad AEAD tag"
+ // subclass of a bad padding exception can only happen if we have some AD to check, which we don't for an
+ // ephemeral-key-only message
+ throw new NoiseHandshakeException("Invalid keys");
+ }
+
+ // Send our key material and public key signature back to the client; this buffer will include:
+ //
+ // - A 32-byte plaintext ephemeral key
+ // - A 32-byte encrypted static key
+ // - A 16-byte AEAD tag for the static key
+ // - The public key signature payload
+ // - A 16-byte AEAD tag for the payload
+ final byte[] keyMaterial = new byte[32 + 32 + 16 + publicKeySignature.length + 16];
+
+ try {
+ handshakeState.writeMessage(keyMaterial, 0, publicKeySignature, 0, publicKeySignature.length);
+
+ context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(keyMaterial)))
+ .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
+ } catch (final ShortBufferException e) {
+ // This should never happen for messages of known length that we control
+ throw new AssertionError("Key material buffer was too short for message", e);
+ }
+ }
+
+ @Override
+ public void handlerRemoved(final ChannelHandlerContext context) {
+ handshakeState.destroy();
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ApplicationWebSocketCloseReason.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ApplicationWebSocketCloseReason.java
new file mode 100644
index 000000000..a0ba4ee0b
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ApplicationWebSocketCloseReason.java
@@ -0,0 +1,24 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
+
+enum ApplicationWebSocketCloseReason {
+ NOISE_HANDSHAKE_ERROR(4001),
+ CLIENT_AUTHENTICATION_ERROR(4002),
+ NOISE_ENCRYPTION_ERROR(4003),
+ REAUTHENTICATION_REQUIRED(4004);
+
+ private final int statusCode;
+
+ ApplicationWebSocketCloseReason(final int statusCode) {
+ this.statusCode = statusCode;
+ }
+
+ public int getStatusCode() {
+ return statusCode;
+ }
+
+ WebSocketCloseStatus toWebSocketCloseStatus(final String reason) {
+ return new WebSocketCloseStatus(statusCode, reason);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java
new file mode 100644
index 000000000..f3015b7e8
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java
@@ -0,0 +1,7 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+/**
+ * Indicates that an attempt to authenticate a remote client failed for some reason.
+ */
+class ClientAuthenticationException extends Exception {
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java
new file mode 100644
index 000000000..1828ee689
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java
@@ -0,0 +1,59 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
+import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
+import javax.crypto.BadPaddingException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a
+ * WebSocket handshake, the error handler will send appropriate WebSocket closure codes to the client in an attempt to
+ * identify the problem. If the client has not completed a WebSocket handshake, the handler simply closes the
+ * connection.
+ */
+class ErrorHandler extends ChannelInboundHandlerAdapter {
+
+ private boolean websocketHandshakeComplete = false;
+
+ private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
+
+ @Override
+ public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
+ if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
+ setWebsocketHandshakeComplete();
+ }
+
+ context.fireUserEventTriggered(event);
+ }
+
+ protected void setWebsocketHandshakeComplete() {
+ this.websocketHandshakeComplete = true;
+ }
+
+ @Override
+ public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
+ if (websocketHandshakeComplete) {
+ final WebSocketCloseStatus webSocketCloseStatus = switch (cause) {
+ case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage());
+ case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated");
+ case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
+ default -> {
+ log.warn("An unexpected exception reached the end of the pipeline", cause);
+ yield WebSocketCloseStatus.INTERNAL_SERVER_ERROR;
+ }
+ };
+
+ context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus))
+ .addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
+ } else {
+ // We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
+ // way; just close the connection instead.
+ context.close();
+ }
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java
new file mode 100644
index 000000000..37be75df0
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java
@@ -0,0 +1,96 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.local.LocalAddress;
+import io.netty.channel.local.LocalChannel;
+import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
+import io.netty.util.ReferenceCountUtil;
+import java.util.ArrayList;
+import java.util.List;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
+ * any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC
+ * server.
+ */
+class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
+
+ private final LocalAddress authenticatedGrpcServerAddress;
+ private final LocalAddress anonymousGrpcServerAddress;
+ private final List pendingReads = new ArrayList<>();
+
+ private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class);
+
+ public EstablishLocalGrpcConnectionHandler(final LocalAddress authenticatedGrpcServerAddress,
+ final LocalAddress anonymousGrpcServerAddress) {
+
+ this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress;
+ this.anonymousGrpcServerAddress = anonymousGrpcServerAddress;
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) {
+ pendingReads.add(message);
+ }
+
+ @Override
+ public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) throws Exception {
+ if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
+ // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
+ // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
+ // connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
+ // authenticated service.
+ final LocalAddress grpcServerAddress = noiseHandshakeCompleteEvent.authenticatedDevice().isPresent()
+ ? authenticatedGrpcServerAddress
+ : anonymousGrpcServerAddress;
+
+ new Bootstrap()
+ .remoteAddress(grpcServerAddress)
+ // TODO Set local address
+ .channel(LocalChannel.class)
+ .group(remoteChannelContext.channel().eventLoop())
+ .handler(new ChannelInitializer() {
+ @Override
+ protected void initChannel(final LocalChannel localChannel) {
+ localChannel.pipeline().addLast(new ProxyHandler(remoteChannelContext.channel()));
+ }
+ })
+ .connect()
+ .addListener((ChannelFutureListener) future -> {
+ if (future.isSuccess()) {
+ // Close the local connection if the remote channel closes and vice versa
+ remoteChannelContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
+ future.channel().closeFuture().addListener(closeFuture ->
+ remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART)));
+
+ remoteChannelContext.pipeline()
+ .addAfter(remoteChannelContext.name(), null, new ProxyHandler(future.channel()));
+
+ // Flush any buffered reads we accumulated while waiting to open the connection
+ pendingReads.forEach(remoteChannelContext::fireChannelRead);
+ pendingReads.clear();
+
+ remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this);
+ } else {
+ log.warn("Failed to establish local connection to gRPC server", future.cause());
+ remoteChannelContext.close();
+ }
+ });
+ }
+
+ remoteChannelContext.fireUserEventTriggered(event);
+ }
+
+ @Override
+ public void handlerRemoved(final ChannelHandlerContext context) {
+ pendingReads.forEach(ReferenceCountUtil::release);
+ pendingReads.clear();
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java
new file mode 100644
index 000000000..5888174c4
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java
@@ -0,0 +1,16 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.dropwizard.lifecycle.Managed;
+import io.netty.channel.DefaultEventLoopGroup;
+
+/**
+ * A wrapper for a Netty {@link DefaultEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing
+ * Dropwizard to manage the lifecycle of the event loop group.
+ */
+public class ManagedDefaultEventLoopGroup extends DefaultEventLoopGroup implements Managed {
+
+ @Override
+ public void stop() throws InterruptedException {
+ this.shutdownGracefully().await();
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java
new file mode 100644
index 000000000..21cd3d846
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java
@@ -0,0 +1,49 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.dropwizard.lifecycle.Managed;
+import io.grpc.Server;
+import io.grpc.ServerBuilder;
+import io.grpc.netty.NettyServerBuilder;
+import io.netty.channel.DefaultEventLoopGroup;
+import io.netty.channel.local.LocalAddress;
+import io.netty.channel.local.LocalServerChannel;
+import java.io.IOException;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * A managed, local gRPC server configures and wraps a gRPC {@link Server} that listens on a Netty {@link LocalAddress}
+ * and whose lifecycle is managed by Dropwizard via the {@link Managed} interface.
+ */
+public abstract class ManagedLocalGrpcServer implements Managed {
+
+ private final Server server;
+
+ public ManagedLocalGrpcServer(final LocalAddress localAddress,
+ final DefaultEventLoopGroup eventLoopGroup) {
+
+ final ServerBuilder> serverBuilder = NettyServerBuilder.forAddress(localAddress)
+ .channelType(LocalServerChannel.class)
+ .bossEventLoopGroup(eventLoopGroup)
+ .workerEventLoopGroup(eventLoopGroup);
+
+ configureServer(serverBuilder);
+
+ server = serverBuilder.build();
+ }
+
+ protected abstract void configureServer(final ServerBuilder> serverBuilder);
+
+ @Override
+ public void start() throws IOException {
+ server.start();
+ }
+
+ @Override
+ public void stop() {
+ try {
+ server.shutdown().awaitTermination(5, TimeUnit.MINUTES);
+ } catch (final InterruptedException e) {
+ server.shutdownNow();
+ }
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedNioEventLoopGroup.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedNioEventLoopGroup.java
new file mode 100644
index 000000000..06d3e97db
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedNioEventLoopGroup.java
@@ -0,0 +1,16 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.dropwizard.lifecycle.Managed;
+import io.netty.channel.nio.NioEventLoopGroup;
+
+/**
+ * A wrapper for a Netty {@link NioEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing
+ * Dropwizard to manage the lifecycle of the event loop group.
+ */
+public class ManagedNioEventLoopGroup extends NioEventLoopGroup implements Managed {
+
+ @Override
+ public void stop() throws Exception {
+ this.shutdownGracefully().await();
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java
new file mode 100644
index 000000000..5a2f1ae99
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java
@@ -0,0 +1,13 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
+import java.util.Optional;
+
+/**
+ * An event that indicates that a Noise handshake has completed, possibly authenticating a caller in the process.
+ *
+ * @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
+ * type that performs authentication
+ */
+record NoiseHandshakeCompleteEvent(Optional authenticatedDevice) {
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java
new file mode 100644
index 000000000..1a11bcec6
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java
@@ -0,0 +1,12 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+/**
+ * Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
+ * a general encryption error).
+ */
+class NoiseHandshakeException extends Exception {
+
+ public NoiseHandshakeException(final String message) {
+ super(message);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java
new file mode 100644
index 000000000..8cddd32cf
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java
@@ -0,0 +1,40 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import java.util.Optional;
+import io.netty.util.ReferenceCountUtil;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+
+/**
+ * A Noise NX handler handles the responder side of a Noise NX handshake.
+ */
+class NoiseNXHandshakeHandler extends AbstractNoiseHandshakeHandler {
+
+ static final String NOISE_PROTOCOL_NAME = "Noise_NX_25519_ChaChaPoly_BLAKE2b";
+
+ NoiseNXHandshakeHandler(final ECKeyPair ecKeyPair, final byte[] publicKeySignature) {
+ super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature);
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
+ if (message instanceof BinaryWebSocketFrame frame) {
+ try {
+ handleEphemeralKeyMessage(context, frame);
+ } finally {
+ frame.release();
+ }
+
+ // All we need to do is accept the client's ephemeral key and send our own static keys; after that, we can consider
+ // the handshake complete
+ context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
+ context.pipeline().replace(NoiseNXHandshakeHandler.this, null, new NoiseStreamHandler(getHandshakeState().split()));
+ } else {
+ // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
+ // error
+ ReferenceCountUtil.release(message);
+ throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
+ }
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandler.java
new file mode 100644
index 000000000..b2776b2cc
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandler.java
@@ -0,0 +1,95 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import com.southernstorm.noise.protocol.CipherState;
+import com.southernstorm.noise.protocol.CipherStatePair;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufUtil;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelDuplexHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelPromise;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.WebSocketFrame;
+import io.netty.util.ReferenceCountUtil;
+import javax.crypto.BadPaddingException;
+import javax.crypto.ShortBufferException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A Noise stream handler manages a bidirectional Noise session after a handshake has completed.
+ */
+class NoiseStreamHandler extends ChannelDuplexHandler {
+
+ private final CipherStatePair cipherStatePair;
+
+ private static final Logger log = LoggerFactory.getLogger(NoiseStreamHandler.class);
+
+ NoiseStreamHandler(CipherStatePair cipherStatePair) {
+ this.cipherStatePair = cipherStatePair;
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message)
+ throws ShortBufferException, BadPaddingException {
+
+ if (message instanceof BinaryWebSocketFrame frame) {
+ try {
+ final CipherState cipherState = cipherStatePair.getReceiver();
+
+ // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
+ // We'll need to copy it to a heap buffer.
+ final byte[] noiseBuffer = ByteBufUtil.getBytes(frame.content());
+
+ // Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
+ final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length);
+
+ context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength));
+ } finally {
+ frame.release();
+ }
+ } else {
+ // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
+ // error
+ ReferenceCountUtil.release(message);
+ throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
+ }
+ }
+
+ @Override
+ public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
+ if (message instanceof ByteBuf plaintext) {
+ try {
+ // TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
+ final CipherState cipherState = cipherStatePair.getSender();
+ final int plaintextLength = plaintext.readableBytes();
+
+ // We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
+ // buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
+ // mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
+ final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
+ plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
+
+ // Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
+ cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
+
+ context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
+ } finally {
+ plaintext.release();
+ }
+ } else {
+ if (!(message instanceof WebSocketFrame)) {
+ // Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
+ // get issued in response to exceptions)
+ log.warn("Unexpected object in pipeline: {}", message);
+ }
+
+ context.write(message, promise);
+ }
+ }
+
+ @Override
+ public void handlerRemoved(final ChannelHandlerContext context) {
+ cipherStatePair.destroy();
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java
new file mode 100644
index 000000000..2ebc9ea85
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java
@@ -0,0 +1,178 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import com.southernstorm.noise.protocol.HandshakeState;
+import io.netty.buffer.ByteBufUtil;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import io.netty.util.ReferenceCountUtil;
+import java.security.MessageDigest;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.UUID;
+import javax.crypto.BadPaddingException;
+import javax.crypto.ShortBufferException;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
+import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
+import org.whispersystems.textsecuregcm.util.UUIDUtil;
+
+/**
+ * A Noise XX handler handles the responder side of a Noise XX handshake. This implementation expects clients to send
+ * identifying information (an account identifier and device ID) as an additional payload when sending its static key
+ * material. It compares the static public key against the stored public key for the identified device asynchronously,
+ * buffering traffic from the client until the authentication check completes.
+ */
+class NoiseXXHandshakeHandler extends AbstractNoiseHandshakeHandler {
+
+ private final ClientPublicKeysManager clientPublicKeysManager;
+
+ private AuthenticationState authenticationState = AuthenticationState.GET_EPHEMERAL_KEY;
+
+ private final List pendingInboundFrames = new ArrayList<>();
+
+ static final String NOISE_PROTOCOL_NAME = "Noise_XX_25519_ChaChaPoly_BLAKE2b";
+
+ // When the client sends its static key message, we expect:
+ //
+ // - A 32-byte encrypted static public key
+ // - A 16-byte AEAD tag for the static key
+ // - 17 bytes of identity data in the message payload (a UUID and a one-byte device ID)
+ // - A 16-byte AEAD tag for the identity payload
+ private static final int EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH = 81;
+
+ private enum AuthenticationState {
+ GET_EPHEMERAL_KEY,
+ GET_STATIC_KEY,
+ CHECK_PUBLIC_KEY,
+ ERROR
+ }
+
+ public NoiseXXHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager,
+ final ECKeyPair ecKeyPair,
+ final byte[] publicKeySignature) {
+
+ super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature);
+
+ this.clientPublicKeysManager = clientPublicKeysManager;
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
+ if (message instanceof BinaryWebSocketFrame frame) {
+ try {
+ switch (authenticationState) {
+ case GET_EPHEMERAL_KEY -> {
+ try {
+ handleEphemeralKeyMessage(context, frame);
+ authenticationState = AuthenticationState.GET_STATIC_KEY;
+ } finally {
+ frame.release();
+ }
+ }
+ case GET_STATIC_KEY -> {
+ try {
+ handleStaticKey(context, frame);
+ authenticationState = AuthenticationState.CHECK_PUBLIC_KEY;
+ } finally {
+ frame.release();
+ }
+ }
+ case CHECK_PUBLIC_KEY -> {
+ // Buffer any inbound traffic until we've finished checking the client's public key
+ pendingInboundFrames.add(frame);
+ }
+ case ERROR -> {
+ // If authentication has failed for any reason, just discard inbound traffic until the channel closes
+ frame.release();
+ }
+ }
+ } catch (final ShortBufferException e) {
+ authenticationState = AuthenticationState.ERROR;
+ throw new NoiseHandshakeException("Unexpected payload length");
+ } catch (final BadPaddingException e) {
+ authenticationState = AuthenticationState.ERROR;
+ throw new ClientAuthenticationException();
+ }
+ } else {
+ // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
+ // error
+ ReferenceCountUtil.release(message);
+ throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
+ }
+ }
+
+ private void handleStaticKey(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
+ throws NoiseHandshakeException, ShortBufferException, BadPaddingException {
+
+ if (frame.content().readableBytes() != EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH) {
+ throw new NoiseHandshakeException("Unexpected client static key message length");
+ }
+
+ final HandshakeState handshakeState = getHandshakeState();
+
+ // The websocket frame will have come right off the wire, and so needs to be copied from a non-array-backed direct
+ // buffer into a heap buffer.
+ final byte[] staticKeyAndClientIdentityMessage = ByteBufUtil.getBytes(frame.content());
+
+ // The payload from the client should be a UUID (16 bytes) followed by a device ID (1 byte)
+ final byte[] payload = new byte[17];
+
+ final UUID accountIdentifier;
+ final byte deviceId;
+
+ final int payloadBytesRead = handshakeState.readMessage(staticKeyAndClientIdentityMessage,
+ 0, staticKeyAndClientIdentityMessage.length, payload, 0);
+
+ if (payloadBytesRead != 17) {
+ throw new NoiseHandshakeException("Unexpected identity payload length");
+ }
+
+ try {
+ accountIdentifier = UUIDUtil.fromBytes(payload, 0);
+ } catch (final IllegalArgumentException e) {
+ throw new NoiseHandshakeException("Could not parse account identifier");
+ }
+
+ deviceId = payload[16];
+
+ // Verify the identity of the caller by comparing the submitted static public key against the stored public key for
+ // the identified device
+ clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)
+ .whenCompleteAsync((maybePublicKey, throwable) -> maybePublicKey.ifPresentOrElse(storedPublicKey -> {
+ final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()];
+ handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0);
+
+ if (MessageDigest.isEqual(publicKeyFromClient, storedPublicKey.getPublicKeyBytes())) {
+ context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(
+ Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))));
+
+ context.pipeline().addAfter(context.name(), null, new NoiseStreamHandler(handshakeState.split()));
+
+ // Flush any buffered reads
+ pendingInboundFrames.forEach(context::fireChannelRead);
+ pendingInboundFrames.clear();
+
+ context.pipeline().remove(NoiseXXHandshakeHandler.this);
+ } else {
+ // We found a key, but it doesn't match what the caller submitted
+ context.fireExceptionCaught(new ClientAuthenticationException());
+ authenticationState = AuthenticationState.ERROR;
+ }
+ },
+ () -> {
+ // We couldn't find a key for the identified account/device
+ context.fireExceptionCaught(new ClientAuthenticationException());
+ authenticationState = AuthenticationState.ERROR;
+ }),
+ context.executor());
+ }
+
+ @Override
+ public void handlerRemoved(final ChannelHandlerContext context) {
+ super.handlerRemoved(context);
+
+ pendingInboundFrames.forEach(BinaryWebSocketFrame::release);
+ pendingInboundFrames.clear();
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java
new file mode 100644
index 000000000..2d5effc1a
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java
@@ -0,0 +1,24 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+
+/**
+ * A proxy handler writes all data read from one channel to another peer channel.
+ */
+class ProxyHandler extends ChannelInboundHandlerAdapter {
+
+ private final Channel peerChannel;
+
+ public ProxyHandler(final Channel peerChannel) {
+ this.peerChannel = peerChannel;
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) {
+ peerChannel.writeAndFlush(message)
+ .addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandler.java
new file mode 100644
index 000000000..41d7436e2
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandler.java
@@ -0,0 +1,35 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
+import io.netty.handler.codec.http.websocketx.WebSocketFrame;
+import io.netty.util.ReferenceCountUtil;
+
+/**
+ * A "reject unsupported message" handler closes the channel if it receives messages it does not know how to process.
+ */
+public class RejectUnsupportedMessagesHandler extends ChannelInboundHandlerAdapter {
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
+ if (message instanceof final WebSocketFrame webSocketFrame) {
+ if (webSocketFrame instanceof final TextWebSocketFrame textWebSocketFrame) {
+ try {
+ context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE));
+ } finally {
+ textWebSocketFrame.release();
+ }
+ } else {
+ // Allow all other types of WebSocket frames
+ context.fireChannelRead(webSocketFrame);
+ }
+ } else {
+ // Discard anything that's not a WebSocket frame
+ ReferenceCountUtil.release(message);
+ context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE));
+ }
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandler.java
new file mode 100644
index 000000000..453ccb390
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandler.java
@@ -0,0 +1,74 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.codec.http.DefaultFullHttpResponse;
+import io.netty.handler.codec.http.FullHttpRequest;
+import io.netty.handler.codec.http.HttpHeaderNames;
+import io.netty.handler.codec.http.HttpHeaderValues;
+import io.netty.handler.codec.http.HttpMethod;
+import io.netty.handler.codec.http.HttpResponseStatus;
+import io.netty.util.ReferenceCountUtil;
+
+/**
+ * A WebSocket opening handshake handler serves as the "front door" for the WebSocket/Noise tunnel and gracefully
+ * rejects requests for anything other than a WebSocket connection to a known endpoint.
+ */
+class WebSocketOpeningHandshakeHandler extends ChannelInboundHandlerAdapter {
+
+ private final String authenticatedPath;
+ private final String anonymousPath;
+
+ WebSocketOpeningHandshakeHandler(final String authenticatedPath, final String anonymousPath) {
+ this.authenticatedPath = authenticatedPath;
+ this.anonymousPath = anonymousPath;
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
+ if (message instanceof FullHttpRequest request) {
+ boolean shouldReleaseRequest = true;
+
+ try {
+ if (request.decoderResult().isSuccess()) {
+ if (HttpMethod.GET.equals(request.method())) {
+ if (authenticatedPath.equals(request.uri()) || anonymousPath.equals(request.uri())) {
+ if (request.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)) {
+ // Pass the request along to the websocket handshake handler and remove ourselves from the pipeline
+ shouldReleaseRequest = false;
+
+ context.fireChannelRead(request);
+ context.pipeline().remove(this);
+ } else {
+ closeConnectionWithStatus(context, request, HttpResponseStatus.UPGRADE_REQUIRED);
+ }
+ } else {
+ closeConnectionWithStatus(context, request, HttpResponseStatus.NOT_FOUND);
+ }
+ } else {
+ closeConnectionWithStatus(context, request, HttpResponseStatus.METHOD_NOT_ALLOWED);
+ }
+ } else {
+ closeConnectionWithStatus(context, request, HttpResponseStatus.BAD_REQUEST);
+ }
+ } finally {
+ if (shouldReleaseRequest) {
+ request.release();
+ }
+ }
+ } else {
+ // Anything except HTTP requests should have been filtered out of the pipeline by now; treat this as an error
+ ReferenceCountUtil.release(message);
+ throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
+ }
+ }
+
+ private static void closeConnectionWithStatus(final ChannelHandlerContext context,
+ final FullHttpRequest request,
+ final HttpResponseStatus status) {
+
+ context.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), status))
+ .addListener(ChannelFutureListener.CLOSE);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListener.java
new file mode 100644
index 000000000..5b743a61c
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListener.java
@@ -0,0 +1,52 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
+
+/**
+ * A WebSocket handshake listener waits for a WebSocket handshake to complete, then replaces itself with the appropriate
+ * Noise handshake handler for the requested path.
+ */
+class WebsocketHandshakeCompleteListener extends ChannelInboundHandlerAdapter {
+
+ private final ClientPublicKeysManager clientPublicKeysManager;
+
+ private final ECKeyPair ecKeyPair;
+ private final byte[] publicKeySignature;
+
+ WebsocketHandshakeCompleteListener(final ClientPublicKeysManager clientPublicKeysManager,
+ final ECKeyPair ecKeyPair,
+ final byte[] publicKeySignature) {
+
+ this.clientPublicKeysManager = clientPublicKeysManager;
+ this.ecKeyPair = ecKeyPair;
+ this.publicKeySignature = publicKeySignature;
+ }
+
+ @Override
+ public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
+ if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
+ final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
+ case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH ->
+ new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature);
+
+ case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH ->
+ new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature);
+
+ default -> {
+ // The HttpHandler should have caught all of these cases already; we'll consider it an internal error if
+ // something slipped through.
+ throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
+ }
+ };
+
+ context.pipeline().replace(WebsocketHandshakeCompleteListener.this, null, noiseHandshakeHandler);
+ }
+
+ context.fireUserEventTriggered(event);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java
new file mode 100644
index 000000000..f60e0f838
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java
@@ -0,0 +1,116 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.southernstorm.noise.protocol.Noise;
+import io.dropwizard.lifecycle.Managed;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.local.LocalAddress;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.ServerSocketChannel;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.handler.codec.http.HttpObjectAggregator;
+import io.netty.handler.codec.http.HttpServerCodec;
+import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
+import io.netty.handler.ssl.ClientAuth;
+import io.netty.handler.ssl.OpenSsl;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
+import io.netty.handler.ssl.SslProtocols;
+import io.netty.handler.ssl.SslProvider;
+import java.net.InetSocketAddress;
+import java.security.PrivateKey;
+import java.security.cert.X509Certificate;
+import java.util.concurrent.Executor;
+import javax.net.ssl.SSLException;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
+
+/**
+ * A WebSocket/Noise tunnel server accepts traffic from the public internet (in the form of Noise packets framed by
+ * binary WebSocket frames) and passes it through to a local gRPC server.
+ */
+public class WebsocketNoiseTunnelServer implements Managed {
+
+ private final ServerBootstrap bootstrap;
+ private ServerSocketChannel channel;
+
+ static final String AUTHENTICATED_SERVICE_PATH = "/authenticated";
+ static final String ANONYMOUS_SERVICE_PATH = "/anonymous";
+
+ private static final Logger log = LoggerFactory.getLogger(WebsocketNoiseTunnelServer.class);
+
+ public WebsocketNoiseTunnelServer(final int websocketPort,
+ final X509Certificate[] tlsCertificateChain,
+ final PrivateKey tlsPrivateKey,
+ final NioEventLoopGroup eventLoopGroup,
+ final Executor delegatedTaskExecutor,
+ final ClientPublicKeysManager clientPublicKeysManager,
+ final ECKeyPair ecKeyPair,
+ final byte[] publicKeySignature,
+ final LocalAddress authenticatedGrpcServerAddress,
+ final LocalAddress anonymousGrpcServerAddress) throws SSLException {
+
+ final SslProvider sslProvider;
+
+ if (OpenSsl.isAvailable()) {
+ log.info("Native OpenSSL provider is available; will use native provider");
+ sslProvider = SslProvider.OPENSSL;
+ } else {
+ log.info("No native SSL provider available; will use JDK provider");
+ sslProvider = SslProvider.JDK;
+ }
+
+ final SslContext sslContext = SslContextBuilder.forServer(tlsPrivateKey, tlsCertificateChain)
+ .clientAuth(ClientAuth.NONE)
+ .protocols(SslProtocols.TLS_v1_3)
+ .sslProvider(sslProvider)
+ .build();
+
+ this.bootstrap = new ServerBootstrap()
+ .group(eventLoopGroup)
+ .channel(NioServerSocketChannel.class)
+ .localAddress(websocketPort)
+ .childHandler(new ChannelInitializer() {
+ @Override
+ protected void initChannel(SocketChannel socketChannel) {
+ socketChannel.pipeline()
+ .addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor))
+ .addLast(new HttpServerCodec())
+ .addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
+ // The WebSocket opening handshake handler will remove itself from the pipeline once it has received a valid WebSocket upgrade
+ // request and passed it down the pipeline
+ .addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH))
+ .addLast(new WebSocketServerProtocolHandler("/", true))
+ .addLast(new RejectUnsupportedMessagesHandler())
+ // The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
+ // a WebSocket handshake has been completed
+ .addLast(new WebsocketHandshakeCompleteListener(clientPublicKeysManager, ecKeyPair, publicKeySignature))
+ // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
+ // once the Noise handshake has completed
+ .addLast(new EstablishLocalGrpcConnectionHandler(authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
+ .addLast(new ErrorHandler());
+ }
+ });
+ }
+
+ @VisibleForTesting
+ InetSocketAddress getLocalAddress() {
+ return channel.localAddress();
+ }
+
+ @Override
+ public void start() throws InterruptedException {
+ channel = (ServerSocketChannel) bootstrap.bind().await().channel();
+ }
+
+ @Override
+ public void stop() throws InterruptedException {
+ if (channel != null) {
+ channel.close().await();
+ }
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java
index 38f2cb34b..37cc6bbe3 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java
@@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.util;
import com.google.protobuf.ByteString;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
-import java.util.Optional;
import java.util.UUID;
public final class UUIDUtil {
@@ -40,6 +39,10 @@ public final class UUIDUtil {
return fromByteBuffer(ByteBuffer.wrap(bytes));
}
+ public static UUID fromBytes(final byte[] bytes, final int offset) {
+ return fromByteBuffer(ByteBuffer.wrap(bytes, offset, 16));
+ }
+
public static UUID fromByteBuffer(final ByteBuffer byteBuffer) {
try {
final long mostSigBits = byteBuffer.getLong();
@@ -52,12 +55,4 @@ public final class UUIDUtil {
throw new IllegalArgumentException("unexpected byte array length; was less than 16");
}
}
-
- public static Optional fromStringSafe(final String uuidString) {
- try {
- return Optional.of(UUID.fromString(uuidString));
- } catch (final IllegalArgumentException e) {
- return Optional.empty();
- }
- }
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java
new file mode 100644
index 000000000..79fb2078f
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java
@@ -0,0 +1,21 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.util.ResourceLeakDetector;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+
+abstract class AbstractLeakDetectionTest {
+
+ private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel;
+
+ @BeforeAll
+ static void setLeakDetectionLevel() {
+ originalResourceLeakDetectorLevel = ResourceLeakDetector.getLevel();
+ ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID);
+ }
+
+ @AfterAll
+ static void restoreLeakDetectionLevel() {
+ ResourceLeakDetector.setLevel(originalResourceLeakDetectorLevel);
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java
new file mode 100644
index 000000000..63b114af5
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java
@@ -0,0 +1,94 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import com.southernstorm.noise.protocol.HandshakeState;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
+import java.security.NoSuchAlgorithmException;
+import javax.crypto.BadPaddingException;
+import javax.crypto.ShortBufferException;
+import org.signal.libsignal.protocol.ecc.ECPublicKey;
+
+abstract class AbstractNoiseClientHandler extends ChannelInboundHandlerAdapter {
+
+ private final ECPublicKey rootPublicKey;
+
+ private final HandshakeState handshakeState;
+
+ AbstractNoiseClientHandler(final ECPublicKey rootPublicKey) {
+ this.rootPublicKey = rootPublicKey;
+
+ try {
+ handshakeState = new HandshakeState(getNoiseProtocolName(), HandshakeState.INITIATOR);
+ } catch (final NoSuchAlgorithmException e) {
+ throw new AssertionError("Unsupported Noise algorithm: " + getNoiseProtocolName(), e);
+ }
+ }
+
+ protected abstract String getNoiseProtocolName();
+
+ protected abstract void startHandshake();
+
+ protected HandshakeState getHandshakeState() {
+ return handshakeState;
+ }
+
+ @Override
+ public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
+ if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
+ if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
+ startHandshake();
+
+ final byte[] ephemeralKeyMessage = new byte[32];
+ handshakeState.writeMessage(ephemeralKeyMessage, 0, null, 0, 0);
+
+ context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessage)))
+ .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
+ }
+ }
+
+ super.userEventTriggered(context, event);
+ }
+
+ protected void handleServerStaticKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
+ throws NoiseHandshakeException {
+
+ // The frame is coming right off the wire and so will be a direct buffer not backed by an array; copy it to a heap
+ // buffer so we can Noise at it.
+ final ByteBuf keyMaterialBuffer = context.alloc().heapBuffer(frame.content().readableBytes());
+ final byte[] serverPublicKeySignature = new byte[64];
+
+ try {
+ frame.content().readBytes(keyMaterialBuffer);
+
+ final int payloadBytesRead =
+ handshakeState.readMessage(keyMaterialBuffer.array(), keyMaterialBuffer.arrayOffset(), keyMaterialBuffer.readableBytes(), serverPublicKeySignature, 0);
+
+ if (payloadBytesRead != 64) {
+ throw new NoiseHandshakeException("Unexpected signature length");
+ }
+ } catch (final ShortBufferException e) {
+ throw new NoiseHandshakeException("Unexpected signature length");
+ } catch (final BadPaddingException e) {
+ throw new NoiseHandshakeException("Invalid keys");
+ } finally {
+ keyMaterialBuffer.release();
+ }
+
+ final byte[] serverPublicKey = new byte[32];
+ handshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
+
+ if (!rootPublicKey.verifySignature(serverPublicKey, serverPublicKeySignature)) {
+ throw new NoiseHandshakeException("Invalid server public key signature");
+ }
+ }
+
+ @Override
+ public void handlerRemoved(final ChannelHandlerContext context) throws Exception {
+ handshakeState.destroy();
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java
new file mode 100644
index 000000000..885b08645
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java
@@ -0,0 +1,141 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import io.netty.util.ReferenceCountUtil;
+import java.util.concurrent.ThreadLocalRandom;
+import javax.annotation.Nullable;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.signal.libsignal.protocol.ecc.Curve;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.signal.libsignal.protocol.ecc.ECPublicKey;
+
+abstract class AbstractNoiseHandshakeHandlerTest extends AbstractLeakDetectionTest {
+
+ private ECPublicKey rootPublicKey;
+
+ private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
+
+ private EmbeddedChannel embeddedChannel;
+
+ private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
+
+ @Nullable
+ private NoiseHandshakeCompleteEvent handshakeCompleteEvent = null;
+
+ @Override
+ public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
+ if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
+ handshakeCompleteEvent = noiseHandshakeCompleteEvent;
+ } else {
+ context.fireUserEventTriggered(event);
+ }
+ }
+
+ @Nullable
+ public NoiseHandshakeCompleteEvent getHandshakeCompleteEvent() {
+ return handshakeCompleteEvent;
+ }
+ }
+
+ @BeforeEach
+ void setUp() {
+ final ECKeyPair rootKeyPair = Curve.generateKeyPair();
+ final ECKeyPair serverKeyPair = Curve.generateKeyPair();
+
+ rootPublicKey = rootKeyPair.getPublicKey();
+
+ final byte[] serverPublicKeySignature =
+ rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes());
+
+ noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
+
+ embeddedChannel =
+ new EmbeddedChannel(getHandler(serverKeyPair, serverPublicKeySignature), noiseHandshakeCompleteHandler);
+ }
+
+ @AfterEach
+ void tearDown() {
+ embeddedChannel.close();
+ }
+
+ protected EmbeddedChannel getEmbeddedChannel() {
+ return embeddedChannel;
+ }
+
+ protected ECPublicKey getRootPublicKey() {
+ return rootPublicKey;
+ }
+
+ @Nullable
+ protected NoiseHandshakeCompleteEvent getNoiseHandshakeCompleteEvent() {
+ return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
+ }
+
+ protected abstract AbstractNoiseHandshakeHandler getHandler(final ECKeyPair serverKeyPair, final byte[] serverPublicKeySignature);
+
+ @Test
+ void handleInvalidInitialMessage() throws InterruptedException {
+ final byte[] contentBytes = new byte[17];
+ ThreadLocalRandom.current().nextBytes(contentBytes);
+
+ final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
+
+ final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await();
+
+ assertFalse(writeFuture.isSuccess());
+ assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
+ assertEquals(0, content.refCnt());
+ assertNull(getNoiseHandshakeCompleteEvent());
+ }
+
+ @Test
+ void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
+ final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7];
+
+ for (int i = 0; i < frames.length; i++) {
+ final byte[] contentBytes = new byte[17];
+ ThreadLocalRandom.current().nextBytes(contentBytes);
+
+ frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
+
+ embeddedChannel.writeOneInbound(frames[i]).await();
+ }
+
+ for (final BinaryWebSocketFrame frame : frames) {
+ assertEquals(0, frame.refCnt());
+ }
+
+ assertNull(getNoiseHandshakeCompleteEvent());
+ }
+
+ @Test
+ void handleNonWebSocketBinaryFrame() throws InterruptedException {
+ final byte[] contentBytes = new byte[17];
+ ThreadLocalRandom.current().nextBytes(contentBytes);
+
+ final ByteBuf message = Unpooled.wrappedBuffer(contentBytes);
+
+ final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
+
+ assertFalse(writeFuture.isSuccess());
+ assertInstanceOf(IllegalArgumentException.class, writeFuture.cause());
+ assertEquals(0, message.refCnt());
+ assertNull(getNoiseHandshakeCompleteEvent());
+
+ assertTrue(embeddedChannel.inboundMessages().isEmpty());
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AuthenticationTypeService.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AuthenticationTypeService.java
new file mode 100644
index 000000000..112c1f20f
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AuthenticationTypeService.java
@@ -0,0 +1,21 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.grpc.stub.StreamObserver;
+import org.signal.chat.rpc.AuthenticationTypeGrpc;
+import org.signal.chat.rpc.GetAuthenticatedRequest;
+import org.signal.chat.rpc.GetAuthenticatedResponse;
+
+public class AuthenticationTypeService extends AuthenticationTypeGrpc.AuthenticationTypeImplBase {
+
+ private final boolean authenticated;
+
+ public AuthenticationTypeService(final boolean authenticated) {
+ this.authenticated = authenticated;
+ }
+
+ @Override
+ public void getAuthenticated(final GetAuthenticatedRequest request, final StreamObserver responseObserver) {
+ responseObserver.onNext(GetAuthenticatedResponse.newBuilder().setAuthenticated(authenticated).build());
+ responseObserver.onCompleted();
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientErrorHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientErrorHandler.java
new file mode 100644
index 000000000..ae7bcc133
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientErrorHandler.java
@@ -0,0 +1,18 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
+
+class ClientErrorHandler extends ErrorHandler {
+
+ @Override
+ public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
+ if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
+ if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
+ setWebsocketHandshakeComplete();
+ }
+ }
+
+ super.userEventTriggered(context, event);
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java
new file mode 100644
index 000000000..ec22d6af7
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java
@@ -0,0 +1,141 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import com.southernstorm.noise.protocol.Noise;
+import io.netty.bootstrap.Bootstrap;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.codec.http.DefaultHttpHeaders;
+import io.netty.handler.codec.http.HttpClientCodec;
+import io.netty.handler.codec.http.HttpObjectAggregator;
+import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
+import io.netty.handler.codec.http.websocketx.WebSocketVersion;
+import io.netty.handler.ssl.SslContextBuilder;
+import io.netty.util.ReferenceCountUtil;
+import java.net.SocketAddress;
+import java.net.URI;
+import java.security.cert.X509Certificate;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+import javax.annotation.Nullable;
+import javax.net.ssl.SSLException;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.signal.libsignal.protocol.ecc.ECPublicKey;
+
+class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
+
+ private final X509Certificate trustedServerCertificate;
+ private final URI websocketUri;
+ private final boolean authenticated;
+ @Nullable private final ECKeyPair ecKeyPair;
+ private final ECPublicKey rootPublicKey;
+ @Nullable private final UUID accountIdentifier;
+ private final byte deviceId;
+ private final SocketAddress remoteServerAddress;
+ private final WebSocketCloseListener webSocketCloseListener;
+
+ private final List pendingReads = new ArrayList<>();
+
+ private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
+
+ EstablishRemoteConnectionHandler(
+ final X509Certificate trustedServerCertificate,
+ final URI websocketUri,
+ final boolean authenticated,
+ @Nullable final ECKeyPair ecKeyPair,
+ final ECPublicKey rootPublicKey,
+ @Nullable final UUID accountIdentifier,
+ final byte deviceId,
+ final SocketAddress remoteServerAddress,
+ final WebSocketCloseListener webSocketCloseListener) {
+
+ this.trustedServerCertificate = trustedServerCertificate;
+ this.websocketUri = websocketUri;
+ this.authenticated = authenticated;
+ this.ecKeyPair = ecKeyPair;
+ this.rootPublicKey = rootPublicKey;
+ this.accountIdentifier = accountIdentifier;
+ this.deviceId = deviceId;
+ this.remoteServerAddress = remoteServerAddress;
+ this.webSocketCloseListener = webSocketCloseListener;
+ }
+
+ @Override
+ public void handlerAdded(final ChannelHandlerContext localContext) {
+ new Bootstrap()
+ .channel(NioSocketChannel.class)
+ .group(localContext.channel().eventLoop())
+ .handler(new ChannelInitializer() {
+ @Override
+ protected void initChannel(final SocketChannel channel) throws SSLException {
+ channel.pipeline()
+ .addLast(SslContextBuilder
+ .forClient()
+ .trustManager(trustedServerCertificate)
+ .build()
+ .newHandler(channel.alloc()))
+ .addLast(new HttpClientCodec())
+ .addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
+ // Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we
+ // want to react to them on our own, we need to catch them before they hit that handler.
+ .addLast(new InboundCloseWebSocketFrameHandler(webSocketCloseListener))
+ .addLast(new WebSocketClientProtocolHandler(websocketUri,
+ WebSocketVersion.V13,
+ null,
+ false,
+ new DefaultHttpHeaders(),
+ Noise.MAX_PACKET_LEN,
+ 10_000))
+ .addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener))
+ .addLast(authenticated
+ ? new NoiseXXClientHandshakeHandler(ecKeyPair, rootPublicKey, accountIdentifier, deviceId)
+ : new NoiseNXClientHandshakeHandler(rootPublicKey))
+ .addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
+ @Override
+ public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
+ throws Exception {
+ if (event instanceof NoiseHandshakeCompleteEvent) {
+ remoteContext.pipeline()
+ .replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
+
+ localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
+
+ pendingReads.forEach(localContext::fireChannelRead);
+ pendingReads.clear();
+
+ localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
+ }
+
+ super.userEventTriggered(remoteContext, event);
+ }
+ })
+ .addLast(new ClientErrorHandler());
+ }
+ })
+ .connect(remoteServerAddress)
+ .addListener((ChannelFutureListener) future -> {
+ if (future.isSuccess()) {
+ // Close the local connection if the remote channel closes and vice versa
+ future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close());
+ localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
+ } else {
+ localContext.close();
+ }
+ });
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) {
+ pendingReads.add(message);
+ }
+
+ @Override
+ public void handlerRemoved(final ChannelHandlerContext context) {
+ pendingReads.forEach(ReferenceCountUtil::release);
+ pendingReads.clear();
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/InboundCloseWebSocketFrameHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/InboundCloseWebSocketFrameHandler.java
new file mode 100644
index 000000000..e946c937a
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/InboundCloseWebSocketFrameHandler.java
@@ -0,0 +1,23 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
+
+class InboundCloseWebSocketFrameHandler extends ChannelInboundHandlerAdapter {
+
+ private final WebSocketCloseListener webSocketCloseListener;
+
+ public InboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) {
+ this.webSocketCloseListener = webSocketCloseListener;
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
+ if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
+ webSocketCloseListener.handleWebSocketClosedByServer(closeWebSocketFrame.statusCode());
+ }
+
+ super.channelRead(context, message);
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java
new file mode 100644
index 000000000..fc8b68e85
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java
@@ -0,0 +1,47 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import java.util.Optional;
+import org.signal.libsignal.protocol.ecc.ECPublicKey;
+
+class NoiseNXClientHandshakeHandler extends AbstractNoiseClientHandler {
+
+ private boolean receivedServerStaticKeyMessage = false;
+
+ NoiseNXClientHandshakeHandler(final ECPublicKey rootPublicKey) {
+ super(rootPublicKey);
+ }
+
+ @Override
+ protected String getNoiseProtocolName() {
+ return NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME;
+ }
+
+ @Override
+ protected void startHandshake() {
+ getHandshakeState().start();
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
+ if (message instanceof BinaryWebSocketFrame frame) {
+ try {
+ // Don't process additional messages if we're just waiting to close because the handshake failed
+ if (receivedServerStaticKeyMessage) {
+ return;
+ }
+
+ receivedServerStaticKeyMessage = true;
+ handleServerStaticKeyMessage(context, frame);
+
+ context.pipeline().replace(this, null, new NoiseStreamHandler(getHandshakeState().split()));
+ context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
+ } finally {
+ frame.release();
+ }
+ } else {
+ context.fireChannelRead(message);
+ }
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java
new file mode 100644
index 000000000..8ceb485e7
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java
@@ -0,0 +1,84 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import com.southernstorm.noise.protocol.HandshakeState;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import java.security.NoSuchAlgorithmException;
+import java.util.Optional;
+import javax.crypto.BadPaddingException;
+import javax.crypto.ShortBufferException;
+import org.junit.jupiter.api.Test;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+
+class NoiseNXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest {
+
+ @Override
+ protected NoiseNXHandshakeHandler getHandler(final ECKeyPair serverKeyPair,
+ final byte[] serverPublicKeySignature) {
+
+ return new NoiseNXHandshakeHandler(serverKeyPair, serverPublicKeySignature);
+ }
+
+ @Test
+ void handleCompleteHandshake()
+ throws NoSuchAlgorithmException, ShortBufferException, InterruptedException, BadPaddingException {
+
+ final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
+
+ assertNotNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class));
+
+ final HandshakeState clientHandshakeState =
+ new HandshakeState(NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
+
+ clientHandshakeState.start();
+
+ {
+ final byte[] ephemeralKeyMessageBytes = new byte[32];
+ clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0);
+
+ final BinaryWebSocketFrame ephemeralKeyMessageFrame =
+ new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes));
+
+ assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess());
+ assertEquals(0, ephemeralKeyMessageFrame.refCnt());
+ }
+
+ {
+ assertEquals(1, embeddedChannel.outboundMessages().size());
+
+ final BinaryWebSocketFrame serverStaticKeyMessageFrame =
+ (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
+
+ @SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
+ new byte[serverStaticKeyMessageFrame.content().readableBytes()];
+
+ serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
+
+ final byte[] serverPublicKeySignature = new byte[64];
+
+ final int payloadLength =
+ clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0);
+
+ assertEquals(serverPublicKeySignature.length, payloadLength);
+
+ final byte[] serverPublicKey = new byte[32];
+ clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
+
+ assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature));
+ }
+
+ assertEquals(new NoiseHandshakeCompleteEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
+
+ assertNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class),
+ "Handshake handler should remove self from pipeline after successful handshake");
+
+ assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
+ "Handshake handler should insert a Noise stream handler after successful handshake");
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandlerTest.java
new file mode 100644
index 000000000..749818d99
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandlerTest.java
@@ -0,0 +1,135 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import com.southernstorm.noise.protocol.CipherStatePair;
+import com.southernstorm.noise.protocol.HandshakeState;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufUtil;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import io.netty.util.internal.EmptyArrays;
+import io.netty.util.internal.ThreadLocalRandom;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import javax.crypto.AEADBadTagException;
+import javax.crypto.BadPaddingException;
+import javax.crypto.ShortBufferException;
+import java.nio.charset.StandardCharsets;
+import java.security.NoSuchAlgorithmException;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+class NoiseStreamHandlerTest extends AbstractLeakDetectionTest {
+
+ private CipherStatePair clientCipherStatePair;
+ private EmbeddedChannel embeddedChannel;
+
+ // We use an NN handshake for this test just because it's a little shorter and easier to set up
+ private static final String NOISE_PROTOCOL_NAME = "Noise_NN_25519_ChaChaPoly_BLAKE2b";
+
+ @BeforeEach
+ void setUp() throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException {
+ final HandshakeState clientHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
+ final HandshakeState serverHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.RESPONDER);
+
+ clientHandshakeState.start();
+ serverHandshakeState.start();
+
+ final byte[] clientEphemeralKeyMessage = new byte[32];
+ assertEquals(clientEphemeralKeyMessage.length,
+ clientHandshakeState.writeMessage(clientEphemeralKeyMessage, 0, null, 0, 0));
+
+ serverHandshakeState.readMessage(clientEphemeralKeyMessage, 0, clientEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
+
+ // 32 bytes of key material plus a 16-byte MAC
+ final byte[] serverEphemeralKeyMessage = new byte[48];
+ assertEquals(serverEphemeralKeyMessage.length,
+ serverHandshakeState.writeMessage(serverEphemeralKeyMessage, 0, null, 0, 0));
+
+ clientHandshakeState.readMessage(serverEphemeralKeyMessage, 0, serverEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
+
+ clientCipherStatePair = clientHandshakeState.split();
+ embeddedChannel = new EmbeddedChannel(new NoiseStreamHandler(serverHandshakeState.split()));
+
+ clientHandshakeState.destroy();
+ serverHandshakeState.destroy();
+ }
+
+ @Test
+ void channelRead() throws ShortBufferException, InterruptedException {
+ final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
+ final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
+ clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
+
+ final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext));
+ assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
+ assertEquals(0, ciphertextFrame.refCnt());
+
+ final ByteBuf decryptedPlaintextBuffer = (ByteBuf) embeddedChannel.inboundMessages().poll();
+ assertNotNull(decryptedPlaintextBuffer);
+ assertTrue(embeddedChannel.inboundMessages().isEmpty());
+
+ final byte[] decryptedPlaintext = ByteBufUtil.getBytes(decryptedPlaintextBuffer);
+ decryptedPlaintextBuffer.release();
+
+ assertArrayEquals(plaintext, decryptedPlaintext);
+ }
+
+ @Test
+ void channelReadBadCiphertext() throws InterruptedException {
+ final byte[] bogusCiphertext = new byte[32];
+ ThreadLocalRandom.current().nextBytes(bogusCiphertext);
+
+ final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext));
+ final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
+
+ assertEquals(0, ciphertextFrame.refCnt());
+ assertFalse(readCiphertextFuture.isSuccess());
+ assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause());
+ assertTrue(embeddedChannel.inboundMessages().isEmpty());
+ }
+
+ @Test
+ void channelReadUnexpectedMessageType() throws InterruptedException {
+ final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await();
+
+ assertFalse(readFuture.isSuccess());
+ assertInstanceOf(IllegalArgumentException.class, readFuture.cause());
+ assertTrue(embeddedChannel.inboundMessages().isEmpty());
+ }
+
+ @Test
+ void write() throws InterruptedException, ShortBufferException, BadPaddingException {
+ final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
+ final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext);
+
+ final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
+ assertTrue(writePlaintextFuture.await().isSuccess());
+ assertEquals(0, plaintextBuffer.refCnt());
+
+ final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
+ assertNotNull(ciphertextFrame);
+ assertTrue(embeddedChannel.outboundMessages().isEmpty());
+
+ final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
+ ciphertextFrame.release();
+
+ final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
+ clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length);
+
+ assertArrayEquals(plaintext, decryptedPlaintext);
+ }
+
+ @Test
+ void writeUnexpectedMessageType() throws InterruptedException {
+ final Object unexpectedMessaged = new Object();
+
+ final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged);
+ assertTrue(writeFuture.await().isSuccess());
+
+ assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll());
+ assertTrue(embeddedChannel.outboundMessages().isEmpty());
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java
new file mode 100644
index 000000000..1966cee12
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java
@@ -0,0 +1,89 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import com.southernstorm.noise.protocol.HandshakeState;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import java.nio.ByteBuffer;
+import java.util.Optional;
+import java.util.UUID;
+import javax.crypto.ShortBufferException;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.signal.libsignal.protocol.ecc.ECPublicKey;
+
+class NoiseXXClientHandshakeHandler extends AbstractNoiseClientHandler {
+
+ private final ECKeyPair ecKeyPair;
+
+ private final UUID accountIdentifier;
+ private final byte deviceId;
+
+ private boolean receivedServerStaticKeyMessage = false;
+
+ NoiseXXClientHandshakeHandler(final ECKeyPair ecKeyPair,
+ final ECPublicKey rootPublicKey,
+ final UUID accountIdentifier,
+ final byte deviceId) {
+
+ super(rootPublicKey);
+
+ this.ecKeyPair = ecKeyPair;
+
+ this.accountIdentifier = accountIdentifier;
+ this.deviceId = deviceId;
+ }
+
+ @Override
+ protected String getNoiseProtocolName() {
+ return NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME;
+ }
+
+ @Override
+ protected void startHandshake() {
+ final HandshakeState handshakeState = getHandshakeState();
+
+ // Noise-java derives the public key from the private key, so we just need to set the private key
+ handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0);
+ handshakeState.start();
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext context, final Object message)
+ throws NoiseHandshakeException, ShortBufferException {
+ if (message instanceof BinaryWebSocketFrame frame) {
+ try {
+ // Don't process additional messages if the handshake failed and we're just waiting to close
+ if (receivedServerStaticKeyMessage) {
+ return;
+ }
+
+ receivedServerStaticKeyMessage = true;
+ handleServerStaticKeyMessage(context, frame);
+
+ final ByteBuffer clientIdentityBuffer = ByteBuffer.allocate(17);
+ clientIdentityBuffer.putLong(accountIdentifier.getMostSignificantBits());
+ clientIdentityBuffer.putLong(accountIdentifier.getLeastSignificantBits());
+ clientIdentityBuffer.put(deviceId);
+ clientIdentityBuffer.flip();
+
+ final HandshakeState handshakeState = getHandshakeState();
+
+ // We're sending two 32-byte keys plus the client identity payload
+ final byte[] staticKeyAndIdentityMessage = new byte[64 + clientIdentityBuffer.remaining()];
+ handshakeState.writeMessage(
+ staticKeyAndIdentityMessage, 0, clientIdentityBuffer.array(), 0, clientIdentityBuffer.remaining());
+
+ context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(staticKeyAndIdentityMessage)))
+ .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
+
+ context.pipeline().replace(this, null, new NoiseStreamHandler(handshakeState.split()));
+ context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
+ } finally {
+ frame.release();
+ }
+ } else {
+ context.fireChannelRead(message);
+ }
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java
new file mode 100644
index 000000000..239762aa5
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java
@@ -0,0 +1,454 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import com.southernstorm.noise.protocol.CipherState;
+import com.southernstorm.noise.protocol.HandshakeState;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import java.nio.ByteBuffer;
+import java.security.NoSuchAlgorithmException;
+import java.util.Optional;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ThreadLocalRandom;
+import javax.crypto.BadPaddingException;
+import javax.crypto.ShortBufferException;
+import io.netty.util.ReferenceCountUtil;
+import io.netty.util.internal.EmptyArrays;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.signal.libsignal.protocol.ecc.Curve;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.signal.libsignal.protocol.ecc.ECPublicKey;
+import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
+import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
+import org.whispersystems.textsecuregcm.storage.Device;
+
+class NoiseXXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest {
+
+ private ClientPublicKeysManager clientPublicKeysManager;
+
+ @Override
+ @BeforeEach
+ void setUp() {
+ clientPublicKeysManager = mock(ClientPublicKeysManager.class);
+
+ super.setUp();
+ }
+
+ @Override
+ protected NoiseXXHandshakeHandler getHandler(final ECKeyPair serverKeyPair,
+ final byte[] serverPublicKeySignature) {
+
+ return new NoiseXXHandshakeHandler(clientPublicKeysManager, serverKeyPair, serverPublicKeySignature);
+ }
+
+ @Test
+ void handleCompleteHandshake()
+ throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
+
+ final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
+
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
+ final ECKeyPair clientKeyPair = Curve.generateKeyPair();
+
+ when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
+
+ final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
+ sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
+
+ // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
+ // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
+ // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
+ // and issue a "handshake complete" event.
+ embeddedChannel.runPendingTasks();
+
+ assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
+ getNoiseHandshakeCompleteEvent());
+
+ assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
+ "Handshake handler should remove self from pipeline after successful handshake");
+
+ assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
+ "Handshake handler should insert a Noise stream handler after successful handshake");
+ }
+
+ @Test
+ void handleCompleteHandshakeMissingIdentityInformation()
+ throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
+
+ final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
+
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
+ final ECKeyPair clientKeyPair = Curve.generateKeyPair();
+
+ when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
+
+ final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
+
+ {
+ final byte[] clientStaticKeyMessageBytes = new byte[64];
+ final int messageLength =
+ clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, EmptyArrays.EMPTY_BYTES, 0, 0);
+
+ assertEquals(clientStaticKeyMessageBytes.length, messageLength);
+
+ final BinaryWebSocketFrame clientStaticKeyMessageFrame =
+ new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
+
+ final ChannelFuture writeClientStaticKeyMessageFuture =
+ getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await();
+
+ assertFalse(writeClientStaticKeyMessageFuture.isSuccess());
+ assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause());
+ assertEquals(0, clientStaticKeyMessageFrame.refCnt());
+ }
+
+ // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
+ // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
+ // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
+ // and issue a "handshake complete" event.
+ embeddedChannel.runPendingTasks();
+
+ assertNull(getNoiseHandshakeCompleteEvent());
+
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
+ "Handshake handler should not remove self from pipeline after failed handshake");
+
+ assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
+ "Noise stream handler should not be added to pipeline after failed handshake");
+ }
+
+ @Test
+ void handleCompleteHandshakeMalformedIdentityInformation()
+ throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
+
+ final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
+
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
+ final ECKeyPair clientKeyPair = Curve.generateKeyPair();
+
+ when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
+
+ final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
+
+ {
+ final byte[] clientStaticKeyMessageBytes = new byte[96];
+ final int messageLength =
+ clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, new byte[32], 0, 32);
+
+ assertEquals(clientStaticKeyMessageBytes.length, messageLength);
+
+ final BinaryWebSocketFrame clientStaticKeyMessageFrame =
+ new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
+
+ final ChannelFuture writeClientStaticKeyMessageFuture =
+ getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await();
+
+ assertFalse(writeClientStaticKeyMessageFuture.isSuccess());
+ assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause());
+ assertEquals(0, clientStaticKeyMessageFrame.refCnt());
+ }
+
+ // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
+ // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
+ // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
+ // and issue a "handshake complete" event.
+ embeddedChannel.runPendingTasks();
+
+ assertNull(getNoiseHandshakeCompleteEvent());
+
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
+ "Handshake handler should not remove self from pipeline after failed handshake");
+
+ assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
+ "Noise stream handler should not be added to pipeline after failed handshake");
+ }
+
+ @Test
+ void handleCompleteHandshakeUnrecognizedDevice()
+ throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
+
+ final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
+
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
+ final ECKeyPair clientKeyPair = Curve.generateKeyPair();
+
+ when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
+ .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
+
+ final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
+ sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
+
+ // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
+ // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
+ // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
+ // and issue a "handshake complete" event.
+ embeddedChannel.runPendingTasks();
+
+ assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException);
+
+ assertNull(getNoiseHandshakeCompleteEvent());
+
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
+ "Handshake handler should not remove self from pipeline after failed handshake");
+
+ assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
+ "Noise stream handler should not be added to pipeline after failed handshake");
+ }
+
+ @Test
+ void handleCompleteHandshakePublicKeyMismatch()
+ throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
+
+ final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
+
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
+ final ECKeyPair clientKeyPair = Curve.generateKeyPair();
+
+ when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
+
+ final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
+ sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
+
+ // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
+ // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
+ // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
+ // and issue a "handshake complete" event.
+ embeddedChannel.runPendingTasks();
+
+ assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException);
+
+ assertNull(getNoiseHandshakeCompleteEvent());
+
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
+ "Handshake handler should not remove self from pipeline after failed handshake");
+
+ assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
+ "Noise stream handler should not be added to pipeline after failed handshake");
+ }
+
+ @Test
+ void handleCompleteHandshakeBufferedReads()
+ throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
+
+ final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
+
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
+ final ECKeyPair clientKeyPair = Curve.generateKeyPair();
+
+ final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>();
+
+ when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
+
+ final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
+ sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
+
+ final ByteBuf[] additionalMessages = new ByteBuf[4];
+ final CipherState senderState = clientHandshakeState.split().getSender();
+
+ try {
+ for (int i = 0; i < additionalMessages.length; i++) {
+ final byte[] contentBytes = new byte[32];
+ ThreadLocalRandom.current().nextBytes(contentBytes);
+
+ // Copy the "plaintext" portion of the content bytes for future assertions
+ additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16);
+
+ // Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD
+ // tag
+ senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16);
+
+ assertTrue(
+ embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await()
+ .isSuccess());
+ }
+
+ findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
+
+ // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
+ // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
+ // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
+ // and issue a "handshake complete" event.
+ embeddedChannel.runPendingTasks();
+
+ assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
+ getNoiseHandshakeCompleteEvent());
+
+ assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
+ "Handshake handler should remove self from pipeline after successful handshake");
+
+ assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
+ "Handshake handler should insert a Noise stream handler after successful handshake");
+
+ for (final ByteBuf additionalMessage : additionalMessages) {
+ assertEquals(additionalMessage, embeddedChannel.inboundMessages().poll(),
+ "Buffered message should pass through pipeline after successful handshake");
+ }
+ } finally {
+ for (final ByteBuf additionalMessage : additionalMessages) {
+ additionalMessage.release();
+ }
+ }
+ }
+
+ @Test
+ void handleCompleteHandshakeFailureBufferedReads()
+ throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
+
+ final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
+
+ final UUID accountIdentifier = UUID.randomUUID();
+ final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
+ final ECKeyPair clientKeyPair = Curve.generateKeyPair();
+
+ final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>();
+
+ when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
+
+ final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
+ sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
+
+ final ByteBuf[] additionalMessages = new ByteBuf[4];
+ final CipherState senderState = clientHandshakeState.split().getSender();
+
+ try {
+ for (int i = 0; i < additionalMessages.length; i++) {
+ final byte[] contentBytes = new byte[32];
+ ThreadLocalRandom.current().nextBytes(contentBytes);
+
+ // Copy the "plaintext" portion of the content bytes for future assertions
+ additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16);
+
+ // Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD
+ // tag
+ senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16);
+
+ assertTrue(embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await().isSuccess());
+ }
+
+ findPublicKeyFuture.complete(Optional.empty());
+
+ // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
+ // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
+ // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
+ // and issue a "handshake complete" event.
+ embeddedChannel.runPendingTasks();
+
+ assertNull(getNoiseHandshakeCompleteEvent());
+
+ assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
+ "Handshake handler should not remove self from pipeline after failed handshake");
+
+ assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
+ "Noise stream handler should not be added to pipeline after failed handshake");
+
+ assertTrue(embeddedChannel.inboundMessages().isEmpty(),
+ "Buffered messages should not pass through pipeline after failed handshake");
+ } finally {
+ for (final ByteBuf additionalMessage : additionalMessages) {
+ additionalMessage.release();
+ }
+ }
+ }
+
+ private HandshakeState exchangeClientEphemeralAndServerStaticMessages(final ECKeyPair clientKeyPair)
+ throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException, InterruptedException {
+
+ final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
+
+ final HandshakeState clientHandshakeState =
+ new HandshakeState(NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
+
+ clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0);
+ clientHandshakeState.start();
+
+ {
+ final byte[] ephemeralKeyMessageBytes = new byte[32];
+ clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0);
+
+ final BinaryWebSocketFrame ephemeralKeyMessageFrame =
+ new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes));
+
+ assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess());
+ assertEquals(0, ephemeralKeyMessageFrame.refCnt());
+ }
+
+ {
+ assertEquals(1, embeddedChannel.outboundMessages().size());
+
+ final BinaryWebSocketFrame serverStaticKeyMessageFrame =
+ (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
+
+ @SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
+ new byte[serverStaticKeyMessageFrame.content().readableBytes()];
+
+ serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
+
+ final byte[] serverPublicKeySignature = new byte[64];
+
+ final int payloadLength =
+ clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0);
+
+ assertEquals(serverPublicKeySignature.length, payloadLength);
+
+ final byte[] serverPublicKey = new byte[32];
+ clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
+
+ assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature));
+ }
+
+ return clientHandshakeState;
+ }
+
+ private void sendClientStaticKey(final HandshakeState handshakeState, final UUID accountIdentifier, final byte deviceId)
+ throws ShortBufferException, InterruptedException {
+
+ final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17);
+ clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits());
+ clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits());
+ clientIdentityPayloadBuffer.put(deviceId);
+ clientIdentityPayloadBuffer.flip();
+
+ final byte[] clientStaticKeyMessageBytes = new byte[81];
+ final int messageLength =
+ handshakeState.writeMessage(clientStaticKeyMessageBytes, 0, clientIdentityPayloadBuffer.array(), 0, clientIdentityPayloadBuffer.remaining());
+
+ assertEquals(clientStaticKeyMessageBytes.length, messageLength);
+
+ final BinaryWebSocketFrame clientStaticKeyMessageFrame =
+ new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
+
+ assertTrue(getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await().isSuccess());
+ assertEquals(0, clientStaticKeyMessageFrame.refCnt());
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseWebSocketFrameHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseWebSocketFrameHandler.java
new file mode 100644
index 000000000..682b22d90
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseWebSocketFrameHandler.java
@@ -0,0 +1,24 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
+import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
+
+class OutboundCloseWebSocketFrameHandler extends ChannelOutboundHandlerAdapter {
+
+ private final WebSocketCloseListener webSocketCloseListener;
+
+ OutboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) {
+ this.webSocketCloseListener = webSocketCloseListener;
+ }
+
+ @Override
+ public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
+ if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
+ webSocketCloseListener.handleWebSocketClosedByClient(closeWebSocketFrame.statusCode());
+ }
+
+ super.write(context, message, promise);
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java
new file mode 100644
index 000000000..c16a6468d
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java
@@ -0,0 +1,72 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
+import io.netty.handler.codec.http.websocketx.WebSocketFrame;
+import io.netty.util.ReferenceCountUtil;
+import java.util.List;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+
+class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest {
+
+ private EmbeddedChannel embeddedChannel;
+
+ @BeforeEach
+ void setUp() {
+ embeddedChannel = new EmbeddedChannel(new RejectUnsupportedMessagesHandler());
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void allowWebSocketFrame(final WebSocketFrame frame) {
+ embeddedChannel.writeOneInbound(frame);
+
+ try {
+ assertEquals(frame, embeddedChannel.inboundMessages().poll());
+ assertTrue(embeddedChannel.inboundMessages().isEmpty());
+ assertEquals(1, frame.refCnt());
+ } finally {
+ frame.release();
+ }
+ }
+
+ private static List allowWebSocketFrame() {
+ return List.of(
+ new BinaryWebSocketFrame(),
+ new CloseWebSocketFrame(),
+ new ContinuationWebSocketFrame(),
+ new PingWebSocketFrame(),
+ new PongWebSocketFrame());
+ }
+
+ @Test
+ void rejectTextFrame() {
+ final TextWebSocketFrame textFrame = new TextWebSocketFrame();
+ embeddedChannel.writeOneInbound(textFrame);
+
+ assertTrue(embeddedChannel.inboundMessages().isEmpty());
+ assertEquals(0, textFrame.refCnt());
+ }
+
+ @Test
+ void rejectNonWebSocketFrame() {
+ final ByteBuf bytes = Unpooled.buffer(0);
+ embeddedChannel.writeOneInbound(bytes);
+
+ assertTrue(embeddedChannel.inboundMessages().isEmpty());
+ assertEquals(0, bytes.refCnt());
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketCloseListener.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketCloseListener.java
new file mode 100644
index 000000000..38253eca7
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketCloseListener.java
@@ -0,0 +1,18 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+interface WebSocketCloseListener {
+
+ WebSocketCloseListener NOOP_LISTENER = new WebSocketCloseListener() {
+ @Override
+ public void handleWebSocketClosedByClient(final int statusCode) {
+ }
+
+ @Override
+ public void handleWebSocketClosedByServer(final int statusCode) {
+ }
+ };
+
+ void handleWebSocketClosedByClient(int statusCode);
+
+ void handleWebSocketClosedByServer(int statusCode);
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java
new file mode 100644
index 000000000..fb31da1c6
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java
@@ -0,0 +1,70 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.local.LocalAddress;
+import io.netty.channel.local.LocalChannel;
+import io.netty.channel.local.LocalServerChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import java.net.SocketAddress;
+import java.net.URI;
+import java.security.cert.X509Certificate;
+import java.util.UUID;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.signal.libsignal.protocol.ecc.ECPublicKey;
+import javax.annotation.Nullable;
+
+class WebSocketNoiseTunnelClient implements AutoCloseable {
+
+ private final ServerBootstrap serverBootstrap;
+ private Channel serverChannel;
+
+ static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
+ static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
+
+ public WebSocketNoiseTunnelClient(final SocketAddress remoteServerAddress,
+ final URI websocketUri,
+ final boolean authenticated,
+ final ECKeyPair ecKeyPair,
+ final ECPublicKey rootPublicKey,
+ @Nullable final UUID accountIdentifier,
+ final byte deviceId,
+ final X509Certificate trustedServerCertificate,
+ final NioEventLoopGroup eventLoopGroup,
+ final WebSocketCloseListener webSocketCloseListener) {
+
+ this.serverBootstrap = new ServerBootstrap()
+ .localAddress(new LocalAddress("websocket-noise-tunnel-client"))
+ .channel(LocalServerChannel.class)
+ .group(eventLoopGroup)
+ .childHandler(new ChannelInitializer() {
+ @Override
+ protected void initChannel(final LocalChannel localChannel) {
+ localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(trustedServerCertificate,
+ websocketUri,
+ authenticated,
+ ecKeyPair,
+ rootPublicKey,
+ accountIdentifier,
+ deviceId,
+ remoteServerAddress,
+ webSocketCloseListener));
+ }
+ });
+ }
+
+ LocalAddress getLocalAddress() {
+ return (LocalAddress) serverChannel.localAddress();
+ }
+
+ WebSocketNoiseTunnelClient start() throws InterruptedException {
+ serverChannel = serverBootstrap.bind().await().channel();
+ return this;
+ }
+
+ @Override
+ public void close() throws InterruptedException {
+ serverChannel.close().await();
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java
new file mode 100644
index 000000000..431edec62
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java
@@ -0,0 +1,486 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyByte;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import io.grpc.ManagedChannel;
+import io.grpc.ServerBuilder;
+import io.grpc.Status;
+import io.grpc.netty.NettyChannelBuilder;
+import io.netty.channel.DefaultEventLoopGroup;
+import io.netty.channel.local.LocalAddress;
+import io.netty.channel.local.LocalChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.nio.charset.StandardCharsets;
+import java.security.KeyFactory;
+import java.security.KeyStore;
+import java.security.NoSuchAlgorithmException;
+import java.security.PrivateKey;
+import java.security.SecureRandom;
+import java.security.cert.CertificateException;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+import java.security.spec.InvalidKeySpecException;
+import java.security.spec.PKCS8EncodedKeySpec;
+import java.util.Base64;
+import java.util.Optional;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.TrustManagerFactory;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.signal.chat.rpc.AuthenticationTypeGrpc;
+import org.signal.chat.rpc.GetAuthenticatedRequest;
+import org.signal.chat.rpc.GetAuthenticatedResponse;
+import org.signal.libsignal.protocol.ecc.Curve;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.signal.libsignal.protocol.ecc.ECPublicKey;
+import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
+import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
+import org.whispersystems.textsecuregcm.storage.Device;
+
+class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
+
+ private static NioEventLoopGroup nioEventLoopGroup;
+ private static DefaultEventLoopGroup defaultEventLoopGroup;
+ private static ExecutorService delegatedTaskExecutor;
+
+ private static X509Certificate serverTlsCertificate;
+
+ private ClientPublicKeysManager clientPublicKeysManager;
+
+ private ECKeyPair rootKeyPair;
+ private ECKeyPair clientKeyPair;
+
+ private ManagedLocalGrpcServer authenticatedGrpcServer;
+ private ManagedLocalGrpcServer anonymousGrpcServer;
+
+ private WebsocketNoiseTunnelServer websocketNoiseTunnelServer;
+
+ private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
+ private static final byte DEVICE_ID = Device.PRIMARY_ID;
+
+ // Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
+ // They were generated with:
+ //
+ // ```shell
+ // openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost"
+ // ```
+ private static final String SERVER_CERTIFICATE = """
+ -----BEGIN CERTIFICATE-----
+ MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw
+ FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx
+ MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA
+ IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV
+ jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq
+ SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME
+ GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG
+ SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw
+ XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi
+ iOr9sHiO8Rn2u0xRKgU5Ig==
+ -----END CERTIFICATE-----
+ """;
+
+ // BEGIN/END PRIVATE KEY header/footer removed for easier parsing
+ private static final String SERVER_PRIVATE_KEY = """
+ MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj
+ kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd
+ PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA
+ O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo=
+ """;
+
+ @BeforeAll
+ static void setUpBeforeAll() throws CertificateException {
+ nioEventLoopGroup = new NioEventLoopGroup();
+ defaultEventLoopGroup = new DefaultEventLoopGroup();
+ delegatedTaskExecutor = Executors.newSingleThreadExecutor();
+
+ final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
+ serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
+ new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8)));
+ }
+
+ @BeforeEach
+ void setUp() throws NoSuchAlgorithmException, InvalidKeySpecException, IOException, InterruptedException {
+
+ final PrivateKey serverTlsPrivateKey;
+ {
+ final KeyFactory keyFactory = KeyFactory.getInstance("EC");
+ serverTlsPrivateKey =
+ keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
+ }
+
+ rootKeyPair = Curve.generateKeyPair();
+ clientKeyPair = Curve.generateKeyPair();
+ final ECKeyPair serverKeyPair = Curve.generateKeyPair();
+
+ clientPublicKeysManager = mock(ClientPublicKeysManager.class);
+ when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
+ .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
+
+ when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
+
+ final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated");
+ final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous");
+
+ authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
+ @Override
+ protected void configureServer(final ServerBuilder> serverBuilder) {
+ serverBuilder.addService(new AuthenticationTypeService(true));
+ }
+ };
+
+ authenticatedGrpcServer.start();
+
+ anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
+ @Override
+ protected void configureServer(final ServerBuilder> serverBuilder) {
+ serverBuilder.addService(new AuthenticationTypeService(false));
+ }
+ };
+
+ anonymousGrpcServer.start();
+
+ websocketNoiseTunnelServer = new WebsocketNoiseTunnelServer(0,
+ new X509Certificate[] { serverTlsCertificate },
+ serverTlsPrivateKey,
+ nioEventLoopGroup,
+ delegatedTaskExecutor,
+ clientPublicKeysManager,
+ serverKeyPair,
+ rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
+ authenticatedGrpcServerAddress,
+ anonymousGrpcServerAddress);
+
+ websocketNoiseTunnelServer.start();
+ }
+
+ @AfterEach
+ void tearDown() throws InterruptedException {
+ websocketNoiseTunnelServer.stop();
+ authenticatedGrpcServer.stop();
+ anonymousGrpcServer.stop();
+ }
+
+ @AfterAll
+ static void tearDownAfterAll() throws InterruptedException {
+ nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
+ defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
+
+ delegatedTaskExecutor.shutdown();
+ //noinspection ResultOfMethodCallIgnored
+ delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
+ }
+
+ @Test
+ void connectAuthenticated() throws InterruptedException {
+ try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAuthenticatedClient()) {
+ final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
+
+ try {
+ final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel)
+ .getAuthenticated(GetAuthenticatedRequest.newBuilder().build());
+
+ assertTrue(response.getAuthenticated());
+ } finally {
+ channel.shutdown();
+ }
+ }
+ }
+
+ @Test
+ void connectAuthenticatedBadServerKeySignature() throws InterruptedException {
+ final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
+
+
+ // Try to verify the server's public key with something other than the key with which it was signed
+ try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
+ buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) {
+
+ final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
+
+ try {
+ //noinspection ResultOfMethodCallIgnored
+ GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
+ () -> AuthenticationTypeGrpc.newBlockingStub(channel)
+ .getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
+ } finally {
+ channel.shutdown();
+ }
+ }
+
+ verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
+ }
+
+ @Test
+ void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException {
+ final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
+
+ when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
+
+ try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
+ buildAndStartAuthenticatedClient(webSocketCloseListener)) {
+
+ final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
+
+ try {
+ //noinspection ResultOfMethodCallIgnored
+ GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
+ () -> AuthenticationTypeGrpc.newBlockingStub(channel)
+ .getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
+ } finally {
+ channel.shutdown();
+ }
+ }
+
+ verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
+ }
+
+ @Test
+ void connectAuthenticatedUnrecognizedDevice() throws InterruptedException {
+ final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
+
+ when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
+ .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
+
+ try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
+ buildAndStartAuthenticatedClient(webSocketCloseListener)) {
+
+ final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
+
+ try {
+ //noinspection ResultOfMethodCallIgnored
+ GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
+ () -> AuthenticationTypeGrpc.newBlockingStub(channel)
+ .getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
+ } finally {
+ channel.shutdown();
+ }
+ }
+
+ verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
+ }
+
+ @Test
+ void connectAuthenticatedToAnonymousService() throws InterruptedException {
+ final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
+
+ try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(
+ websocketNoiseTunnelServer.getLocalAddress(),
+ URI.create("wss://localhost/anonymous"),
+ true,
+ clientKeyPair,
+ rootKeyPair.getPublicKey(),
+ ACCOUNT_IDENTIFIER,
+ DEVICE_ID,
+ serverTlsCertificate,
+ nioEventLoopGroup,
+ webSocketCloseListener)
+ .start()) {
+
+ final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
+
+ try {
+ //noinspection ResultOfMethodCallIgnored
+ GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
+ () -> AuthenticationTypeGrpc.newBlockingStub(channel)
+ .getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
+ } finally {
+ channel.shutdown();
+ }
+ }
+
+ verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
+ }
+
+ @Test
+ void connectAnonymous() throws InterruptedException {
+ try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAnonymousClient()) {
+ final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
+
+ try {
+ final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel)
+ .getAuthenticated(GetAuthenticatedRequest.newBuilder().build());
+
+ assertFalse(response.getAuthenticated());
+ } finally {
+ channel.shutdown();
+ }
+ }
+ }
+
+ @Test
+ void connectAnonymousBadServerKeySignature() throws InterruptedException {
+ final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
+
+ // Try to verify the server's public key with something other than the key with which it was signed
+ try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
+ buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) {
+
+ final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
+
+ try {
+ //noinspection ResultOfMethodCallIgnored
+ GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
+ () -> AuthenticationTypeGrpc.newBlockingStub(channel)
+ .getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
+ } finally {
+ channel.shutdown();
+ }
+ }
+
+ verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
+ }
+
+ @Test
+ void connectAnonymousToAuthenticatedService() throws InterruptedException {
+ final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
+
+ try (final WebSocketNoiseTunnelClient websocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(
+ websocketNoiseTunnelServer.getLocalAddress(),
+ URI.create("wss://localhost/authenticated"),
+ false,
+ null,
+ rootKeyPair.getPublicKey(),
+ null,
+ (byte) 0,
+ serverTlsCertificate,
+ nioEventLoopGroup,
+ webSocketCloseListener)
+ .start()) {
+
+ final ManagedChannel channel = buildManagedChannel(websocketNoiseTunnelClient.getLocalAddress());
+
+ try {
+ //noinspection ResultOfMethodCallIgnored
+ GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
+ () -> AuthenticationTypeGrpc.newBlockingStub(channel)
+ .getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
+ } finally {
+ channel.shutdown();
+ }
+ }
+
+ verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
+ }
+
+ private ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
+ return NettyChannelBuilder.forAddress(localAddress)
+ .channelType(LocalChannel.class)
+ .eventLoopGroup(defaultEventLoopGroup)
+ .usePlaintext()
+ .build();
+ }
+
+ @Test
+ void rejectIllegalRequests() throws Exception {
+
+ final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
+ keyStore.load(null, null);
+ keyStore.setCertificateEntry("tunnel", serverTlsCertificate);
+
+ final TrustManagerFactory trustManagerFactory =
+ TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+
+ trustManagerFactory.init(keyStore);
+
+ final SSLContext sslContext = SSLContext.getInstance("TLS");
+ sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
+
+ final URI authenticatedUri =
+ new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null);
+
+ final URI incorrectUri =
+ new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null);
+
+ try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
+ assertEquals(405, httpClient.send(HttpRequest.newBuilder()
+ .uri(authenticatedUri)
+ .PUT(HttpRequest.BodyPublishers.ofString("test"))
+ .build(),
+ HttpResponse.BodyHandlers.ofString()).statusCode(),
+ "Non-GET requests should not be allowed");
+
+ assertEquals(426, httpClient.send(HttpRequest.newBuilder()
+ .GET()
+ .uri(authenticatedUri)
+ .build(),
+ HttpResponse.BodyHandlers.ofString()).statusCode(),
+ "GET requests without upgrade headers should not be allowed");
+
+ assertEquals(404, httpClient.send(HttpRequest.newBuilder()
+ .GET()
+ .uri(incorrectUri)
+ .build(),
+ HttpResponse.BodyHandlers.ofString()).statusCode(),
+ "GET requests to unrecognized URIs should not be allowed");
+ }
+ }
+
+ private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException {
+ return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER);
+ }
+
+ private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener)
+ throws InterruptedException {
+
+ return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey());
+ }
+
+ private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener,
+ final ECPublicKey rootPublicKey) throws InterruptedException {
+
+ return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
+ WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
+ true,
+ clientKeyPair,
+ rootPublicKey,
+ ACCOUNT_IDENTIFIER,
+ DEVICE_ID,
+ serverTlsCertificate,
+ nioEventLoopGroup,
+ webSocketCloseListener)
+ .start();
+ }
+
+ private WebSocketNoiseTunnelClient buildAndStartAnonymousClient() throws InterruptedException {
+ return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey());
+ }
+
+ private WebSocketNoiseTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener,
+ final ECPublicKey rootPublicKey) throws InterruptedException {
+
+ return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
+ WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI,
+ false,
+ null,
+ rootPublicKey,
+ null,
+ (byte) 0,
+ serverTlsCertificate,
+ nioEventLoopGroup,
+ webSocketCloseListener)
+ .start();
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandlerTest.java
new file mode 100644
index 000000000..0f27e5efb
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandlerTest.java
@@ -0,0 +1,104 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+
+import io.netty.buffer.Unpooled;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.DefaultFullHttpRequest;
+import io.netty.handler.codec.http.DefaultHttpHeaders;
+import io.netty.handler.codec.http.FullHttpRequest;
+import io.netty.handler.codec.http.FullHttpResponse;
+import io.netty.handler.codec.http.HttpHeaderNames;
+import io.netty.handler.codec.http.HttpHeaderValues;
+import io.netty.handler.codec.http.HttpHeaders;
+import io.netty.handler.codec.http.HttpMethod;
+import io.netty.handler.codec.http.HttpResponseStatus;
+import io.netty.handler.codec.http.HttpVersion;
+import io.netty.util.ReferenceCountUtil;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest {
+
+ private EmbeddedChannel embeddedChannel;
+
+ private static final String AUTHENTICATED_PATH = "/authenticated";
+ private static final String ANONYMOUS_PATH = "/anonymous";
+
+ @BeforeEach
+ void setUp() {
+ embeddedChannel = new EmbeddedChannel(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_PATH, ANONYMOUS_PATH));
+ }
+
+ @ParameterizedTest
+ @ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
+ void handleValidRequest(final String path) {
+ final FullHttpRequest request = buildRequest(HttpMethod.GET, path,
+ new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
+
+ try {
+ embeddedChannel.writeOneInbound(request);
+
+ assertEquals(1, request.refCnt());
+ assertEquals(1, embeddedChannel.inboundMessages().size());
+ assertEquals(request, embeddedChannel.inboundMessages().poll());
+ } finally {
+ request.release();
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
+ void handleUpgradeRequired(final String path) {
+ final FullHttpRequest request = buildRequest(HttpMethod.GET, path, new DefaultHttpHeaders());
+
+ embeddedChannel.writeOneInbound(request);
+
+ assertEquals(0, request.refCnt());
+ assertHttpResponse(HttpResponseStatus.UPGRADE_REQUIRED);
+ }
+
+ @Test
+ void handleBadPath() {
+ final FullHttpRequest request = buildRequest(HttpMethod.GET, "/incorrect",
+ new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
+
+ embeddedChannel.writeOneInbound(request);
+
+ assertEquals(0, request.refCnt());
+ assertHttpResponse(HttpResponseStatus.NOT_FOUND);
+ }
+
+ @ParameterizedTest
+ @ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
+ void handleMethodNotAllowed(final String path) {
+ final FullHttpRequest request = buildRequest(HttpMethod.DELETE, path,
+ new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
+
+ embeddedChannel.writeOneInbound(request);
+
+ assertEquals(0, request.refCnt());
+ assertHttpResponse(HttpResponseStatus.METHOD_NOT_ALLOWED);
+ }
+
+ private void assertHttpResponse(final HttpResponseStatus expectedStatus) {
+ assertEquals(1, embeddedChannel.outboundMessages().size());
+
+ final FullHttpResponse response = assertInstanceOf(FullHttpResponse.class, embeddedChannel.outboundMessages().poll());
+
+ //noinspection DataFlowIssue
+ assertEquals(expectedStatus, response.status());
+ }
+
+ private FullHttpRequest buildRequest(final HttpMethod method, final String path, final HttpHeaders headers) {
+ return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1,
+ method,
+ path,
+ Unpooled.buffer(0),
+ headers,
+ new DefaultHttpHeaders(true));
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListenerTest.java
new file mode 100644
index 000000000..98f901b28
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListenerTest.java
@@ -0,0 +1,91 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.DefaultHttpHeaders;
+import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.signal.libsignal.protocol.ecc.Curve;
+import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.Mockito.mock;
+
+class WebsocketHandshakeCompleteListenerTest extends AbstractLeakDetectionTest {
+
+ private UserEventRecordingHandler userEventRecordingHandler;
+ private EmbeddedChannel embeddedChannel;
+
+ private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
+
+ private final List receivedEvents = new ArrayList<>();
+
+ @Override
+ public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
+ receivedEvents.add(event);
+ }
+
+ public List getReceivedEvents() {
+ return receivedEvents;
+ }
+ }
+
+ @BeforeEach
+ void setUp() {
+ userEventRecordingHandler = new UserEventRecordingHandler();
+
+ embeddedChannel = new EmbeddedChannel(
+ new WebsocketHandshakeCompleteListener(mock(ClientPublicKeysManager.class), Curve.generateKeyPair(), new byte[64]),
+ userEventRecordingHandler);
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void handleWebSocketHandshakeComplete(final String uri, final Class extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
+ final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
+ new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
+
+ embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
+
+ assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
+ assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
+
+ assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
+ }
+
+ private static List handleWebSocketHandshakeComplete() {
+ return List.of(
+ Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
+ Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
+ }
+
+ @Test
+ void handleWebSocketHandshakeCompleteUnexpectedPath() {
+ final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
+ new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
+
+ embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
+
+ assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
+ assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
+ }
+
+ @Test
+ void handleUnrecognizedEvent() {
+ final Object unrecognizedEvent = new Object();
+
+ embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
+ assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
+ }
+}
diff --git a/service/src/test/proto/authentication_type_service.proto b/service/src/test/proto/authentication_type_service.proto
new file mode 100644
index 000000000..53a60a731
--- /dev/null
+++ b/service/src/test/proto/authentication_type_service.proto
@@ -0,0 +1,22 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+syntax = "proto3";
+
+option java_multiple_files = true;
+
+package org.signal.chat.rpc;
+
+// A simple test service that identifies its authentication type to callers
+service AuthenticationType {
+ rpc GetAuthenticated (GetAuthenticatedRequest) returns (GetAuthenticatedResponse) {}
+}
+
+message GetAuthenticatedRequest {
+}
+
+message GetAuthenticatedResponse {
+ bool authenticated = 1;
+}