From a5774bf6ff1a0edf27e953d1f01aab3ec4079dc2 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:42:42 -0500 Subject: [PATCH] Introduce a (dormant) Noise/WebSocket for future client/server communication --- pom.xml | 5 + service/config/sample.yml | 2 - service/pom.xml | 8 +- .../WhisperServerConfiguration.java | 9 - .../textsecuregcm/WhisperServerService.java | 89 +++- .../grpc/GrpcServerManagedWrapper.java | 35 -- .../net/AbstractNoiseHandshakeHandler.java | 124 +++++ .../net/ApplicationWebSocketCloseReason.java | 24 + .../net/ClientAuthenticationException.java | 7 + .../textsecuregcm/grpc/net/ErrorHandler.java | 59 +++ .../EstablishLocalGrpcConnectionHandler.java | 96 ++++ .../net/ManagedDefaultEventLoopGroup.java | 16 + .../grpc/net/ManagedLocalGrpcServer.java | 49 ++ .../grpc/net/ManagedNioEventLoopGroup.java | 16 + .../grpc/net/NoiseHandshakeCompleteEvent.java | 13 + .../grpc/net/NoiseHandshakeException.java | 12 + .../grpc/net/NoiseNXHandshakeHandler.java | 40 ++ .../grpc/net/NoiseStreamHandler.java | 95 ++++ .../grpc/net/NoiseXXHandshakeHandler.java | 178 +++++++ .../textsecuregcm/grpc/net/ProxyHandler.java | 24 + .../net/RejectUnsupportedMessagesHandler.java | 35 ++ .../net/WebSocketOpeningHandshakeHandler.java | 74 +++ .../WebsocketHandshakeCompleteListener.java | 52 ++ .../grpc/net/WebsocketNoiseTunnelServer.java | 116 +++++ .../textsecuregcm/util/UUIDUtil.java | 13 +- .../grpc/net/AbstractLeakDetectionTest.java | 21 + .../grpc/net/AbstractNoiseClientHandler.java | 94 ++++ .../AbstractNoiseHandshakeHandlerTest.java | 141 +++++ .../grpc/net/AuthenticationTypeService.java | 21 + .../grpc/net/ClientErrorHandler.java | 18 + .../net/EstablishRemoteConnectionHandler.java | 141 +++++ .../InboundCloseWebSocketFrameHandler.java | 23 + .../net/NoiseNXClientHandshakeHandler.java | 47 ++ .../grpc/net/NoiseNXHandshakeHandlerTest.java | 84 +++ .../grpc/net/NoiseStreamHandlerTest.java | 135 +++++ .../net/NoiseXXClientHandshakeHandler.java | 89 ++++ .../grpc/net/NoiseXXHandshakeHandlerTest.java | 454 ++++++++++++++++ .../OutboundCloseWebSocketFrameHandler.java | 24 + .../RejectUnsupportedMessagesHandlerTest.java | 72 +++ .../grpc/net/WebSocketCloseListener.java | 18 + .../grpc/net/WebSocketNoiseTunnelClient.java | 70 +++ ...ocketNoiseTunnelServerIntegrationTest.java | 486 ++++++++++++++++++ .../WebSocketOpeningHandshakeHandlerTest.java | 104 ++++ ...ebsocketHandshakeCompleteListenerTest.java | 91 ++++ .../proto/authentication_type_service.proto | 22 + 45 files changed, 3262 insertions(+), 84 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/GrpcServerManagedWrapper.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ApplicationWebSocketCloseReason.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedNioEventLoopGroup.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandler.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListener.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AuthenticationTypeService.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientErrorHandler.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/InboundCloseWebSocketFrameHandler.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseWebSocketFrameHandler.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketCloseListener.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandlerTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListenerTest.java create mode 100644 service/src/test/proto/authentication_type_service.proto diff --git a/pom.xml b/pom.xml index 2ba7feaa9..1c5011d46 100644 --- a/pom.xml +++ b/pom.xml @@ -286,6 +286,11 @@ libsignal-server 0.39.0 + + org.signal.forks + noise-java + 0.1.0 + org.apache.logging.log4j log4j-bom diff --git a/service/config/sample.yml b/service/config/sample.yml index e7a32878f..9f6ec4d7d 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -40,8 +40,6 @@ metrics: - ^lettuce\..+$ reportOnStop: true -grpcPort: 8080 - tlsKeyStore: password: secret://tlsKeyStore.password diff --git a/service/pom.xml b/service/pom.xml index 41fe438f5..bd9f692cd 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -52,6 +52,11 @@ libsignal-server + + org.signal.forks + noise-java + + io.dropwizard dropwizard-core @@ -242,8 +247,7 @@ io.grpc - grpc-netty-shaded - runtime + grpc-netty io.grpc diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index 3c60e922e..450ac7581 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -298,11 +298,6 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private TusConfiguration tus; - @Valid - @NotNull - @JsonProperty - private int grpcPort; - @Valid @NotNull @JsonProperty @@ -539,10 +534,6 @@ public class WhisperServerConfiguration extends Configuration { return tus; } - public int getGrpcPort() { - return grpcPort; - } - public ClientReleaseConfiguration getClientReleaseConfiguration() { return clientRelease; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index eb8b680d5..edacc7973 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -21,13 +21,13 @@ import io.dropwizard.core.setup.Bootstrap; import io.dropwizard.core.setup.Environment; import io.dropwizard.jetty.HttpsConnectorFactory; import io.grpc.ServerBuilder; -import io.grpc.ServerInterceptors; import io.lettuce.core.metrics.MicrometerCommandLatencyRecorder; import io.lettuce.core.metrics.MicrometerOptions; import io.lettuce.core.resource.ClientResources; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.binder.grpc.MetricCollectingServerInterceptor; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; +import io.netty.channel.local.LocalAddress; import java.net.http.HttpClient; import java.time.Clock; import java.time.Duration; @@ -134,13 +134,14 @@ import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService; import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor; import org.whispersystems.textsecuregcm.grpc.ExternalServiceCredentialsAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.ExternalServiceCredentialsGrpcService; -import org.whispersystems.textsecuregcm.grpc.GrpcServerManagedWrapper; import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.KeysGrpcService; import org.whispersystems.textsecuregcm.grpc.PaymentsGrpcService; import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService; import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor; +import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup; +import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; import org.whispersystems.textsecuregcm.limits.PushChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; @@ -753,20 +754,67 @@ public class WhisperServerService extends Application grpcServer = ServerBuilder.forPort(config.getGrpcPort()) - .addService(ServerInterceptors.intercept(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager), basicCredentialAuthenticationInterceptor)) - .addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters)) - .addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters)) - .addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config)) - .addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keysManager, rateLimiters), basicCredentialAuthenticationInterceptor)) - .addService(new KeysAnonymousGrpcService(accountsManager, keysManager)) - .addService(new PaymentsGrpcService(currencyManager)) - .addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager, - config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()), basicCredentialAuthenticationInterceptor)) - .addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkProfileOperations)); + final ManagedDefaultEventLoopGroup localEventLoopGroup = new ManagedDefaultEventLoopGroup(); + + final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager); + final MetricCollectingServerInterceptor metricCollectingServerInterceptor = + new MetricCollectingServerInterceptor(Metrics.globalRegistry); + + final ErrorMappingInterceptor errorMappingInterceptor = new ErrorMappingInterceptor(); + final AcceptLanguageInterceptor acceptLanguageInterceptor = new AcceptLanguageInterceptor(); + final UserAgentInterceptor userAgentInterceptor = new UserAgentInterceptor(); + + final LocalAddress anonymousGrpcServerAddress = new LocalAddress("grpc-anonymous"); + final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("grpc-authenticated"); + + final ManagedLocalGrpcServer anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, localEventLoopGroup) { + @Override + protected void configureServer(final ServerBuilder serverBuilder) { + // Note: interceptors run in the reverse order they are added; the remote deprecation filter + // depends on the user-agent context so it has to come first here! + // http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor- + serverBuilder + // TODO: specialize metrics with user-agent platform + .intercept(metricCollectingServerInterceptor) + .intercept(errorMappingInterceptor) + .intercept(acceptLanguageInterceptor) + .intercept(remoteDeprecationFilter) + .intercept(userAgentInterceptor) + .addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters)) + .addService(new KeysAnonymousGrpcService(accountsManager, keysManager)) + .addService(new PaymentsGrpcService(currencyManager)) + .addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config)) + .addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkProfileOperations)); + } + }; + + final ManagedLocalGrpcServer authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, localEventLoopGroup) { + @Override + protected void configureServer(final ServerBuilder serverBuilder) { + // Note: interceptors run in the reverse order they are added; the remote deprecation filter + // depends on the user-agent context so it has to come first here! + // http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor- + serverBuilder + // TODO: specialize metrics with user-agent platform + .intercept(metricCollectingServerInterceptor) + .intercept(errorMappingInterceptor) + .intercept(acceptLanguageInterceptor) + .intercept(remoteDeprecationFilter) + .intercept(userAgentInterceptor) + .intercept(new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager))) + .addService(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager)) + .addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters)) + .addService(new KeysGrpcService(accountsManager, keysManager, rateLimiters)) + .addService(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager, + config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket())); + } + }; + + environment.lifecycle().manage(localEventLoopGroup); + environment.lifecycle().manage(anonymousGrpcServer); + environment.lifecycle().manage(authenticatedGrpcServer); final List filters = new ArrayList<>(); - final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager); filters.add(remoteDeprecationFilter); filters.add(new RemoteAddressFilter(useRemoteAddress)); @@ -776,19 +824,6 @@ public class WhisperServerService extends Application accountAuthFilter = new BasicCredentialAuthFilter.Builder() .setAuthenticator(accountAuthenticator) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/GrpcServerManagedWrapper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/GrpcServerManagedWrapper.java deleted file mode 100644 index 1b88c290d..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/GrpcServerManagedWrapper.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc; - -import java.io.IOException; -import java.util.concurrent.TimeUnit; - -import io.dropwizard.lifecycle.Managed; -import io.grpc.Server; - -public class GrpcServerManagedWrapper implements Managed { - - private final Server server; - - public GrpcServerManagedWrapper(final Server server) { - this.server = server; - } - - @Override - public void start() throws IOException { - server.start(); - } - - @Override - public void stop() { - try { - server.shutdown().awaitTermination(5, TimeUnit.MINUTES); - } catch (InterruptedException e) { - server.shutdownNow(); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandler.java new file mode 100644 index 000000000..41594644d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandler.java @@ -0,0 +1,124 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.util.internal.EmptyArrays; +import java.security.NoSuchAlgorithmException; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import org.signal.libsignal.protocol.ecc.ECKeyPair; + +/** + * An abstract base class for XX- and NX-patterned Noise responder handshake handlers. + * + * @see The Noise Protocol Framework + */ +abstract class AbstractNoiseHandshakeHandler extends ChannelInboundHandlerAdapter { + + private final ECKeyPair ecKeyPair; + private final byte[] publicKeySignature; + + private final HandshakeState handshakeState; + + private static final int EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH = 32; + + /** + * Constructs a new Noise handler with the given static server keys and static public key signature. The static public + * key must be signed by a trusted root private key whose public key is known to and trusted by authenticating + * clients. + * + * @param noiseProtocolName the name of the Noise protocol implemented by this handshake handler + * @param ecKeyPair the static key pair for this server + * @param publicKeySignature an Ed25519 signature of the raw bytes of the static public key + */ + AbstractNoiseHandshakeHandler(final String noiseProtocolName, + final ECKeyPair ecKeyPair, + final byte[] publicKeySignature) { + + this.ecKeyPair = ecKeyPair; + this.publicKeySignature = publicKeySignature; + + try { + this.handshakeState = new HandshakeState(noiseProtocolName, HandshakeState.RESPONDER); + } catch (final NoSuchAlgorithmException e) { + throw new AssertionError("Unsupported Noise algorithm: " + noiseProtocolName, e); + } + } + + protected HandshakeState getHandshakeState() { + return handshakeState; + } + + /** + * Handles an initial ephemeral key message from a client, advancing the handshake state and sending the server's + * static keys to the client. Both XX and NX patterns begin with a client sending its ephemeral key to the server. + * Clients must not include an additional payload with their ephemeral key message. The server's reply contains its + * static keys along with an Ed25519 signature of its public static key by a trusted root key. + * + * @param context the channel handler context for this message + * @param frame the websocket frame containing the ephemeral key message + * + * @throws NoiseHandshakeException if the ephemeral key message from the client was not of the expected size or if a + * general Noise encryption error occurred + */ + protected void handleEphemeralKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame) + throws NoiseHandshakeException { + + if (frame.content().readableBytes() != EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH) { + throw new NoiseHandshakeException("Unexpected ephemeral key message length"); + } + + // Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is + // making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key + // from the supplied private key (and will in fact overwrite any previously-set public key when setting a private + // key), so we just set the private key here. + handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0); + handshakeState.start(); + + // The initial message from the client should just include a plaintext ephemeral key with no payload. The frame is + // coming off the wire and so will be in a direct buffer that doesn't have a backing array. + final byte[] ephemeralKeyMessage = ByteBufUtil.getBytes(frame.content()); + frame.content().readBytes(ephemeralKeyMessage); + + try { + handshakeState.readMessage(ephemeralKeyMessage, 0, ephemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0); + } catch (final ShortBufferException e) { + // This should never happen since we're checking the length of the frame up front + throw new NoiseHandshakeException("Unexpected client payload"); + } catch (final BadPaddingException e) { + // It turns out this should basically never happen because (a) we're not using padding and (b) the "bad AEAD tag" + // subclass of a bad padding exception can only happen if we have some AD to check, which we don't for an + // ephemeral-key-only message + throw new NoiseHandshakeException("Invalid keys"); + } + + // Send our key material and public key signature back to the client; this buffer will include: + // + // - A 32-byte plaintext ephemeral key + // - A 32-byte encrypted static key + // - A 16-byte AEAD tag for the static key + // - The public key signature payload + // - A 16-byte AEAD tag for the payload + final byte[] keyMaterial = new byte[32 + 32 + 16 + publicKeySignature.length + 16]; + + try { + handshakeState.writeMessage(keyMaterial, 0, publicKeySignature, 0, publicKeySignature.length); + + context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(keyMaterial))) + .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + } catch (final ShortBufferException e) { + // This should never happen for messages of known length that we control + throw new AssertionError("Key material buffer was too short for message", e); + } + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + handshakeState.destroy(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ApplicationWebSocketCloseReason.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ApplicationWebSocketCloseReason.java new file mode 100644 index 000000000..a0ba4ee0b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ApplicationWebSocketCloseReason.java @@ -0,0 +1,24 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; + +enum ApplicationWebSocketCloseReason { + NOISE_HANDSHAKE_ERROR(4001), + CLIENT_AUTHENTICATION_ERROR(4002), + NOISE_ENCRYPTION_ERROR(4003), + REAUTHENTICATION_REQUIRED(4004); + + private final int statusCode; + + ApplicationWebSocketCloseReason(final int statusCode) { + this.statusCode = statusCode; + } + + public int getStatusCode() { + return statusCode; + } + + WebSocketCloseStatus toWebSocketCloseStatus(final String reason) { + return new WebSocketCloseStatus(statusCode, reason); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java new file mode 100644 index 000000000..f3015b7e8 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ClientAuthenticationException.java @@ -0,0 +1,7 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +/** + * Indicates that an attempt to authenticate a remote client failed for some reason. + */ +class ClientAuthenticationException extends Exception { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java new file mode 100644 index 000000000..1828ee689 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java @@ -0,0 +1,59 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import javax.crypto.BadPaddingException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a + * WebSocket handshake, the error handler will send appropriate WebSocket closure codes to the client in an attempt to + * identify the problem. If the client has not completed a WebSocket handshake, the handler simply closes the + * connection. + */ +class ErrorHandler extends ChannelInboundHandlerAdapter { + + private boolean websocketHandshakeComplete = false; + + private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class); + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception { + if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) { + setWebsocketHandshakeComplete(); + } + + context.fireUserEventTriggered(event); + } + + protected void setWebsocketHandshakeComplete() { + this.websocketHandshakeComplete = true; + } + + @Override + public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) { + if (websocketHandshakeComplete) { + final WebSocketCloseStatus webSocketCloseStatus = switch (cause) { + case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage()); + case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated"); + case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error"); + default -> { + log.warn("An unexpected exception reached the end of the pipeline", cause); + yield WebSocketCloseStatus.INTERNAL_SERVER_ERROR; + } + }; + + context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus)) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } else { + // We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful + // way; just close the connection instead. + context.close(); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java new file mode 100644 index 000000000..37be75df0 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java @@ -0,0 +1,96 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; +import io.netty.util.ReferenceCountUtil; +import java.util.ArrayList; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering + * any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC + * server. + */ +class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { + + private final LocalAddress authenticatedGrpcServerAddress; + private final LocalAddress anonymousGrpcServerAddress; + private final List pendingReads = new ArrayList<>(); + + private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class); + + public EstablishLocalGrpcConnectionHandler(final LocalAddress authenticatedGrpcServerAddress, + final LocalAddress anonymousGrpcServerAddress) { + + this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress; + this.anonymousGrpcServerAddress = anonymousGrpcServerAddress; + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) { + pendingReads.add(message); + } + + @Override + public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) throws Exception { + if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) { + // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements + // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to + // connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the + // authenticated service. + final LocalAddress grpcServerAddress = noiseHandshakeCompleteEvent.authenticatedDevice().isPresent() + ? authenticatedGrpcServerAddress + : anonymousGrpcServerAddress; + + new Bootstrap() + .remoteAddress(grpcServerAddress) + // TODO Set local address + .channel(LocalChannel.class) + .group(remoteChannelContext.channel().eventLoop()) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(final LocalChannel localChannel) { + localChannel.pipeline().addLast(new ProxyHandler(remoteChannelContext.channel())); + } + }) + .connect() + .addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + // Close the local connection if the remote channel closes and vice versa + remoteChannelContext.channel().closeFuture().addListener(closeFuture -> future.channel().close()); + future.channel().closeFuture().addListener(closeFuture -> + remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART))); + + remoteChannelContext.pipeline() + .addAfter(remoteChannelContext.name(), null, new ProxyHandler(future.channel())); + + // Flush any buffered reads we accumulated while waiting to open the connection + pendingReads.forEach(remoteChannelContext::fireChannelRead); + pendingReads.clear(); + + remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this); + } else { + log.warn("Failed to establish local connection to gRPC server", future.cause()); + remoteChannelContext.close(); + } + }); + } + + remoteChannelContext.fireUserEventTriggered(event); + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + pendingReads.forEach(ReferenceCountUtil::release); + pendingReads.clear(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java new file mode 100644 index 000000000..5888174c4 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java @@ -0,0 +1,16 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.dropwizard.lifecycle.Managed; +import io.netty.channel.DefaultEventLoopGroup; + +/** + * A wrapper for a Netty {@link DefaultEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing + * Dropwizard to manage the lifecycle of the event loop group. + */ +public class ManagedDefaultEventLoopGroup extends DefaultEventLoopGroup implements Managed { + + @Override + public void stop() throws InterruptedException { + this.shutdownGracefully().await(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java new file mode 100644 index 000000000..21cd3d846 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java @@ -0,0 +1,49 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.dropwizard.lifecycle.Managed; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.netty.NettyServerBuilder; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalServerChannel; +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +/** + * A managed, local gRPC server configures and wraps a gRPC {@link Server} that listens on a Netty {@link LocalAddress} + * and whose lifecycle is managed by Dropwizard via the {@link Managed} interface. + */ +public abstract class ManagedLocalGrpcServer implements Managed { + + private final Server server; + + public ManagedLocalGrpcServer(final LocalAddress localAddress, + final DefaultEventLoopGroup eventLoopGroup) { + + final ServerBuilder serverBuilder = NettyServerBuilder.forAddress(localAddress) + .channelType(LocalServerChannel.class) + .bossEventLoopGroup(eventLoopGroup) + .workerEventLoopGroup(eventLoopGroup); + + configureServer(serverBuilder); + + server = serverBuilder.build(); + } + + protected abstract void configureServer(final ServerBuilder serverBuilder); + + @Override + public void start() throws IOException { + server.start(); + } + + @Override + public void stop() { + try { + server.shutdown().awaitTermination(5, TimeUnit.MINUTES); + } catch (final InterruptedException e) { + server.shutdownNow(); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedNioEventLoopGroup.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedNioEventLoopGroup.java new file mode 100644 index 000000000..06d3e97db --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedNioEventLoopGroup.java @@ -0,0 +1,16 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.dropwizard.lifecycle.Managed; +import io.netty.channel.nio.NioEventLoopGroup; + +/** + * A wrapper for a Netty {@link NioEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing + * Dropwizard to manage the lifecycle of the event loop group. + */ +public class ManagedNioEventLoopGroup extends NioEventLoopGroup implements Managed { + + @Override + public void stop() throws Exception { + this.shutdownGracefully().await(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java new file mode 100644 index 000000000..5a2f1ae99 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeCompleteEvent.java @@ -0,0 +1,13 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import java.util.Optional; + +/** + * An event that indicates that a Noise handshake has completed, possibly authenticating a caller in the process. + * + * @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a + * type that performs authentication + */ +record NoiseHandshakeCompleteEvent(Optional authenticatedDevice) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java new file mode 100644 index 000000000..1a11bcec6 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java @@ -0,0 +1,12 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +/** + * Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or + * a general encryption error). + */ +class NoiseHandshakeException extends Exception { + + public NoiseHandshakeException(final String message) { + super(message); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java new file mode 100644 index 000000000..8cddd32cf --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandler.java @@ -0,0 +1,40 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import java.util.Optional; +import io.netty.util.ReferenceCountUtil; +import org.signal.libsignal.protocol.ecc.ECKeyPair; + +/** + * A Noise NX handler handles the responder side of a Noise NX handshake. + */ +class NoiseNXHandshakeHandler extends AbstractNoiseHandshakeHandler { + + static final String NOISE_PROTOCOL_NAME = "Noise_NX_25519_ChaChaPoly_BLAKE2b"; + + NoiseNXHandshakeHandler(final ECKeyPair ecKeyPair, final byte[] publicKeySignature) { + super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature); + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + if (message instanceof BinaryWebSocketFrame frame) { + try { + handleEphemeralKeyMessage(context, frame); + } finally { + frame.release(); + } + + // All we need to do is accept the client's ephemeral key and send our own static keys; after that, we can consider + // the handshake complete + context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty())); + context.pipeline().replace(NoiseNXHandshakeHandler.this, null, new NoiseStreamHandler(getHandshakeState().split())); + } else { + // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an + // error + ReferenceCountUtil.release(message); + throw new IllegalArgumentException("Unexpected message in pipeline: " + message); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandler.java new file mode 100644 index 000000000..b2776b2cc --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandler.java @@ -0,0 +1,95 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.CipherState; +import com.southernstorm.noise.protocol.CipherStatePair; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A Noise stream handler manages a bidirectional Noise session after a handshake has completed. + */ +class NoiseStreamHandler extends ChannelDuplexHandler { + + private final CipherStatePair cipherStatePair; + + private static final Logger log = LoggerFactory.getLogger(NoiseStreamHandler.class); + + NoiseStreamHandler(CipherStatePair cipherStatePair) { + this.cipherStatePair = cipherStatePair; + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) + throws ShortBufferException, BadPaddingException { + + if (message instanceof BinaryWebSocketFrame frame) { + try { + final CipherState cipherState = cipherStatePair.getReceiver(); + + // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array. + // We'll need to copy it to a heap buffer. + final byte[] noiseBuffer = ByteBufUtil.getBytes(frame.content()); + + // Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer + final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length); + + context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength)); + } finally { + frame.release(); + } + } else { + // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an + // error + ReferenceCountUtil.release(message); + throw new IllegalArgumentException("Unexpected message in pipeline: " + message); + } + } + + @Override + public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception { + if (message instanceof ByteBuf plaintext) { + try { + // TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames + final CipherState cipherState = cipherStatePair.getSender(); + final int plaintextLength = plaintext.readableBytes(); + + // We've read these bytes from a local connection; although that likely means they're backed by a heap array, the + // buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a + // mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC. + final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()]; + plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes()); + + // Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer + cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength); + + context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise); + } finally { + plaintext.release(); + } + } else { + if (!(message instanceof WebSocketFrame)) { + // Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that + // get issued in response to exceptions) + log.warn("Unexpected object in pipeline: {}", message); + } + + context.write(message, promise); + } + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + cipherStatePair.destroy(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java new file mode 100644 index 000000000..2ebc9ea85 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandler.java @@ -0,0 +1,178 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import java.security.MessageDigest; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +/** + * A Noise XX handler handles the responder side of a Noise XX handshake. This implementation expects clients to send + * identifying information (an account identifier and device ID) as an additional payload when sending its static key + * material. It compares the static public key against the stored public key for the identified device asynchronously, + * buffering traffic from the client until the authentication check completes. + */ +class NoiseXXHandshakeHandler extends AbstractNoiseHandshakeHandler { + + private final ClientPublicKeysManager clientPublicKeysManager; + + private AuthenticationState authenticationState = AuthenticationState.GET_EPHEMERAL_KEY; + + private final List pendingInboundFrames = new ArrayList<>(); + + static final String NOISE_PROTOCOL_NAME = "Noise_XX_25519_ChaChaPoly_BLAKE2b"; + + // When the client sends its static key message, we expect: + // + // - A 32-byte encrypted static public key + // - A 16-byte AEAD tag for the static key + // - 17 bytes of identity data in the message payload (a UUID and a one-byte device ID) + // - A 16-byte AEAD tag for the identity payload + private static final int EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH = 81; + + private enum AuthenticationState { + GET_EPHEMERAL_KEY, + GET_STATIC_KEY, + CHECK_PUBLIC_KEY, + ERROR + } + + public NoiseXXHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, + final ECKeyPair ecKeyPair, + final byte[] publicKeySignature) { + + super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature); + + this.clientPublicKeysManager = clientPublicKeysManager; + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + if (message instanceof BinaryWebSocketFrame frame) { + try { + switch (authenticationState) { + case GET_EPHEMERAL_KEY -> { + try { + handleEphemeralKeyMessage(context, frame); + authenticationState = AuthenticationState.GET_STATIC_KEY; + } finally { + frame.release(); + } + } + case GET_STATIC_KEY -> { + try { + handleStaticKey(context, frame); + authenticationState = AuthenticationState.CHECK_PUBLIC_KEY; + } finally { + frame.release(); + } + } + case CHECK_PUBLIC_KEY -> { + // Buffer any inbound traffic until we've finished checking the client's public key + pendingInboundFrames.add(frame); + } + case ERROR -> { + // If authentication has failed for any reason, just discard inbound traffic until the channel closes + frame.release(); + } + } + } catch (final ShortBufferException e) { + authenticationState = AuthenticationState.ERROR; + throw new NoiseHandshakeException("Unexpected payload length"); + } catch (final BadPaddingException e) { + authenticationState = AuthenticationState.ERROR; + throw new ClientAuthenticationException(); + } + } else { + // Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an + // error + ReferenceCountUtil.release(message); + throw new IllegalArgumentException("Unexpected message in pipeline: " + message); + } + } + + private void handleStaticKey(final ChannelHandlerContext context, final BinaryWebSocketFrame frame) + throws NoiseHandshakeException, ShortBufferException, BadPaddingException { + + if (frame.content().readableBytes() != EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH) { + throw new NoiseHandshakeException("Unexpected client static key message length"); + } + + final HandshakeState handshakeState = getHandshakeState(); + + // The websocket frame will have come right off the wire, and so needs to be copied from a non-array-backed direct + // buffer into a heap buffer. + final byte[] staticKeyAndClientIdentityMessage = ByteBufUtil.getBytes(frame.content()); + + // The payload from the client should be a UUID (16 bytes) followed by a device ID (1 byte) + final byte[] payload = new byte[17]; + + final UUID accountIdentifier; + final byte deviceId; + + final int payloadBytesRead = handshakeState.readMessage(staticKeyAndClientIdentityMessage, + 0, staticKeyAndClientIdentityMessage.length, payload, 0); + + if (payloadBytesRead != 17) { + throw new NoiseHandshakeException("Unexpected identity payload length"); + } + + try { + accountIdentifier = UUIDUtil.fromBytes(payload, 0); + } catch (final IllegalArgumentException e) { + throw new NoiseHandshakeException("Could not parse account identifier"); + } + + deviceId = payload[16]; + + // Verify the identity of the caller by comparing the submitted static public key against the stored public key for + // the identified device + clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId) + .whenCompleteAsync((maybePublicKey, throwable) -> maybePublicKey.ifPresentOrElse(storedPublicKey -> { + final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()]; + handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0); + + if (MessageDigest.isEqual(publicKeyFromClient, storedPublicKey.getPublicKeyBytes())) { + context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent( + Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)))); + + context.pipeline().addAfter(context.name(), null, new NoiseStreamHandler(handshakeState.split())); + + // Flush any buffered reads + pendingInboundFrames.forEach(context::fireChannelRead); + pendingInboundFrames.clear(); + + context.pipeline().remove(NoiseXXHandshakeHandler.this); + } else { + // We found a key, but it doesn't match what the caller submitted + context.fireExceptionCaught(new ClientAuthenticationException()); + authenticationState = AuthenticationState.ERROR; + } + }, + () -> { + // We couldn't find a key for the identified account/device + context.fireExceptionCaught(new ClientAuthenticationException()); + authenticationState = AuthenticationState.ERROR; + }), + context.executor()); + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + super.handlerRemoved(context); + + pendingInboundFrames.forEach(BinaryWebSocketFrame::release); + pendingInboundFrames.clear(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java new file mode 100644 index 000000000..2d5effc1a --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java @@ -0,0 +1,24 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; + +/** + * A proxy handler writes all data read from one channel to another peer channel. + */ +class ProxyHandler extends ChannelInboundHandlerAdapter { + + private final Channel peerChannel; + + public ProxyHandler(final Channel peerChannel) { + this.peerChannel = peerChannel; + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) { + peerChannel.writeAndFlush(message) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandler.java new file mode 100644 index 000000000..41d7436e2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandler.java @@ -0,0 +1,35 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.util.ReferenceCountUtil; + +/** + * A "reject unsupported message" handler closes the channel if it receives messages it does not know how to process. + */ +public class RejectUnsupportedMessagesHandler extends ChannelInboundHandlerAdapter { + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + if (message instanceof final WebSocketFrame webSocketFrame) { + if (webSocketFrame instanceof final TextWebSocketFrame textWebSocketFrame) { + try { + context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE)); + } finally { + textWebSocketFrame.release(); + } + } else { + // Allow all other types of WebSocket frames + context.fireChannelRead(webSocketFrame); + } + } else { + // Discard anything that's not a WebSocket frame + ReferenceCountUtil.release(message); + context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE)); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandler.java new file mode 100644 index 000000000..453ccb390 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandler.java @@ -0,0 +1,74 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.util.ReferenceCountUtil; + +/** + * A WebSocket opening handshake handler serves as the "front door" for the WebSocket/Noise tunnel and gracefully + * rejects requests for anything other than a WebSocket connection to a known endpoint. + */ +class WebSocketOpeningHandshakeHandler extends ChannelInboundHandlerAdapter { + + private final String authenticatedPath; + private final String anonymousPath; + + WebSocketOpeningHandshakeHandler(final String authenticatedPath, final String anonymousPath) { + this.authenticatedPath = authenticatedPath; + this.anonymousPath = anonymousPath; + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + if (message instanceof FullHttpRequest request) { + boolean shouldReleaseRequest = true; + + try { + if (request.decoderResult().isSuccess()) { + if (HttpMethod.GET.equals(request.method())) { + if (authenticatedPath.equals(request.uri()) || anonymousPath.equals(request.uri())) { + if (request.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)) { + // Pass the request along to the websocket handshake handler and remove ourselves from the pipeline + shouldReleaseRequest = false; + + context.fireChannelRead(request); + context.pipeline().remove(this); + } else { + closeConnectionWithStatus(context, request, HttpResponseStatus.UPGRADE_REQUIRED); + } + } else { + closeConnectionWithStatus(context, request, HttpResponseStatus.NOT_FOUND); + } + } else { + closeConnectionWithStatus(context, request, HttpResponseStatus.METHOD_NOT_ALLOWED); + } + } else { + closeConnectionWithStatus(context, request, HttpResponseStatus.BAD_REQUEST); + } + } finally { + if (shouldReleaseRequest) { + request.release(); + } + } + } else { + // Anything except HTTP requests should have been filtered out of the pipeline by now; treat this as an error + ReferenceCountUtil.release(message); + throw new IllegalArgumentException("Unexpected message in pipeline: " + message); + } + } + + private static void closeConnectionWithStatus(final ChannelHandlerContext context, + final FullHttpRequest request, + final HttpResponseStatus status) { + + context.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), status)) + .addListener(ChannelFutureListener.CLOSE); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListener.java new file mode 100644 index 000000000..5b743a61c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListener.java @@ -0,0 +1,52 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; + +/** + * A WebSocket handshake listener waits for a WebSocket handshake to complete, then replaces itself with the appropriate + * Noise handshake handler for the requested path. + */ +class WebsocketHandshakeCompleteListener extends ChannelInboundHandlerAdapter { + + private final ClientPublicKeysManager clientPublicKeysManager; + + private final ECKeyPair ecKeyPair; + private final byte[] publicKeySignature; + + WebsocketHandshakeCompleteListener(final ClientPublicKeysManager clientPublicKeysManager, + final ECKeyPair ecKeyPair, + final byte[] publicKeySignature) { + + this.clientPublicKeysManager = clientPublicKeysManager; + this.ecKeyPair = ecKeyPair; + this.publicKeySignature = publicKeySignature; + } + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) { + if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) { + final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) { + case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH -> + new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature); + + case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH -> + new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature); + + default -> { + // The HttpHandler should have caught all of these cases already; we'll consider it an internal error if + // something slipped through. + throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri()); + } + }; + + context.pipeline().replace(WebsocketHandshakeCompleteListener.this, null, noiseHandshakeHandler); + } + + context.fireUserEventTriggered(event); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java new file mode 100644 index 000000000..f60e0f838 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketNoiseTunnelServer.java @@ -0,0 +1,116 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.google.common.annotations.VisibleForTesting; +import com.southernstorm.noise.protocol.Noise; +import io.dropwizard.lifecycle.Managed; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import io.netty.handler.ssl.ClientAuth; +import io.netty.handler.ssl.OpenSsl; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProtocols; +import io.netty.handler.ssl.SslProvider; +import java.net.InetSocketAddress; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.concurrent.Executor; +import javax.net.ssl.SSLException; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; + +/** + * A WebSocket/Noise tunnel server accepts traffic from the public internet (in the form of Noise packets framed by + * binary WebSocket frames) and passes it through to a local gRPC server. + */ +public class WebsocketNoiseTunnelServer implements Managed { + + private final ServerBootstrap bootstrap; + private ServerSocketChannel channel; + + static final String AUTHENTICATED_SERVICE_PATH = "/authenticated"; + static final String ANONYMOUS_SERVICE_PATH = "/anonymous"; + + private static final Logger log = LoggerFactory.getLogger(WebsocketNoiseTunnelServer.class); + + public WebsocketNoiseTunnelServer(final int websocketPort, + final X509Certificate[] tlsCertificateChain, + final PrivateKey tlsPrivateKey, + final NioEventLoopGroup eventLoopGroup, + final Executor delegatedTaskExecutor, + final ClientPublicKeysManager clientPublicKeysManager, + final ECKeyPair ecKeyPair, + final byte[] publicKeySignature, + final LocalAddress authenticatedGrpcServerAddress, + final LocalAddress anonymousGrpcServerAddress) throws SSLException { + + final SslProvider sslProvider; + + if (OpenSsl.isAvailable()) { + log.info("Native OpenSSL provider is available; will use native provider"); + sslProvider = SslProvider.OPENSSL; + } else { + log.info("No native SSL provider available; will use JDK provider"); + sslProvider = SslProvider.JDK; + } + + final SslContext sslContext = SslContextBuilder.forServer(tlsPrivateKey, tlsCertificateChain) + .clientAuth(ClientAuth.NONE) + .protocols(SslProtocols.TLS_v1_3) + .sslProvider(sslProvider) + .build(); + + this.bootstrap = new ServerBootstrap() + .group(eventLoopGroup) + .channel(NioServerSocketChannel.class) + .localAddress(websocketPort) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel socketChannel) { + socketChannel.pipeline() + .addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor)) + .addLast(new HttpServerCodec()) + .addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN)) + // The WebSocket opening handshake handler will remove itself from the pipeline once it has received a valid WebSocket upgrade + // request and passed it down the pipeline + .addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH)) + .addLast(new WebSocketServerProtocolHandler("/", true)) + .addLast(new RejectUnsupportedMessagesHandler()) + // The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once + // a WebSocket handshake has been completed + .addLast(new WebsocketHandshakeCompleteListener(clientPublicKeysManager, ecKeyPair, publicKeySignature)) + // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler + // once the Noise handshake has completed + .addLast(new EstablishLocalGrpcConnectionHandler(authenticatedGrpcServerAddress, anonymousGrpcServerAddress)) + .addLast(new ErrorHandler()); + } + }); + } + + @VisibleForTesting + InetSocketAddress getLocalAddress() { + return channel.localAddress(); + } + + @Override + public void start() throws InterruptedException { + channel = (ServerSocketChannel) bootstrap.bind().await().channel(); + } + + @Override + public void stop() throws InterruptedException { + if (channel != null) { + channel.close().await(); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java index 38f2cb34b..37cc6bbe3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java @@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.util; import com.google.protobuf.ByteString; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; -import java.util.Optional; import java.util.UUID; public final class UUIDUtil { @@ -40,6 +39,10 @@ public final class UUIDUtil { return fromByteBuffer(ByteBuffer.wrap(bytes)); } + public static UUID fromBytes(final byte[] bytes, final int offset) { + return fromByteBuffer(ByteBuffer.wrap(bytes, offset, 16)); + } + public static UUID fromByteBuffer(final ByteBuffer byteBuffer) { try { final long mostSigBits = byteBuffer.getLong(); @@ -52,12 +55,4 @@ public final class UUIDUtil { throw new IllegalArgumentException("unexpected byte array length; was less than 16"); } } - - public static Optional fromStringSafe(final String uuidString) { - try { - return Optional.of(UUID.fromString(uuidString)); - } catch (final IllegalArgumentException e) { - return Optional.empty(); - } - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java new file mode 100644 index 000000000..79fb2078f --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java @@ -0,0 +1,21 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.util.ResourceLeakDetector; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; + +abstract class AbstractLeakDetectionTest { + + private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel; + + @BeforeAll + static void setLeakDetectionLevel() { + originalResourceLeakDetectorLevel = ResourceLeakDetector.getLevel(); + ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID); + } + + @AfterAll + static void restoreLeakDetectionLevel() { + ResourceLeakDetector.setLevel(originalResourceLeakDetectorLevel); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java new file mode 100644 index 000000000..63b114af5 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseClientHandler.java @@ -0,0 +1,94 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; +import java.security.NoSuchAlgorithmException; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import org.signal.libsignal.protocol.ecc.ECPublicKey; + +abstract class AbstractNoiseClientHandler extends ChannelInboundHandlerAdapter { + + private final ECPublicKey rootPublicKey; + + private final HandshakeState handshakeState; + + AbstractNoiseClientHandler(final ECPublicKey rootPublicKey) { + this.rootPublicKey = rootPublicKey; + + try { + handshakeState = new HandshakeState(getNoiseProtocolName(), HandshakeState.INITIATOR); + } catch (final NoSuchAlgorithmException e) { + throw new AssertionError("Unsupported Noise algorithm: " + getNoiseProtocolName(), e); + } + } + + protected abstract String getNoiseProtocolName(); + + protected abstract void startHandshake(); + + protected HandshakeState getHandshakeState() { + return handshakeState; + } + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception { + if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) { + if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + startHandshake(); + + final byte[] ephemeralKeyMessage = new byte[32]; + handshakeState.writeMessage(ephemeralKeyMessage, 0, null, 0, 0); + + context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessage))) + .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + } + } + + super.userEventTriggered(context, event); + } + + protected void handleServerStaticKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame) + throws NoiseHandshakeException { + + // The frame is coming right off the wire and so will be a direct buffer not backed by an array; copy it to a heap + // buffer so we can Noise at it. + final ByteBuf keyMaterialBuffer = context.alloc().heapBuffer(frame.content().readableBytes()); + final byte[] serverPublicKeySignature = new byte[64]; + + try { + frame.content().readBytes(keyMaterialBuffer); + + final int payloadBytesRead = + handshakeState.readMessage(keyMaterialBuffer.array(), keyMaterialBuffer.arrayOffset(), keyMaterialBuffer.readableBytes(), serverPublicKeySignature, 0); + + if (payloadBytesRead != 64) { + throw new NoiseHandshakeException("Unexpected signature length"); + } + } catch (final ShortBufferException e) { + throw new NoiseHandshakeException("Unexpected signature length"); + } catch (final BadPaddingException e) { + throw new NoiseHandshakeException("Invalid keys"); + } finally { + keyMaterialBuffer.release(); + } + + final byte[] serverPublicKey = new byte[32]; + handshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); + + if (!rootPublicKey.verifySignature(serverPublicKey, serverPublicKeySignature)) { + throw new NoiseHandshakeException("Invalid server public key signature"); + } + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) throws Exception { + handshakeState.destroy(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java new file mode 100644 index 000000000..885b08645 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandshakeHandlerTest.java @@ -0,0 +1,141 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import java.util.concurrent.ThreadLocalRandom; +import javax.annotation.Nullable; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; + +abstract class AbstractNoiseHandshakeHandlerTest extends AbstractLeakDetectionTest { + + private ECPublicKey rootPublicKey; + + private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler; + + private EmbeddedChannel embeddedChannel; + + private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { + + @Nullable + private NoiseHandshakeCompleteEvent handshakeCompleteEvent = null; + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) { + if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) { + handshakeCompleteEvent = noiseHandshakeCompleteEvent; + } else { + context.fireUserEventTriggered(event); + } + } + + @Nullable + public NoiseHandshakeCompleteEvent getHandshakeCompleteEvent() { + return handshakeCompleteEvent; + } + } + + @BeforeEach + void setUp() { + final ECKeyPair rootKeyPair = Curve.generateKeyPair(); + final ECKeyPair serverKeyPair = Curve.generateKeyPair(); + + rootPublicKey = rootKeyPair.getPublicKey(); + + final byte[] serverPublicKeySignature = + rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()); + + noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler(); + + embeddedChannel = + new EmbeddedChannel(getHandler(serverKeyPair, serverPublicKeySignature), noiseHandshakeCompleteHandler); + } + + @AfterEach + void tearDown() { + embeddedChannel.close(); + } + + protected EmbeddedChannel getEmbeddedChannel() { + return embeddedChannel; + } + + protected ECPublicKey getRootPublicKey() { + return rootPublicKey; + } + + @Nullable + protected NoiseHandshakeCompleteEvent getNoiseHandshakeCompleteEvent() { + return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent(); + } + + protected abstract AbstractNoiseHandshakeHandler getHandler(final ECKeyPair serverKeyPair, final byte[] serverPublicKeySignature); + + @Test + void handleInvalidInitialMessage() throws InterruptedException { + final byte[] contentBytes = new byte[17]; + ThreadLocalRandom.current().nextBytes(contentBytes); + + final ByteBuf content = Unpooled.wrappedBuffer(contentBytes); + + final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await(); + + assertFalse(writeFuture.isSuccess()); + assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause()); + assertEquals(0, content.refCnt()); + assertNull(getNoiseHandshakeCompleteEvent()); + } + + @Test + void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException { + final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7]; + + for (int i = 0; i < frames.length; i++) { + final byte[] contentBytes = new byte[17]; + ThreadLocalRandom.current().nextBytes(contentBytes); + + frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes)); + + embeddedChannel.writeOneInbound(frames[i]).await(); + } + + for (final BinaryWebSocketFrame frame : frames) { + assertEquals(0, frame.refCnt()); + } + + assertNull(getNoiseHandshakeCompleteEvent()); + } + + @Test + void handleNonWebSocketBinaryFrame() throws InterruptedException { + final byte[] contentBytes = new byte[17]; + ThreadLocalRandom.current().nextBytes(contentBytes); + + final ByteBuf message = Unpooled.wrappedBuffer(contentBytes); + + final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await(); + + assertFalse(writeFuture.isSuccess()); + assertInstanceOf(IllegalArgumentException.class, writeFuture.cause()); + assertEquals(0, message.refCnt()); + assertNull(getNoiseHandshakeCompleteEvent()); + + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AuthenticationTypeService.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AuthenticationTypeService.java new file mode 100644 index 000000000..112c1f20f --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AuthenticationTypeService.java @@ -0,0 +1,21 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.grpc.stub.StreamObserver; +import org.signal.chat.rpc.AuthenticationTypeGrpc; +import org.signal.chat.rpc.GetAuthenticatedRequest; +import org.signal.chat.rpc.GetAuthenticatedResponse; + +public class AuthenticationTypeService extends AuthenticationTypeGrpc.AuthenticationTypeImplBase { + + private final boolean authenticated; + + public AuthenticationTypeService(final boolean authenticated) { + this.authenticated = authenticated; + } + + @Override + public void getAuthenticated(final GetAuthenticatedRequest request, final StreamObserver responseObserver) { + responseObserver.onNext(GetAuthenticatedResponse.newBuilder().setAuthenticated(authenticated).build()); + responseObserver.onCompleted(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientErrorHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientErrorHandler.java new file mode 100644 index 000000000..ae7bcc133 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ClientErrorHandler.java @@ -0,0 +1,18 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; + +class ClientErrorHandler extends ErrorHandler { + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception { + if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) { + if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + setWebsocketHandshakeComplete(); + } + } + + super.userEventTriggered(context, event); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java new file mode 100644 index 000000000..ec22d6af7 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/EstablishRemoteConnectionHandler.java @@ -0,0 +1,141 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.Noise; +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; +import io.netty.handler.codec.http.websocketx.WebSocketVersion; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.util.ReferenceCountUtil; +import java.net.SocketAddress; +import java.net.URI; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import javax.annotation.Nullable; +import javax.net.ssl.SSLException; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; + +class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { + + private final X509Certificate trustedServerCertificate; + private final URI websocketUri; + private final boolean authenticated; + @Nullable private final ECKeyPair ecKeyPair; + private final ECPublicKey rootPublicKey; + @Nullable private final UUID accountIdentifier; + private final byte deviceId; + private final SocketAddress remoteServerAddress; + private final WebSocketCloseListener webSocketCloseListener; + + private final List pendingReads = new ArrayList<>(); + + private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake"; + + EstablishRemoteConnectionHandler( + final X509Certificate trustedServerCertificate, + final URI websocketUri, + final boolean authenticated, + @Nullable final ECKeyPair ecKeyPair, + final ECPublicKey rootPublicKey, + @Nullable final UUID accountIdentifier, + final byte deviceId, + final SocketAddress remoteServerAddress, + final WebSocketCloseListener webSocketCloseListener) { + + this.trustedServerCertificate = trustedServerCertificate; + this.websocketUri = websocketUri; + this.authenticated = authenticated; + this.ecKeyPair = ecKeyPair; + this.rootPublicKey = rootPublicKey; + this.accountIdentifier = accountIdentifier; + this.deviceId = deviceId; + this.remoteServerAddress = remoteServerAddress; + this.webSocketCloseListener = webSocketCloseListener; + } + + @Override + public void handlerAdded(final ChannelHandlerContext localContext) { + new Bootstrap() + .channel(NioSocketChannel.class) + .group(localContext.channel().eventLoop()) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(final SocketChannel channel) throws SSLException { + channel.pipeline() + .addLast(SslContextBuilder + .forClient() + .trustManager(trustedServerCertificate) + .build() + .newHandler(channel.alloc())) + .addLast(new HttpClientCodec()) + .addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN)) + // Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we + // want to react to them on our own, we need to catch them before they hit that handler. + .addLast(new InboundCloseWebSocketFrameHandler(webSocketCloseListener)) + .addLast(new WebSocketClientProtocolHandler(websocketUri, + WebSocketVersion.V13, + null, + false, + new DefaultHttpHeaders(), + Noise.MAX_PACKET_LEN, + 10_000)) + .addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener)) + .addLast(authenticated + ? new NoiseXXClientHandshakeHandler(ecKeyPair, rootPublicKey, accountIdentifier, deviceId) + : new NoiseNXClientHandshakeHandler(rootPublicKey)) + .addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event) + throws Exception { + if (event instanceof NoiseHandshakeCompleteEvent) { + remoteContext.pipeline() + .replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel())); + + localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel())); + + pendingReads.forEach(localContext::fireChannelRead); + pendingReads.clear(); + + localContext.pipeline().remove(EstablishRemoteConnectionHandler.this); + } + + super.userEventTriggered(remoteContext, event); + } + }) + .addLast(new ClientErrorHandler()); + } + }) + .connect(remoteServerAddress) + .addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + // Close the local connection if the remote channel closes and vice versa + future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close()); + localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close()); + } else { + localContext.close(); + } + }); + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) { + pendingReads.add(message); + } + + @Override + public void handlerRemoved(final ChannelHandlerContext context) { + pendingReads.forEach(ReferenceCountUtil::release); + pendingReads.clear(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/InboundCloseWebSocketFrameHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/InboundCloseWebSocketFrameHandler.java new file mode 100644 index 000000000..e946c937a --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/InboundCloseWebSocketFrameHandler.java @@ -0,0 +1,23 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; + +class InboundCloseWebSocketFrameHandler extends ChannelInboundHandlerAdapter { + + private final WebSocketCloseListener webSocketCloseListener; + + public InboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) { + this.webSocketCloseListener = webSocketCloseListener; + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + if (message instanceof CloseWebSocketFrame closeWebSocketFrame) { + webSocketCloseListener.handleWebSocketClosedByServer(closeWebSocketFrame.statusCode()); + } + + super.channelRead(context, message); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java new file mode 100644 index 000000000..fc8b68e85 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXClientHandshakeHandler.java @@ -0,0 +1,47 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import java.util.Optional; +import org.signal.libsignal.protocol.ecc.ECPublicKey; + +class NoiseNXClientHandshakeHandler extends AbstractNoiseClientHandler { + + private boolean receivedServerStaticKeyMessage = false; + + NoiseNXClientHandshakeHandler(final ECPublicKey rootPublicKey) { + super(rootPublicKey); + } + + @Override + protected String getNoiseProtocolName() { + return NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME; + } + + @Override + protected void startHandshake() { + getHandshakeState().start(); + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { + if (message instanceof BinaryWebSocketFrame frame) { + try { + // Don't process additional messages if we're just waiting to close because the handshake failed + if (receivedServerStaticKeyMessage) { + return; + } + + receivedServerStaticKeyMessage = true; + handleServerStaticKeyMessage(context, frame); + + context.pipeline().replace(this, null, new NoiseStreamHandler(getHandshakeState().split())); + context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty())); + } finally { + frame.release(); + } + } else { + context.fireChannelRead(message); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java new file mode 100644 index 000000000..8ceb485e7 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseNXHandshakeHandlerTest.java @@ -0,0 +1,84 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import java.security.NoSuchAlgorithmException; +import java.util.Optional; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import org.junit.jupiter.api.Test; +import org.signal.libsignal.protocol.ecc.ECKeyPair; + +class NoiseNXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest { + + @Override + protected NoiseNXHandshakeHandler getHandler(final ECKeyPair serverKeyPair, + final byte[] serverPublicKeySignature) { + + return new NoiseNXHandshakeHandler(serverKeyPair, serverPublicKeySignature); + } + + @Test + void handleCompleteHandshake() + throws NoSuchAlgorithmException, ShortBufferException, InterruptedException, BadPaddingException { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + + assertNotNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class)); + + final HandshakeState clientHandshakeState = + new HandshakeState(NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR); + + clientHandshakeState.start(); + + { + final byte[] ephemeralKeyMessageBytes = new byte[32]; + clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0); + + final BinaryWebSocketFrame ephemeralKeyMessageFrame = + new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes)); + + assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess()); + assertEquals(0, ephemeralKeyMessageFrame.refCnt()); + } + + { + assertEquals(1, embeddedChannel.outboundMessages().size()); + + final BinaryWebSocketFrame serverStaticKeyMessageFrame = + (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); + + @SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes = + new byte[serverStaticKeyMessageFrame.content().readableBytes()]; + + serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes); + + final byte[] serverPublicKeySignature = new byte[64]; + + final int payloadLength = + clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0); + + assertEquals(serverPublicKeySignature.length, payloadLength); + + final byte[] serverPublicKey = new byte[32]; + clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); + + assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature)); + } + + assertEquals(new NoiseHandshakeCompleteEvent(Optional.empty()), getNoiseHandshakeCompleteEvent()); + + assertNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class), + "Handshake handler should remove self from pipeline after successful handshake"); + + assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class), + "Handshake handler should insert a Noise stream handler after successful handshake"); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandlerTest.java new file mode 100644 index 000000000..749818d99 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseStreamHandlerTest.java @@ -0,0 +1,135 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.CipherStatePair; +import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ThreadLocalRandom; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import javax.crypto.AEADBadTagException; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import java.nio.charset.StandardCharsets; +import java.security.NoSuchAlgorithmException; + +import static org.junit.jupiter.api.Assertions.*; + +class NoiseStreamHandlerTest extends AbstractLeakDetectionTest { + + private CipherStatePair clientCipherStatePair; + private EmbeddedChannel embeddedChannel; + + // We use an NN handshake for this test just because it's a little shorter and easier to set up + private static final String NOISE_PROTOCOL_NAME = "Noise_NN_25519_ChaChaPoly_BLAKE2b"; + + @BeforeEach + void setUp() throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException { + final HandshakeState clientHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR); + final HandshakeState serverHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.RESPONDER); + + clientHandshakeState.start(); + serverHandshakeState.start(); + + final byte[] clientEphemeralKeyMessage = new byte[32]; + assertEquals(clientEphemeralKeyMessage.length, + clientHandshakeState.writeMessage(clientEphemeralKeyMessage, 0, null, 0, 0)); + + serverHandshakeState.readMessage(clientEphemeralKeyMessage, 0, clientEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0); + + // 32 bytes of key material plus a 16-byte MAC + final byte[] serverEphemeralKeyMessage = new byte[48]; + assertEquals(serverEphemeralKeyMessage.length, + serverHandshakeState.writeMessage(serverEphemeralKeyMessage, 0, null, 0, 0)); + + clientHandshakeState.readMessage(serverEphemeralKeyMessage, 0, serverEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0); + + clientCipherStatePair = clientHandshakeState.split(); + embeddedChannel = new EmbeddedChannel(new NoiseStreamHandler(serverHandshakeState.split())); + + clientHandshakeState.destroy(); + serverHandshakeState.destroy(); + } + + @Test + void channelRead() throws ShortBufferException, InterruptedException { + final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8); + final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()]; + clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length); + + final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext)); + assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess()); + assertEquals(0, ciphertextFrame.refCnt()); + + final ByteBuf decryptedPlaintextBuffer = (ByteBuf) embeddedChannel.inboundMessages().poll(); + assertNotNull(decryptedPlaintextBuffer); + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + + final byte[] decryptedPlaintext = ByteBufUtil.getBytes(decryptedPlaintextBuffer); + decryptedPlaintextBuffer.release(); + + assertArrayEquals(plaintext, decryptedPlaintext); + } + + @Test + void channelReadBadCiphertext() throws InterruptedException { + final byte[] bogusCiphertext = new byte[32]; + ThreadLocalRandom.current().nextBytes(bogusCiphertext); + + final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext)); + final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await(); + + assertEquals(0, ciphertextFrame.refCnt()); + assertFalse(readCiphertextFuture.isSuccess()); + assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause()); + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + } + + @Test + void channelReadUnexpectedMessageType() throws InterruptedException { + final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await(); + + assertFalse(readFuture.isSuccess()); + assertInstanceOf(IllegalArgumentException.class, readFuture.cause()); + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + } + + @Test + void write() throws InterruptedException, ShortBufferException, BadPaddingException { + final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8); + final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext); + + final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer); + assertTrue(writePlaintextFuture.await().isSuccess()); + assertEquals(0, plaintextBuffer.refCnt()); + + final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); + assertNotNull(ciphertextFrame); + assertTrue(embeddedChannel.outboundMessages().isEmpty()); + + final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content()); + ciphertextFrame.release(); + + final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()]; + clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length); + + assertArrayEquals(plaintext, decryptedPlaintext); + } + + @Test + void writeUnexpectedMessageType() throws InterruptedException { + final Object unexpectedMessaged = new Object(); + + final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged); + assertTrue(writeFuture.await().isSuccess()); + + assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll()); + assertTrue(embeddedChannel.outboundMessages().isEmpty()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java new file mode 100644 index 000000000..1966cee12 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXClientHandshakeHandler.java @@ -0,0 +1,89 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.UUID; +import javax.crypto.ShortBufferException; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; + +class NoiseXXClientHandshakeHandler extends AbstractNoiseClientHandler { + + private final ECKeyPair ecKeyPair; + + private final UUID accountIdentifier; + private final byte deviceId; + + private boolean receivedServerStaticKeyMessage = false; + + NoiseXXClientHandshakeHandler(final ECKeyPair ecKeyPair, + final ECPublicKey rootPublicKey, + final UUID accountIdentifier, + final byte deviceId) { + + super(rootPublicKey); + + this.ecKeyPair = ecKeyPair; + + this.accountIdentifier = accountIdentifier; + this.deviceId = deviceId; + } + + @Override + protected String getNoiseProtocolName() { + return NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME; + } + + @Override + protected void startHandshake() { + final HandshakeState handshakeState = getHandshakeState(); + + // Noise-java derives the public key from the private key, so we just need to set the private key + handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0); + handshakeState.start(); + } + + @Override + public void channelRead(final ChannelHandlerContext context, final Object message) + throws NoiseHandshakeException, ShortBufferException { + if (message instanceof BinaryWebSocketFrame frame) { + try { + // Don't process additional messages if the handshake failed and we're just waiting to close + if (receivedServerStaticKeyMessage) { + return; + } + + receivedServerStaticKeyMessage = true; + handleServerStaticKeyMessage(context, frame); + + final ByteBuffer clientIdentityBuffer = ByteBuffer.allocate(17); + clientIdentityBuffer.putLong(accountIdentifier.getMostSignificantBits()); + clientIdentityBuffer.putLong(accountIdentifier.getLeastSignificantBits()); + clientIdentityBuffer.put(deviceId); + clientIdentityBuffer.flip(); + + final HandshakeState handshakeState = getHandshakeState(); + + // We're sending two 32-byte keys plus the client identity payload + final byte[] staticKeyAndIdentityMessage = new byte[64 + clientIdentityBuffer.remaining()]; + handshakeState.writeMessage( + staticKeyAndIdentityMessage, 0, clientIdentityBuffer.array(), 0, clientIdentityBuffer.remaining()); + + context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(staticKeyAndIdentityMessage))) + .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + + context.pipeline().replace(this, null, new NoiseStreamHandler(handshakeState.split())); + context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty())); + } finally { + frame.release(); + } + } else { + context.fireChannelRead(message); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java new file mode 100644 index 000000000..239762aa5 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseXXHandshakeHandlerTest.java @@ -0,0 +1,454 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.southernstorm.noise.protocol.CipherState; +import com.southernstorm.noise.protocol.HandshakeState; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import java.nio.ByteBuffer; +import java.security.NoSuchAlgorithmException; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ThreadLocalRandom; +import javax.crypto.BadPaddingException; +import javax.crypto.ShortBufferException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.EmptyArrays; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; +import org.whispersystems.textsecuregcm.storage.Device; + +class NoiseXXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest { + + private ClientPublicKeysManager clientPublicKeysManager; + + @Override + @BeforeEach + void setUp() { + clientPublicKeysManager = mock(ClientPublicKeysManager.class); + + super.setUp(); + } + + @Override + protected NoiseXXHandshakeHandler getHandler(final ECKeyPair serverKeyPair, + final byte[] serverPublicKeySignature) { + + return new NoiseXXHandshakeHandler(clientPublicKeysManager, serverKeyPair, serverPublicKeySignature); + } + + @Test + void handleCompleteHandshake() + throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final ECKeyPair clientKeyPair = Curve.generateKeyPair(); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); + + final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); + sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId); + + // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the + // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same + // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback + // and issue a "handshake complete" event. + embeddedChannel.runPendingTasks(); + + assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))), + getNoiseHandshakeCompleteEvent()); + + assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), + "Handshake handler should remove self from pipeline after successful handshake"); + + assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class), + "Handshake handler should insert a Noise stream handler after successful handshake"); + } + + @Test + void handleCompleteHandshakeMissingIdentityInformation() + throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final ECKeyPair clientKeyPair = Curve.generateKeyPair(); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); + + final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); + + { + final byte[] clientStaticKeyMessageBytes = new byte[64]; + final int messageLength = + clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, EmptyArrays.EMPTY_BYTES, 0, 0); + + assertEquals(clientStaticKeyMessageBytes.length, messageLength); + + final BinaryWebSocketFrame clientStaticKeyMessageFrame = + new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes)); + + final ChannelFuture writeClientStaticKeyMessageFuture = + getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await(); + + assertFalse(writeClientStaticKeyMessageFuture.isSuccess()); + assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause()); + assertEquals(0, clientStaticKeyMessageFrame.refCnt()); + } + + // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the + // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same + // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback + // and issue a "handshake complete" event. + embeddedChannel.runPendingTasks(); + + assertNull(getNoiseHandshakeCompleteEvent()); + + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), + "Handshake handler should not remove self from pipeline after failed handshake"); + + assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class), + "Noise stream handler should not be added to pipeline after failed handshake"); + } + + @Test + void handleCompleteHandshakeMalformedIdentityInformation() + throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final ECKeyPair clientKeyPair = Curve.generateKeyPair(); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); + + final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); + + { + final byte[] clientStaticKeyMessageBytes = new byte[96]; + final int messageLength = + clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, new byte[32], 0, 32); + + assertEquals(clientStaticKeyMessageBytes.length, messageLength); + + final BinaryWebSocketFrame clientStaticKeyMessageFrame = + new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes)); + + final ChannelFuture writeClientStaticKeyMessageFuture = + getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await(); + + assertFalse(writeClientStaticKeyMessageFuture.isSuccess()); + assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause()); + assertEquals(0, clientStaticKeyMessageFrame.refCnt()); + } + + // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the + // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same + // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback + // and issue a "handshake complete" event. + embeddedChannel.runPendingTasks(); + + assertNull(getNoiseHandshakeCompleteEvent()); + + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), + "Handshake handler should not remove self from pipeline after failed handshake"); + + assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class), + "Noise stream handler should not be added to pipeline after failed handshake"); + } + + @Test + void handleCompleteHandshakeUnrecognizedDevice() + throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final ECKeyPair clientKeyPair = Curve.generateKeyPair(); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); + sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId); + + // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the + // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same + // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback + // and issue a "handshake complete" event. + embeddedChannel.runPendingTasks(); + + assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException); + + assertNull(getNoiseHandshakeCompleteEvent()); + + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), + "Handshake handler should not remove self from pipeline after failed handshake"); + + assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class), + "Noise stream handler should not be added to pipeline after failed handshake"); + } + + @Test + void handleCompleteHandshakePublicKeyMismatch() + throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final ECKeyPair clientKeyPair = Curve.generateKeyPair(); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey()))); + + final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); + sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId); + + // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the + // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same + // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback + // and issue a "handshake complete" event. + embeddedChannel.runPendingTasks(); + + assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException); + + assertNull(getNoiseHandshakeCompleteEvent()); + + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), + "Handshake handler should not remove self from pipeline after failed handshake"); + + assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class), + "Noise stream handler should not be added to pipeline after failed handshake"); + } + + @Test + void handleCompleteHandshakeBufferedReads() + throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final ECKeyPair clientKeyPair = Curve.generateKeyPair(); + + final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>(); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture); + + final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); + sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId); + + final ByteBuf[] additionalMessages = new ByteBuf[4]; + final CipherState senderState = clientHandshakeState.split().getSender(); + + try { + for (int i = 0; i < additionalMessages.length; i++) { + final byte[] contentBytes = new byte[32]; + ThreadLocalRandom.current().nextBytes(contentBytes); + + // Copy the "plaintext" portion of the content bytes for future assertions + additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16); + + // Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD + // tag + senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16); + + assertTrue( + embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await() + .isSuccess()); + } + + findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey())); + + // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the + // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same + // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback + // and issue a "handshake complete" event. + embeddedChannel.runPendingTasks(); + + assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))), + getNoiseHandshakeCompleteEvent()); + + assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), + "Handshake handler should remove self from pipeline after successful handshake"); + + assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class), + "Handshake handler should insert a Noise stream handler after successful handshake"); + + for (final ByteBuf additionalMessage : additionalMessages) { + assertEquals(additionalMessage, embeddedChannel.inboundMessages().poll(), + "Buffered message should pass through pipeline after successful handshake"); + } + } finally { + for (final ByteBuf additionalMessage : additionalMessages) { + additionalMessage.release(); + } + } + } + + @Test + void handleCompleteHandshakeFailureBufferedReads() + throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class)); + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID); + final ECKeyPair clientKeyPair = Curve.generateKeyPair(); + + final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>(); + + when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture); + + final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair); + sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId); + + final ByteBuf[] additionalMessages = new ByteBuf[4]; + final CipherState senderState = clientHandshakeState.split().getSender(); + + try { + for (int i = 0; i < additionalMessages.length; i++) { + final byte[] contentBytes = new byte[32]; + ThreadLocalRandom.current().nextBytes(contentBytes); + + // Copy the "plaintext" portion of the content bytes for future assertions + additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16); + + // Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD + // tag + senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16); + + assertTrue(embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await().isSuccess()); + } + + findPublicKeyFuture.complete(Optional.empty()); + + // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the + // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same + // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback + // and issue a "handshake complete" event. + embeddedChannel.runPendingTasks(); + + assertNull(getNoiseHandshakeCompleteEvent()); + + assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class), + "Handshake handler should not remove self from pipeline after failed handshake"); + + assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class), + "Noise stream handler should not be added to pipeline after failed handshake"); + + assertTrue(embeddedChannel.inboundMessages().isEmpty(), + "Buffered messages should not pass through pipeline after failed handshake"); + } finally { + for (final ByteBuf additionalMessage : additionalMessages) { + additionalMessage.release(); + } + } + } + + private HandshakeState exchangeClientEphemeralAndServerStaticMessages(final ECKeyPair clientKeyPair) + throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException, InterruptedException { + + final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); + + final HandshakeState clientHandshakeState = + new HandshakeState(NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR); + + clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0); + clientHandshakeState.start(); + + { + final byte[] ephemeralKeyMessageBytes = new byte[32]; + clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0); + + final BinaryWebSocketFrame ephemeralKeyMessageFrame = + new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes)); + + assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess()); + assertEquals(0, ephemeralKeyMessageFrame.refCnt()); + } + + { + assertEquals(1, embeddedChannel.outboundMessages().size()); + + final BinaryWebSocketFrame serverStaticKeyMessageFrame = + (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll(); + + @SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes = + new byte[serverStaticKeyMessageFrame.content().readableBytes()]; + + serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes); + + final byte[] serverPublicKeySignature = new byte[64]; + + final int payloadLength = + clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0); + + assertEquals(serverPublicKeySignature.length, payloadLength); + + final byte[] serverPublicKey = new byte[32]; + clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); + + assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature)); + } + + return clientHandshakeState; + } + + private void sendClientStaticKey(final HandshakeState handshakeState, final UUID accountIdentifier, final byte deviceId) + throws ShortBufferException, InterruptedException { + + final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17); + clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits()); + clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits()); + clientIdentityPayloadBuffer.put(deviceId); + clientIdentityPayloadBuffer.flip(); + + final byte[] clientStaticKeyMessageBytes = new byte[81]; + final int messageLength = + handshakeState.writeMessage(clientStaticKeyMessageBytes, 0, clientIdentityPayloadBuffer.array(), 0, clientIdentityPayloadBuffer.remaining()); + + assertEquals(clientStaticKeyMessageBytes.length, messageLength); + + final BinaryWebSocketFrame clientStaticKeyMessageFrame = + new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes)); + + assertTrue(getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await().isSuccess()); + assertEquals(0, clientStaticKeyMessageFrame.refCnt()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseWebSocketFrameHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseWebSocketFrameHandler.java new file mode 100644 index 000000000..682b22d90 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseWebSocketFrameHandler.java @@ -0,0 +1,24 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; + +class OutboundCloseWebSocketFrameHandler extends ChannelOutboundHandlerAdapter { + + private final WebSocketCloseListener webSocketCloseListener; + + OutboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) { + this.webSocketCloseListener = webSocketCloseListener; + } + + @Override + public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception { + if (message instanceof CloseWebSocketFrame closeWebSocketFrame) { + webSocketCloseListener.handleWebSocketClosedByClient(closeWebSocketFrame.statusCode()); + } + + super.write(context, message, promise); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java new file mode 100644 index 000000000..c16a6468d --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/RejectUnsupportedMessagesHandlerTest.java @@ -0,0 +1,72 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest { + + private EmbeddedChannel embeddedChannel; + + @BeforeEach + void setUp() { + embeddedChannel = new EmbeddedChannel(new RejectUnsupportedMessagesHandler()); + } + + @ParameterizedTest + @MethodSource + void allowWebSocketFrame(final WebSocketFrame frame) { + embeddedChannel.writeOneInbound(frame); + + try { + assertEquals(frame, embeddedChannel.inboundMessages().poll()); + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + assertEquals(1, frame.refCnt()); + } finally { + frame.release(); + } + } + + private static List allowWebSocketFrame() { + return List.of( + new BinaryWebSocketFrame(), + new CloseWebSocketFrame(), + new ContinuationWebSocketFrame(), + new PingWebSocketFrame(), + new PongWebSocketFrame()); + } + + @Test + void rejectTextFrame() { + final TextWebSocketFrame textFrame = new TextWebSocketFrame(); + embeddedChannel.writeOneInbound(textFrame); + + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + assertEquals(0, textFrame.refCnt()); + } + + @Test + void rejectNonWebSocketFrame() { + final ByteBuf bytes = Unpooled.buffer(0); + embeddedChannel.writeOneInbound(bytes); + + assertTrue(embeddedChannel.inboundMessages().isEmpty()); + assertEquals(0, bytes.refCnt()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketCloseListener.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketCloseListener.java new file mode 100644 index 000000000..38253eca7 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketCloseListener.java @@ -0,0 +1,18 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +interface WebSocketCloseListener { + + WebSocketCloseListener NOOP_LISTENER = new WebSocketCloseListener() { + @Override + public void handleWebSocketClosedByClient(final int statusCode) { + } + + @Override + public void handleWebSocketClosedByServer(final int statusCode) { + } + }; + + void handleWebSocketClosedByClient(int statusCode); + + void handleWebSocketClosedByServer(int statusCode); +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java new file mode 100644 index 000000000..fb31da1c6 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelClient.java @@ -0,0 +1,70 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import java.net.SocketAddress; +import java.net.URI; +import java.security.cert.X509Certificate; +import java.util.UUID; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import javax.annotation.Nullable; + +class WebSocketNoiseTunnelClient implements AutoCloseable { + + private final ServerBootstrap serverBootstrap; + private Channel serverChannel; + + static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated"); + static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous"); + + public WebSocketNoiseTunnelClient(final SocketAddress remoteServerAddress, + final URI websocketUri, + final boolean authenticated, + final ECKeyPair ecKeyPair, + final ECPublicKey rootPublicKey, + @Nullable final UUID accountIdentifier, + final byte deviceId, + final X509Certificate trustedServerCertificate, + final NioEventLoopGroup eventLoopGroup, + final WebSocketCloseListener webSocketCloseListener) { + + this.serverBootstrap = new ServerBootstrap() + .localAddress(new LocalAddress("websocket-noise-tunnel-client")) + .channel(LocalServerChannel.class) + .group(eventLoopGroup) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(final LocalChannel localChannel) { + localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(trustedServerCertificate, + websocketUri, + authenticated, + ecKeyPair, + rootPublicKey, + accountIdentifier, + deviceId, + remoteServerAddress, + webSocketCloseListener)); + } + }); + } + + LocalAddress getLocalAddress() { + return (LocalAddress) serverChannel.localAddress(); + } + + WebSocketNoiseTunnelClient start() throws InterruptedException { + serverChannel = serverBootstrap.bind().await().channel(); + return this; + } + + @Override + public void close() throws InterruptedException { + serverChannel.close().await(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java new file mode 100644 index 000000000..431edec62 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketNoiseTunnelServerIntegrationTest.java @@ -0,0 +1,486 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.grpc.ManagedChannel; +import io.grpc.ServerBuilder; +import io.grpc.Status; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.security.KeyFactory; +import java.security.KeyStore; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.SecureRandom; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.Base64; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.signal.chat.rpc.AuthenticationTypeGrpc; +import org.signal.chat.rpc.GetAuthenticatedRequest; +import org.signal.chat.rpc.GetAuthenticatedResponse; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; +import org.whispersystems.textsecuregcm.storage.Device; + +class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest { + + private static NioEventLoopGroup nioEventLoopGroup; + private static DefaultEventLoopGroup defaultEventLoopGroup; + private static ExecutorService delegatedTaskExecutor; + + private static X509Certificate serverTlsCertificate; + + private ClientPublicKeysManager clientPublicKeysManager; + + private ECKeyPair rootKeyPair; + private ECKeyPair clientKeyPair; + + private ManagedLocalGrpcServer authenticatedGrpcServer; + private ManagedLocalGrpcServer anonymousGrpcServer; + + private WebsocketNoiseTunnelServer websocketNoiseTunnelServer; + + private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID(); + private static final byte DEVICE_ID = Device.PRIMARY_ID; + + // Please note that this certificate/key are used only for testing and are not used anywhere outside of this test. + // They were generated with: + // + // ```shell + // openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost" + // ``` + private static final String SERVER_CERTIFICATE = """ + -----BEGIN CERTIFICATE----- + MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw + FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx + MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA + IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV + jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq + SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME + GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG + SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw + XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi + iOr9sHiO8Rn2u0xRKgU5Ig== + -----END CERTIFICATE----- + """; + + // BEGIN/END PRIVATE KEY header/footer removed for easier parsing + private static final String SERVER_PRIVATE_KEY = """ + MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj + kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd + PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA + O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo= + """; + + @BeforeAll + static void setUpBeforeAll() throws CertificateException { + nioEventLoopGroup = new NioEventLoopGroup(); + defaultEventLoopGroup = new DefaultEventLoopGroup(); + delegatedTaskExecutor = Executors.newSingleThreadExecutor(); + + final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); + serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate( + new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8))); + } + + @BeforeEach + void setUp() throws NoSuchAlgorithmException, InvalidKeySpecException, IOException, InterruptedException { + + final PrivateKey serverTlsPrivateKey; + { + final KeyFactory keyFactory = KeyFactory.getInstance("EC"); + serverTlsPrivateKey = + keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY))); + } + + rootKeyPair = Curve.generateKeyPair(); + clientKeyPair = Curve.generateKeyPair(); + final ECKeyPair serverKeyPair = Curve.generateKeyPair(); + + clientPublicKeysManager = mock(ClientPublicKeysManager.class); + when(clientPublicKeysManager.findPublicKey(any(), anyByte())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); + + final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated"); + final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous"); + + authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) { + @Override + protected void configureServer(final ServerBuilder serverBuilder) { + serverBuilder.addService(new AuthenticationTypeService(true)); + } + }; + + authenticatedGrpcServer.start(); + + anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) { + @Override + protected void configureServer(final ServerBuilder serverBuilder) { + serverBuilder.addService(new AuthenticationTypeService(false)); + } + }; + + anonymousGrpcServer.start(); + + websocketNoiseTunnelServer = new WebsocketNoiseTunnelServer(0, + new X509Certificate[] { serverTlsCertificate }, + serverTlsPrivateKey, + nioEventLoopGroup, + delegatedTaskExecutor, + clientPublicKeysManager, + serverKeyPair, + rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()), + authenticatedGrpcServerAddress, + anonymousGrpcServerAddress); + + websocketNoiseTunnelServer.start(); + } + + @AfterEach + void tearDown() throws InterruptedException { + websocketNoiseTunnelServer.stop(); + authenticatedGrpcServer.stop(); + anonymousGrpcServer.stop(); + } + + @AfterAll + static void tearDownAfterAll() throws InterruptedException { + nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await(); + defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await(); + + delegatedTaskExecutor.shutdown(); + //noinspection ResultOfMethodCallIgnored + delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS); + } + + @Test + void connectAuthenticated() throws InterruptedException { + try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAuthenticatedClient()) { + final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); + + try { + final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel) + .getAuthenticated(GetAuthenticatedRequest.newBuilder().build()); + + assertTrue(response.getAuthenticated()); + } finally { + channel.shutdown(); + } + } + } + + @Test + void connectAuthenticatedBadServerKeySignature() throws InterruptedException { + final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); + + + // Try to verify the server's public key with something other than the key with which it was signed + try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = + buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) { + + final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> AuthenticationTypeGrpc.newBlockingStub(channel) + .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + } + + verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); + } + + @Test + void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException { + final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); + + when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey()))); + + try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = + buildAndStartAuthenticatedClient(webSocketCloseListener)) { + + final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> AuthenticationTypeGrpc.newBlockingStub(channel) + .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + } + + verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode()); + } + + @Test + void connectAuthenticatedUnrecognizedDevice() throws InterruptedException { + final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); + + when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = + buildAndStartAuthenticatedClient(webSocketCloseListener)) { + + final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> AuthenticationTypeGrpc.newBlockingStub(channel) + .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + } + + verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode()); + } + + @Test + void connectAuthenticatedToAnonymousService() throws InterruptedException { + final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); + + try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = new WebSocketNoiseTunnelClient( + websocketNoiseTunnelServer.getLocalAddress(), + URI.create("wss://localhost/anonymous"), + true, + clientKeyPair, + rootKeyPair.getPublicKey(), + ACCOUNT_IDENTIFIER, + DEVICE_ID, + serverTlsCertificate, + nioEventLoopGroup, + webSocketCloseListener) + .start()) { + + final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> AuthenticationTypeGrpc.newBlockingStub(channel) + .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + } + + verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); + } + + @Test + void connectAnonymous() throws InterruptedException { + try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAnonymousClient()) { + final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); + + try { + final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel) + .getAuthenticated(GetAuthenticatedRequest.newBuilder().build()); + + assertFalse(response.getAuthenticated()); + } finally { + channel.shutdown(); + } + } + } + + @Test + void connectAnonymousBadServerKeySignature() throws InterruptedException { + final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); + + // Try to verify the server's public key with something other than the key with which it was signed + try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = + buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) { + + final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> AuthenticationTypeGrpc.newBlockingStub(channel) + .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + } + + verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); + } + + @Test + void connectAnonymousToAuthenticatedService() throws InterruptedException { + final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); + + try (final WebSocketNoiseTunnelClient websocketNoiseTunnelClient = new WebSocketNoiseTunnelClient( + websocketNoiseTunnelServer.getLocalAddress(), + URI.create("wss://localhost/authenticated"), + false, + null, + rootKeyPair.getPublicKey(), + null, + (byte) 0, + serverTlsCertificate, + nioEventLoopGroup, + webSocketCloseListener) + .start()) { + + final ManagedChannel channel = buildManagedChannel(websocketNoiseTunnelClient.getLocalAddress()); + + try { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, + () -> AuthenticationTypeGrpc.newBlockingStub(channel) + .getAuthenticated(GetAuthenticatedRequest.newBuilder().build())); + } finally { + channel.shutdown(); + } + } + + verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode()); + } + + private ManagedChannel buildManagedChannel(final LocalAddress localAddress) { + return NettyChannelBuilder.forAddress(localAddress) + .channelType(LocalChannel.class) + .eventLoopGroup(defaultEventLoopGroup) + .usePlaintext() + .build(); + } + + @Test + void rejectIllegalRequests() throws Exception { + + final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null, null); + keyStore.setCertificateEntry("tunnel", serverTlsCertificate); + + final TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + + trustManagerFactory.init(keyStore); + + final SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom()); + + final URI authenticatedUri = + new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null); + + final URI incorrectUri = + new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null); + + try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) { + assertEquals(405, httpClient.send(HttpRequest.newBuilder() + .uri(authenticatedUri) + .PUT(HttpRequest.BodyPublishers.ofString("test")) + .build(), + HttpResponse.BodyHandlers.ofString()).statusCode(), + "Non-GET requests should not be allowed"); + + assertEquals(426, httpClient.send(HttpRequest.newBuilder() + .GET() + .uri(authenticatedUri) + .build(), + HttpResponse.BodyHandlers.ofString()).statusCode(), + "GET requests without upgrade headers should not be allowed"); + + assertEquals(404, httpClient.send(HttpRequest.newBuilder() + .GET() + .uri(incorrectUri) + .build(), + HttpResponse.BodyHandlers.ofString()).statusCode(), + "GET requests to unrecognized URIs should not be allowed"); + } + } + + private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException { + return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER); + } + + private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener) + throws InterruptedException { + + return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey()); + } + + private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener, + final ECPublicKey rootPublicKey) throws InterruptedException { + + return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(), + WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI, + true, + clientKeyPair, + rootPublicKey, + ACCOUNT_IDENTIFIER, + DEVICE_ID, + serverTlsCertificate, + nioEventLoopGroup, + webSocketCloseListener) + .start(); + } + + private WebSocketNoiseTunnelClient buildAndStartAnonymousClient() throws InterruptedException { + return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey()); + } + + private WebSocketNoiseTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener, + final ECPublicKey rootPublicKey) throws InterruptedException { + + return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(), + WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI, + false, + null, + rootPublicKey, + null, + (byte) 0, + serverTlsCertificate, + nioEventLoopGroup, + webSocketCloseListener) + .start(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandlerTest.java new file mode 100644 index 000000000..0f27e5efb --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebSocketOpeningHandshakeHandlerTest.java @@ -0,0 +1,104 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest { + + private EmbeddedChannel embeddedChannel; + + private static final String AUTHENTICATED_PATH = "/authenticated"; + private static final String ANONYMOUS_PATH = "/anonymous"; + + @BeforeEach + void setUp() { + embeddedChannel = new EmbeddedChannel(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_PATH, ANONYMOUS_PATH)); + } + + @ParameterizedTest + @ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH }) + void handleValidRequest(final String path) { + final FullHttpRequest request = buildRequest(HttpMethod.GET, path, + new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)); + + try { + embeddedChannel.writeOneInbound(request); + + assertEquals(1, request.refCnt()); + assertEquals(1, embeddedChannel.inboundMessages().size()); + assertEquals(request, embeddedChannel.inboundMessages().poll()); + } finally { + request.release(); + } + } + + @ParameterizedTest + @ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH }) + void handleUpgradeRequired(final String path) { + final FullHttpRequest request = buildRequest(HttpMethod.GET, path, new DefaultHttpHeaders()); + + embeddedChannel.writeOneInbound(request); + + assertEquals(0, request.refCnt()); + assertHttpResponse(HttpResponseStatus.UPGRADE_REQUIRED); + } + + @Test + void handleBadPath() { + final FullHttpRequest request = buildRequest(HttpMethod.GET, "/incorrect", + new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)); + + embeddedChannel.writeOneInbound(request); + + assertEquals(0, request.refCnt()); + assertHttpResponse(HttpResponseStatus.NOT_FOUND); + } + + @ParameterizedTest + @ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH }) + void handleMethodNotAllowed(final String path) { + final FullHttpRequest request = buildRequest(HttpMethod.DELETE, path, + new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)); + + embeddedChannel.writeOneInbound(request); + + assertEquals(0, request.refCnt()); + assertHttpResponse(HttpResponseStatus.METHOD_NOT_ALLOWED); + } + + private void assertHttpResponse(final HttpResponseStatus expectedStatus) { + assertEquals(1, embeddedChannel.outboundMessages().size()); + + final FullHttpResponse response = assertInstanceOf(FullHttpResponse.class, embeddedChannel.outboundMessages().poll()); + + //noinspection DataFlowIssue + assertEquals(expectedStatus, response.status()); + } + + private FullHttpRequest buildRequest(final HttpMethod method, final String path, final HttpHeaders headers) { + return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + method, + path, + Unpooled.buffer(0), + headers, + new DefaultHttpHeaders(true)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListenerTest.java new file mode 100644 index 000000000..98f901b28 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/WebsocketHandshakeCompleteListenerTest.java @@ -0,0 +1,91 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.signal.libsignal.protocol.ecc.Curve; +import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; + +class WebsocketHandshakeCompleteListenerTest extends AbstractLeakDetectionTest { + + private UserEventRecordingHandler userEventRecordingHandler; + private EmbeddedChannel embeddedChannel; + + private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter { + + private final List receivedEvents = new ArrayList<>(); + + @Override + public void userEventTriggered(final ChannelHandlerContext context, final Object event) { + receivedEvents.add(event); + } + + public List getReceivedEvents() { + return receivedEvents; + } + } + + @BeforeEach + void setUp() { + userEventRecordingHandler = new UserEventRecordingHandler(); + + embeddedChannel = new EmbeddedChannel( + new WebsocketHandshakeCompleteListener(mock(ClientPublicKeysManager.class), Curve.generateKeyPair(), new byte[64]), + userEventRecordingHandler); + } + + @ParameterizedTest + @MethodSource + void handleWebSocketHandshakeComplete(final String uri, final Class expectedHandlerClass) { + final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = + new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null); + + embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); + + assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class)); + assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass)); + + assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents()); + } + + private static List handleWebSocketHandshakeComplete() { + return List.of( + Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class), + Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class)); + } + + @Test + void handleWebSocketHandshakeCompleteUnexpectedPath() { + final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = + new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null); + + embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); + + assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class)); + assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException()); + } + + @Test + void handleUnrecognizedEvent() { + final Object unrecognizedEvent = new Object(); + + embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent); + assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents()); + } +} diff --git a/service/src/test/proto/authentication_type_service.proto b/service/src/test/proto/authentication_type_service.proto new file mode 100644 index 000000000..53a60a731 --- /dev/null +++ b/service/src/test/proto/authentication_type_service.proto @@ -0,0 +1,22 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +syntax = "proto3"; + +option java_multiple_files = true; + +package org.signal.chat.rpc; + +// A simple test service that identifies its authentication type to callers +service AuthenticationType { + rpc GetAuthenticated (GetAuthenticatedRequest) returns (GetAuthenticatedResponse) {} +} + +message GetAuthenticatedRequest { +} + +message GetAuthenticatedResponse { + bool authenticated = 1; +}