Introduce a (dormant) Noise/WebSocket for future client/server communication
This commit is contained in:
parent
d2716fe5cf
commit
a5774bf6ff
5
pom.xml
5
pom.xml
|
@ -286,6 +286,11 @@
|
||||||
<artifactId>libsignal-server</artifactId>
|
<artifactId>libsignal-server</artifactId>
|
||||||
<version>0.39.0</version>
|
<version>0.39.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.signal.forks</groupId>
|
||||||
|
<artifactId>noise-java</artifactId>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.logging.log4j</groupId>
|
<groupId>org.apache.logging.log4j</groupId>
|
||||||
<artifactId>log4j-bom</artifactId>
|
<artifactId>log4j-bom</artifactId>
|
||||||
|
|
|
@ -40,8 +40,6 @@ metrics:
|
||||||
- ^lettuce\..+$
|
- ^lettuce\..+$
|
||||||
reportOnStop: true
|
reportOnStop: true
|
||||||
|
|
||||||
grpcPort: 8080
|
|
||||||
|
|
||||||
tlsKeyStore:
|
tlsKeyStore:
|
||||||
password: secret://tlsKeyStore.password
|
password: secret://tlsKeyStore.password
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,11 @@
|
||||||
<artifactId>libsignal-server</artifactId>
|
<artifactId>libsignal-server</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.signal.forks</groupId>
|
||||||
|
<artifactId>noise-java</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>io.dropwizard</groupId>
|
<groupId>io.dropwizard</groupId>
|
||||||
<artifactId>dropwizard-core</artifactId>
|
<artifactId>dropwizard-core</artifactId>
|
||||||
|
@ -242,8 +247,7 @@
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>io.grpc</groupId>
|
<groupId>io.grpc</groupId>
|
||||||
<artifactId>grpc-netty-shaded</artifactId>
|
<artifactId>grpc-netty</artifactId>
|
||||||
<scope>runtime</scope>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>io.grpc</groupId>
|
<groupId>io.grpc</groupId>
|
||||||
|
|
|
@ -298,11 +298,6 @@ public class WhisperServerConfiguration extends Configuration {
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private TusConfiguration tus;
|
private TusConfiguration tus;
|
||||||
|
|
||||||
@Valid
|
|
||||||
@NotNull
|
|
||||||
@JsonProperty
|
|
||||||
private int grpcPort;
|
|
||||||
|
|
||||||
@Valid
|
@Valid
|
||||||
@NotNull
|
@NotNull
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
|
@ -539,10 +534,6 @@ public class WhisperServerConfiguration extends Configuration {
|
||||||
return tus;
|
return tus;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getGrpcPort() {
|
|
||||||
return grpcPort;
|
|
||||||
}
|
|
||||||
|
|
||||||
public ClientReleaseConfiguration getClientReleaseConfiguration() {
|
public ClientReleaseConfiguration getClientReleaseConfiguration() {
|
||||||
return clientRelease;
|
return clientRelease;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,13 +21,13 @@ import io.dropwizard.core.setup.Bootstrap;
|
||||||
import io.dropwizard.core.setup.Environment;
|
import io.dropwizard.core.setup.Environment;
|
||||||
import io.dropwizard.jetty.HttpsConnectorFactory;
|
import io.dropwizard.jetty.HttpsConnectorFactory;
|
||||||
import io.grpc.ServerBuilder;
|
import io.grpc.ServerBuilder;
|
||||||
import io.grpc.ServerInterceptors;
|
|
||||||
import io.lettuce.core.metrics.MicrometerCommandLatencyRecorder;
|
import io.lettuce.core.metrics.MicrometerCommandLatencyRecorder;
|
||||||
import io.lettuce.core.metrics.MicrometerOptions;
|
import io.lettuce.core.metrics.MicrometerOptions;
|
||||||
import io.lettuce.core.resource.ClientResources;
|
import io.lettuce.core.resource.ClientResources;
|
||||||
import io.micrometer.core.instrument.Metrics;
|
import io.micrometer.core.instrument.Metrics;
|
||||||
import io.micrometer.core.instrument.binder.grpc.MetricCollectingServerInterceptor;
|
import io.micrometer.core.instrument.binder.grpc.MetricCollectingServerInterceptor;
|
||||||
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
|
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
import java.net.http.HttpClient;
|
import java.net.http.HttpClient;
|
||||||
import java.time.Clock;
|
import java.time.Clock;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
@ -134,13 +134,14 @@ import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService;
|
||||||
import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor;
|
import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor;
|
||||||
import org.whispersystems.textsecuregcm.grpc.ExternalServiceCredentialsAnonymousGrpcService;
|
import org.whispersystems.textsecuregcm.grpc.ExternalServiceCredentialsAnonymousGrpcService;
|
||||||
import org.whispersystems.textsecuregcm.grpc.ExternalServiceCredentialsGrpcService;
|
import org.whispersystems.textsecuregcm.grpc.ExternalServiceCredentialsGrpcService;
|
||||||
import org.whispersystems.textsecuregcm.grpc.GrpcServerManagedWrapper;
|
|
||||||
import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService;
|
import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService;
|
||||||
import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
|
import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
|
||||||
import org.whispersystems.textsecuregcm.grpc.PaymentsGrpcService;
|
import org.whispersystems.textsecuregcm.grpc.PaymentsGrpcService;
|
||||||
import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService;
|
import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService;
|
||||||
import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService;
|
import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService;
|
||||||
import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor;
|
import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
|
||||||
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
|
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
|
||||||
import org.whispersystems.textsecuregcm.limits.PushChallengeManager;
|
import org.whispersystems.textsecuregcm.limits.PushChallengeManager;
|
||||||
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
|
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
|
||||||
|
@ -753,20 +754,67 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
final BasicCredentialAuthenticationInterceptor basicCredentialAuthenticationInterceptor =
|
final BasicCredentialAuthenticationInterceptor basicCredentialAuthenticationInterceptor =
|
||||||
new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager));
|
new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager));
|
||||||
|
|
||||||
final ServerBuilder<?> grpcServer = ServerBuilder.forPort(config.getGrpcPort())
|
final ManagedDefaultEventLoopGroup localEventLoopGroup = new ManagedDefaultEventLoopGroup();
|
||||||
.addService(ServerInterceptors.intercept(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager), basicCredentialAuthenticationInterceptor))
|
|
||||||
.addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
|
final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);
|
||||||
.addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
|
final MetricCollectingServerInterceptor metricCollectingServerInterceptor =
|
||||||
.addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config))
|
new MetricCollectingServerInterceptor(Metrics.globalRegistry);
|
||||||
.addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keysManager, rateLimiters), basicCredentialAuthenticationInterceptor))
|
|
||||||
.addService(new KeysAnonymousGrpcService(accountsManager, keysManager))
|
final ErrorMappingInterceptor errorMappingInterceptor = new ErrorMappingInterceptor();
|
||||||
.addService(new PaymentsGrpcService(currencyManager))
|
final AcceptLanguageInterceptor acceptLanguageInterceptor = new AcceptLanguageInterceptor();
|
||||||
.addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
|
final UserAgentInterceptor userAgentInterceptor = new UserAgentInterceptor();
|
||||||
config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()), basicCredentialAuthenticationInterceptor))
|
|
||||||
.addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkProfileOperations));
|
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("grpc-anonymous");
|
||||||
|
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("grpc-authenticated");
|
||||||
|
|
||||||
|
final ManagedLocalGrpcServer anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, localEventLoopGroup) {
|
||||||
|
@Override
|
||||||
|
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||||
|
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
|
||||||
|
// depends on the user-agent context so it has to come first here!
|
||||||
|
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
|
||||||
|
serverBuilder
|
||||||
|
// TODO: specialize metrics with user-agent platform
|
||||||
|
.intercept(metricCollectingServerInterceptor)
|
||||||
|
.intercept(errorMappingInterceptor)
|
||||||
|
.intercept(acceptLanguageInterceptor)
|
||||||
|
.intercept(remoteDeprecationFilter)
|
||||||
|
.intercept(userAgentInterceptor)
|
||||||
|
.addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
|
||||||
|
.addService(new KeysAnonymousGrpcService(accountsManager, keysManager))
|
||||||
|
.addService(new PaymentsGrpcService(currencyManager))
|
||||||
|
.addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config))
|
||||||
|
.addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkProfileOperations));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
final ManagedLocalGrpcServer authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, localEventLoopGroup) {
|
||||||
|
@Override
|
||||||
|
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||||
|
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
|
||||||
|
// depends on the user-agent context so it has to come first here!
|
||||||
|
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
|
||||||
|
serverBuilder
|
||||||
|
// TODO: specialize metrics with user-agent platform
|
||||||
|
.intercept(metricCollectingServerInterceptor)
|
||||||
|
.intercept(errorMappingInterceptor)
|
||||||
|
.intercept(acceptLanguageInterceptor)
|
||||||
|
.intercept(remoteDeprecationFilter)
|
||||||
|
.intercept(userAgentInterceptor)
|
||||||
|
.intercept(new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager)))
|
||||||
|
.addService(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager))
|
||||||
|
.addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
|
||||||
|
.addService(new KeysGrpcService(accountsManager, keysManager, rateLimiters))
|
||||||
|
.addService(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
|
||||||
|
config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
environment.lifecycle().manage(localEventLoopGroup);
|
||||||
|
environment.lifecycle().manage(anonymousGrpcServer);
|
||||||
|
environment.lifecycle().manage(authenticatedGrpcServer);
|
||||||
|
|
||||||
final List<Filter> filters = new ArrayList<>();
|
final List<Filter> filters = new ArrayList<>();
|
||||||
final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);
|
|
||||||
filters.add(remoteDeprecationFilter);
|
filters.add(remoteDeprecationFilter);
|
||||||
filters.add(new RemoteAddressFilter(useRemoteAddress));
|
filters.add(new RemoteAddressFilter(useRemoteAddress));
|
||||||
|
|
||||||
|
@ -776,19 +824,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
|
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
|
|
||||||
// depends on the user-agent context so it has to come first here!
|
|
||||||
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
|
|
||||||
grpcServer
|
|
||||||
// TODO: specialize metrics with user-agent platform
|
|
||||||
.intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry))
|
|
||||||
.intercept(new ErrorMappingInterceptor())
|
|
||||||
.intercept(new AcceptLanguageInterceptor())
|
|
||||||
.intercept(remoteDeprecationFilter)
|
|
||||||
.intercept(new UserAgentInterceptor());
|
|
||||||
|
|
||||||
environment.lifecycle().manage(new GrpcServerManagedWrapper(grpcServer.build()));
|
|
||||||
|
|
||||||
final AuthFilter<BasicCredentials, AuthenticatedAccount> accountAuthFilter =
|
final AuthFilter<BasicCredentials, AuthenticatedAccount> accountAuthFilter =
|
||||||
new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>()
|
new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>()
|
||||||
.setAuthenticator(accountAuthenticator)
|
.setAuthenticator(accountAuthenticator)
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2023 Signal Messenger, LLC
|
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.grpc;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
|
|
||||||
import io.dropwizard.lifecycle.Managed;
|
|
||||||
import io.grpc.Server;
|
|
||||||
|
|
||||||
public class GrpcServerManagedWrapper implements Managed {
|
|
||||||
|
|
||||||
private final Server server;
|
|
||||||
|
|
||||||
public GrpcServerManagedWrapper(final Server server) {
|
|
||||||
this.server = server;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void start() throws IOException {
|
|
||||||
server.start();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void stop() {
|
|
||||||
try {
|
|
||||||
server.shutdown().awaitTermination(5, TimeUnit.MINUTES);
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
server.shutdownNow();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,124 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.util.internal.EmptyArrays;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
|
import javax.crypto.ShortBufferException;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An abstract base class for XX- and NX-patterned Noise responder handshake handlers.
|
||||||
|
*
|
||||||
|
* @see <a href="https://noiseprotocol.org/noise.html">The Noise Protocol Framework</a>
|
||||||
|
*/
|
||||||
|
abstract class AbstractNoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final ECKeyPair ecKeyPair;
|
||||||
|
private final byte[] publicKeySignature;
|
||||||
|
|
||||||
|
private final HandshakeState handshakeState;
|
||||||
|
|
||||||
|
private static final int EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH = 32;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a new Noise handler with the given static server keys and static public key signature. The static public
|
||||||
|
* key must be signed by a trusted root private key whose public key is known to and trusted by authenticating
|
||||||
|
* clients.
|
||||||
|
*
|
||||||
|
* @param noiseProtocolName the name of the Noise protocol implemented by this handshake handler
|
||||||
|
* @param ecKeyPair the static key pair for this server
|
||||||
|
* @param publicKeySignature an Ed25519 signature of the raw bytes of the static public key
|
||||||
|
*/
|
||||||
|
AbstractNoiseHandshakeHandler(final String noiseProtocolName,
|
||||||
|
final ECKeyPair ecKeyPair,
|
||||||
|
final byte[] publicKeySignature) {
|
||||||
|
|
||||||
|
this.ecKeyPair = ecKeyPair;
|
||||||
|
this.publicKeySignature = publicKeySignature;
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.handshakeState = new HandshakeState(noiseProtocolName, HandshakeState.RESPONDER);
|
||||||
|
} catch (final NoSuchAlgorithmException e) {
|
||||||
|
throw new AssertionError("Unsupported Noise algorithm: " + noiseProtocolName, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected HandshakeState getHandshakeState() {
|
||||||
|
return handshakeState;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles an initial ephemeral key message from a client, advancing the handshake state and sending the server's
|
||||||
|
* static keys to the client. Both XX and NX patterns begin with a client sending its ephemeral key to the server.
|
||||||
|
* Clients must not include an additional payload with their ephemeral key message. The server's reply contains its
|
||||||
|
* static keys along with an Ed25519 signature of its public static key by a trusted root key.
|
||||||
|
*
|
||||||
|
* @param context the channel handler context for this message
|
||||||
|
* @param frame the websocket frame containing the ephemeral key message
|
||||||
|
*
|
||||||
|
* @throws NoiseHandshakeException if the ephemeral key message from the client was not of the expected size or if a
|
||||||
|
* general Noise encryption error occurred
|
||||||
|
*/
|
||||||
|
protected void handleEphemeralKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
|
||||||
|
throws NoiseHandshakeException {
|
||||||
|
|
||||||
|
if (frame.content().readableBytes() != EXPECTED_EPHEMERAL_KEY_MESSAGE_LENGTH) {
|
||||||
|
throw new NoiseHandshakeException("Unexpected ephemeral key message length");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is
|
||||||
|
// making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key
|
||||||
|
// from the supplied private key (and will in fact overwrite any previously-set public key when setting a private
|
||||||
|
// key), so we just set the private key here.
|
||||||
|
handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0);
|
||||||
|
handshakeState.start();
|
||||||
|
|
||||||
|
// The initial message from the client should just include a plaintext ephemeral key with no payload. The frame is
|
||||||
|
// coming off the wire and so will be in a direct buffer that doesn't have a backing array.
|
||||||
|
final byte[] ephemeralKeyMessage = ByteBufUtil.getBytes(frame.content());
|
||||||
|
frame.content().readBytes(ephemeralKeyMessage);
|
||||||
|
|
||||||
|
try {
|
||||||
|
handshakeState.readMessage(ephemeralKeyMessage, 0, ephemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
|
||||||
|
} catch (final ShortBufferException e) {
|
||||||
|
// This should never happen since we're checking the length of the frame up front
|
||||||
|
throw new NoiseHandshakeException("Unexpected client payload");
|
||||||
|
} catch (final BadPaddingException e) {
|
||||||
|
// It turns out this should basically never happen because (a) we're not using padding and (b) the "bad AEAD tag"
|
||||||
|
// subclass of a bad padding exception can only happen if we have some AD to check, which we don't for an
|
||||||
|
// ephemeral-key-only message
|
||||||
|
throw new NoiseHandshakeException("Invalid keys");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send our key material and public key signature back to the client; this buffer will include:
|
||||||
|
//
|
||||||
|
// - A 32-byte plaintext ephemeral key
|
||||||
|
// - A 32-byte encrypted static key
|
||||||
|
// - A 16-byte AEAD tag for the static key
|
||||||
|
// - The public key signature payload
|
||||||
|
// - A 16-byte AEAD tag for the payload
|
||||||
|
final byte[] keyMaterial = new byte[32 + 32 + 16 + publicKeySignature.length + 16];
|
||||||
|
|
||||||
|
try {
|
||||||
|
handshakeState.writeMessage(keyMaterial, 0, publicKeySignature, 0, publicKeySignature.length);
|
||||||
|
|
||||||
|
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(keyMaterial)))
|
||||||
|
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||||
|
} catch (final ShortBufferException e) {
|
||||||
|
// This should never happen for messages of known length that we control
|
||||||
|
throw new AssertionError("Key material buffer was too short for message", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||||
|
handshakeState.destroy();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
|
|
||||||
|
enum ApplicationWebSocketCloseReason {
|
||||||
|
NOISE_HANDSHAKE_ERROR(4001),
|
||||||
|
CLIENT_AUTHENTICATION_ERROR(4002),
|
||||||
|
NOISE_ENCRYPTION_ERROR(4003),
|
||||||
|
REAUTHENTICATION_REQUIRED(4004);
|
||||||
|
|
||||||
|
private final int statusCode;
|
||||||
|
|
||||||
|
ApplicationWebSocketCloseReason(final int statusCode) {
|
||||||
|
this.statusCode = statusCode;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getStatusCode() {
|
||||||
|
return statusCode;
|
||||||
|
}
|
||||||
|
|
||||||
|
WebSocketCloseStatus toWebSocketCloseStatus(final String reason) {
|
||||||
|
return new WebSocketCloseStatus(statusCode, reason);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indicates that an attempt to authenticate a remote client failed for some reason.
|
||||||
|
*/
|
||||||
|
class ClientAuthenticationException extends Exception {
|
||||||
|
}
|
|
@ -0,0 +1,59 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a
|
||||||
|
* WebSocket handshake, the error handler will send appropriate WebSocket closure codes to the client in an attempt to
|
||||||
|
* identify the problem. If the client has not completed a WebSocket handshake, the handler simply closes the
|
||||||
|
* connection.
|
||||||
|
*/
|
||||||
|
class ErrorHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private boolean websocketHandshakeComplete = false;
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||||
|
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
|
||||||
|
setWebsocketHandshakeComplete();
|
||||||
|
}
|
||||||
|
|
||||||
|
context.fireUserEventTriggered(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void setWebsocketHandshakeComplete() {
|
||||||
|
this.websocketHandshakeComplete = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
|
||||||
|
if (websocketHandshakeComplete) {
|
||||||
|
final WebSocketCloseStatus webSocketCloseStatus = switch (cause) {
|
||||||
|
case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage());
|
||||||
|
case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated");
|
||||||
|
case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
|
||||||
|
default -> {
|
||||||
|
log.warn("An unexpected exception reached the end of the pipeline", cause);
|
||||||
|
yield WebSocketCloseStatus.INTERNAL_SERVER_ERROR;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus))
|
||||||
|
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||||
|
} else {
|
||||||
|
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
|
||||||
|
// way; just close the connection instead.
|
||||||
|
context.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,96 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.bootstrap.Bootstrap;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.channel.ChannelInitializer;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.local.LocalChannel;
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
|
||||||
|
* any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC
|
||||||
|
* server.
|
||||||
|
*/
|
||||||
|
class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final LocalAddress authenticatedGrpcServerAddress;
|
||||||
|
private final LocalAddress anonymousGrpcServerAddress;
|
||||||
|
private final List<Object> pendingReads = new ArrayList<>();
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class);
|
||||||
|
|
||||||
|
public EstablishLocalGrpcConnectionHandler(final LocalAddress authenticatedGrpcServerAddress,
|
||||||
|
final LocalAddress anonymousGrpcServerAddress) {
|
||||||
|
|
||||||
|
this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress;
|
||||||
|
this.anonymousGrpcServerAddress = anonymousGrpcServerAddress;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||||
|
pendingReads.add(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) throws Exception {
|
||||||
|
if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
|
||||||
|
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
|
||||||
|
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
|
||||||
|
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
|
||||||
|
// authenticated service.
|
||||||
|
final LocalAddress grpcServerAddress = noiseHandshakeCompleteEvent.authenticatedDevice().isPresent()
|
||||||
|
? authenticatedGrpcServerAddress
|
||||||
|
: anonymousGrpcServerAddress;
|
||||||
|
|
||||||
|
new Bootstrap()
|
||||||
|
.remoteAddress(grpcServerAddress)
|
||||||
|
// TODO Set local address
|
||||||
|
.channel(LocalChannel.class)
|
||||||
|
.group(remoteChannelContext.channel().eventLoop())
|
||||||
|
.handler(new ChannelInitializer<LocalChannel>() {
|
||||||
|
@Override
|
||||||
|
protected void initChannel(final LocalChannel localChannel) {
|
||||||
|
localChannel.pipeline().addLast(new ProxyHandler(remoteChannelContext.channel()));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.connect()
|
||||||
|
.addListener((ChannelFutureListener) future -> {
|
||||||
|
if (future.isSuccess()) {
|
||||||
|
// Close the local connection if the remote channel closes and vice versa
|
||||||
|
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
|
||||||
|
future.channel().closeFuture().addListener(closeFuture ->
|
||||||
|
remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART)));
|
||||||
|
|
||||||
|
remoteChannelContext.pipeline()
|
||||||
|
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(future.channel()));
|
||||||
|
|
||||||
|
// Flush any buffered reads we accumulated while waiting to open the connection
|
||||||
|
pendingReads.forEach(remoteChannelContext::fireChannelRead);
|
||||||
|
pendingReads.clear();
|
||||||
|
|
||||||
|
remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this);
|
||||||
|
} else {
|
||||||
|
log.warn("Failed to establish local connection to gRPC server", future.cause());
|
||||||
|
remoteChannelContext.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteChannelContext.fireUserEventTriggered(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||||
|
pendingReads.forEach(ReferenceCountUtil::release);
|
||||||
|
pendingReads.clear();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.dropwizard.lifecycle.Managed;
|
||||||
|
import io.netty.channel.DefaultEventLoopGroup;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A wrapper for a Netty {@link DefaultEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing
|
||||||
|
* Dropwizard to manage the lifecycle of the event loop group.
|
||||||
|
*/
|
||||||
|
public class ManagedDefaultEventLoopGroup extends DefaultEventLoopGroup implements Managed {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void stop() throws InterruptedException {
|
||||||
|
this.shutdownGracefully().await();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.dropwizard.lifecycle.Managed;
|
||||||
|
import io.grpc.Server;
|
||||||
|
import io.grpc.ServerBuilder;
|
||||||
|
import io.grpc.netty.NettyServerBuilder;
|
||||||
|
import io.netty.channel.DefaultEventLoopGroup;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.local.LocalServerChannel;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A managed, local gRPC server configures and wraps a gRPC {@link Server} that listens on a Netty {@link LocalAddress}
|
||||||
|
* and whose lifecycle is managed by Dropwizard via the {@link Managed} interface.
|
||||||
|
*/
|
||||||
|
public abstract class ManagedLocalGrpcServer implements Managed {
|
||||||
|
|
||||||
|
private final Server server;
|
||||||
|
|
||||||
|
public ManagedLocalGrpcServer(final LocalAddress localAddress,
|
||||||
|
final DefaultEventLoopGroup eventLoopGroup) {
|
||||||
|
|
||||||
|
final ServerBuilder<?> serverBuilder = NettyServerBuilder.forAddress(localAddress)
|
||||||
|
.channelType(LocalServerChannel.class)
|
||||||
|
.bossEventLoopGroup(eventLoopGroup)
|
||||||
|
.workerEventLoopGroup(eventLoopGroup);
|
||||||
|
|
||||||
|
configureServer(serverBuilder);
|
||||||
|
|
||||||
|
server = serverBuilder.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract void configureServer(final ServerBuilder<?> serverBuilder);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void start() throws IOException {
|
||||||
|
server.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void stop() {
|
||||||
|
try {
|
||||||
|
server.shutdown().awaitTermination(5, TimeUnit.MINUTES);
|
||||||
|
} catch (final InterruptedException e) {
|
||||||
|
server.shutdownNow();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.dropwizard.lifecycle.Managed;
|
||||||
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A wrapper for a Netty {@link NioEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing
|
||||||
|
* Dropwizard to manage the lifecycle of the event loop group.
|
||||||
|
*/
|
||||||
|
public class ManagedNioEventLoopGroup extends NioEventLoopGroup implements Managed {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void stop() throws Exception {
|
||||||
|
this.shutdownGracefully().await();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An event that indicates that a Noise handshake has completed, possibly authenticating a caller in the process.
|
||||||
|
*
|
||||||
|
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
|
||||||
|
* type that performs authentication
|
||||||
|
*/
|
||||||
|
record NoiseHandshakeCompleteEvent(Optional<AuthenticatedDevice> authenticatedDevice) {
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
|
||||||
|
* a general encryption error).
|
||||||
|
*/
|
||||||
|
class NoiseHandshakeException extends Exception {
|
||||||
|
|
||||||
|
public NoiseHandshakeException(final String message) {
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import java.util.Optional;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Noise NX handler handles the responder side of a Noise NX handshake.
|
||||||
|
*/
|
||||||
|
class NoiseNXHandshakeHandler extends AbstractNoiseHandshakeHandler {
|
||||||
|
|
||||||
|
static final String NOISE_PROTOCOL_NAME = "Noise_NX_25519_ChaChaPoly_BLAKE2b";
|
||||||
|
|
||||||
|
NoiseNXHandshakeHandler(final ECKeyPair ecKeyPair, final byte[] publicKeySignature) {
|
||||||
|
super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
|
if (message instanceof BinaryWebSocketFrame frame) {
|
||||||
|
try {
|
||||||
|
handleEphemeralKeyMessage(context, frame);
|
||||||
|
} finally {
|
||||||
|
frame.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
// All we need to do is accept the client's ephemeral key and send our own static keys; after that, we can consider
|
||||||
|
// the handshake complete
|
||||||
|
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
|
||||||
|
context.pipeline().replace(NoiseNXHandshakeHandler.this, null, new NoiseStreamHandler(getHandshakeState().split()));
|
||||||
|
} else {
|
||||||
|
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||||
|
// error
|
||||||
|
ReferenceCountUtil.release(message);
|
||||||
|
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,95 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.CipherState;
|
||||||
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelDuplexHandler;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
|
import javax.crypto.ShortBufferException;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Noise stream handler manages a bidirectional Noise session after a handshake has completed.
|
||||||
|
*/
|
||||||
|
class NoiseStreamHandler extends ChannelDuplexHandler {
|
||||||
|
|
||||||
|
private final CipherStatePair cipherStatePair;
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(NoiseStreamHandler.class);
|
||||||
|
|
||||||
|
NoiseStreamHandler(CipherStatePair cipherStatePair) {
|
||||||
|
this.cipherStatePair = cipherStatePair;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message)
|
||||||
|
throws ShortBufferException, BadPaddingException {
|
||||||
|
|
||||||
|
if (message instanceof BinaryWebSocketFrame frame) {
|
||||||
|
try {
|
||||||
|
final CipherState cipherState = cipherStatePair.getReceiver();
|
||||||
|
|
||||||
|
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||||
|
// We'll need to copy it to a heap buffer.
|
||||||
|
final byte[] noiseBuffer = ByteBufUtil.getBytes(frame.content());
|
||||||
|
|
||||||
|
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
|
||||||
|
final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length);
|
||||||
|
|
||||||
|
context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength));
|
||||||
|
} finally {
|
||||||
|
frame.release();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||||
|
// error
|
||||||
|
ReferenceCountUtil.release(message);
|
||||||
|
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
|
||||||
|
if (message instanceof ByteBuf plaintext) {
|
||||||
|
try {
|
||||||
|
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
|
||||||
|
final CipherState cipherState = cipherStatePair.getSender();
|
||||||
|
final int plaintextLength = plaintext.readableBytes();
|
||||||
|
|
||||||
|
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
|
||||||
|
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
|
||||||
|
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
|
||||||
|
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
|
||||||
|
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
|
||||||
|
|
||||||
|
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
||||||
|
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
||||||
|
|
||||||
|
context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
|
||||||
|
} finally {
|
||||||
|
plaintext.release();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!(message instanceof WebSocketFrame)) {
|
||||||
|
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
|
||||||
|
// get issued in response to exceptions)
|
||||||
|
log.warn("Unexpected object in pipeline: {}", message);
|
||||||
|
}
|
||||||
|
|
||||||
|
context.write(message, promise);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||||
|
cipherStatePair.destroy();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,178 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import java.security.MessageDigest;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
|
import javax.crypto.ShortBufferException;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Noise XX handler handles the responder side of a Noise XX handshake. This implementation expects clients to send
|
||||||
|
* identifying information (an account identifier and device ID) as an additional payload when sending its static key
|
||||||
|
* material. It compares the static public key against the stored public key for the identified device asynchronously,
|
||||||
|
* buffering traffic from the client until the authentication check completes.
|
||||||
|
*/
|
||||||
|
class NoiseXXHandshakeHandler extends AbstractNoiseHandshakeHandler {
|
||||||
|
|
||||||
|
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||||
|
|
||||||
|
private AuthenticationState authenticationState = AuthenticationState.GET_EPHEMERAL_KEY;
|
||||||
|
|
||||||
|
private final List<BinaryWebSocketFrame> pendingInboundFrames = new ArrayList<>();
|
||||||
|
|
||||||
|
static final String NOISE_PROTOCOL_NAME = "Noise_XX_25519_ChaChaPoly_BLAKE2b";
|
||||||
|
|
||||||
|
// When the client sends its static key message, we expect:
|
||||||
|
//
|
||||||
|
// - A 32-byte encrypted static public key
|
||||||
|
// - A 16-byte AEAD tag for the static key
|
||||||
|
// - 17 bytes of identity data in the message payload (a UUID and a one-byte device ID)
|
||||||
|
// - A 16-byte AEAD tag for the identity payload
|
||||||
|
private static final int EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH = 81;
|
||||||
|
|
||||||
|
private enum AuthenticationState {
|
||||||
|
GET_EPHEMERAL_KEY,
|
||||||
|
GET_STATIC_KEY,
|
||||||
|
CHECK_PUBLIC_KEY,
|
||||||
|
ERROR
|
||||||
|
}
|
||||||
|
|
||||||
|
public NoiseXXHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager,
|
||||||
|
final ECKeyPair ecKeyPair,
|
||||||
|
final byte[] publicKeySignature) {
|
||||||
|
|
||||||
|
super(NOISE_PROTOCOL_NAME, ecKeyPair, publicKeySignature);
|
||||||
|
|
||||||
|
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
|
if (message instanceof BinaryWebSocketFrame frame) {
|
||||||
|
try {
|
||||||
|
switch (authenticationState) {
|
||||||
|
case GET_EPHEMERAL_KEY -> {
|
||||||
|
try {
|
||||||
|
handleEphemeralKeyMessage(context, frame);
|
||||||
|
authenticationState = AuthenticationState.GET_STATIC_KEY;
|
||||||
|
} finally {
|
||||||
|
frame.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case GET_STATIC_KEY -> {
|
||||||
|
try {
|
||||||
|
handleStaticKey(context, frame);
|
||||||
|
authenticationState = AuthenticationState.CHECK_PUBLIC_KEY;
|
||||||
|
} finally {
|
||||||
|
frame.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case CHECK_PUBLIC_KEY -> {
|
||||||
|
// Buffer any inbound traffic until we've finished checking the client's public key
|
||||||
|
pendingInboundFrames.add(frame);
|
||||||
|
}
|
||||||
|
case ERROR -> {
|
||||||
|
// If authentication has failed for any reason, just discard inbound traffic until the channel closes
|
||||||
|
frame.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (final ShortBufferException e) {
|
||||||
|
authenticationState = AuthenticationState.ERROR;
|
||||||
|
throw new NoiseHandshakeException("Unexpected payload length");
|
||||||
|
} catch (final BadPaddingException e) {
|
||||||
|
authenticationState = AuthenticationState.ERROR;
|
||||||
|
throw new ClientAuthenticationException();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
|
||||||
|
// error
|
||||||
|
ReferenceCountUtil.release(message);
|
||||||
|
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void handleStaticKey(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
|
||||||
|
throws NoiseHandshakeException, ShortBufferException, BadPaddingException {
|
||||||
|
|
||||||
|
if (frame.content().readableBytes() != EXPECTED_CLIENT_STATIC_KEY_MESSAGE_LENGTH) {
|
||||||
|
throw new NoiseHandshakeException("Unexpected client static key message length");
|
||||||
|
}
|
||||||
|
|
||||||
|
final HandshakeState handshakeState = getHandshakeState();
|
||||||
|
|
||||||
|
// The websocket frame will have come right off the wire, and so needs to be copied from a non-array-backed direct
|
||||||
|
// buffer into a heap buffer.
|
||||||
|
final byte[] staticKeyAndClientIdentityMessage = ByteBufUtil.getBytes(frame.content());
|
||||||
|
|
||||||
|
// The payload from the client should be a UUID (16 bytes) followed by a device ID (1 byte)
|
||||||
|
final byte[] payload = new byte[17];
|
||||||
|
|
||||||
|
final UUID accountIdentifier;
|
||||||
|
final byte deviceId;
|
||||||
|
|
||||||
|
final int payloadBytesRead = handshakeState.readMessage(staticKeyAndClientIdentityMessage,
|
||||||
|
0, staticKeyAndClientIdentityMessage.length, payload, 0);
|
||||||
|
|
||||||
|
if (payloadBytesRead != 17) {
|
||||||
|
throw new NoiseHandshakeException("Unexpected identity payload length");
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
accountIdentifier = UUIDUtil.fromBytes(payload, 0);
|
||||||
|
} catch (final IllegalArgumentException e) {
|
||||||
|
throw new NoiseHandshakeException("Could not parse account identifier");
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceId = payload[16];
|
||||||
|
|
||||||
|
// Verify the identity of the caller by comparing the submitted static public key against the stored public key for
|
||||||
|
// the identified device
|
||||||
|
clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)
|
||||||
|
.whenCompleteAsync((maybePublicKey, throwable) -> maybePublicKey.ifPresentOrElse(storedPublicKey -> {
|
||||||
|
final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()];
|
||||||
|
handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0);
|
||||||
|
|
||||||
|
if (MessageDigest.isEqual(publicKeyFromClient, storedPublicKey.getPublicKeyBytes())) {
|
||||||
|
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(
|
||||||
|
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))));
|
||||||
|
|
||||||
|
context.pipeline().addAfter(context.name(), null, new NoiseStreamHandler(handshakeState.split()));
|
||||||
|
|
||||||
|
// Flush any buffered reads
|
||||||
|
pendingInboundFrames.forEach(context::fireChannelRead);
|
||||||
|
pendingInboundFrames.clear();
|
||||||
|
|
||||||
|
context.pipeline().remove(NoiseXXHandshakeHandler.this);
|
||||||
|
} else {
|
||||||
|
// We found a key, but it doesn't match what the caller submitted
|
||||||
|
context.fireExceptionCaught(new ClientAuthenticationException());
|
||||||
|
authenticationState = AuthenticationState.ERROR;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
() -> {
|
||||||
|
// We couldn't find a key for the identified account/device
|
||||||
|
context.fireExceptionCaught(new ClientAuthenticationException());
|
||||||
|
authenticationState = AuthenticationState.ERROR;
|
||||||
|
}),
|
||||||
|
context.executor());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||||
|
super.handlerRemoved(context);
|
||||||
|
|
||||||
|
pendingInboundFrames.forEach(BinaryWebSocketFrame::release);
|
||||||
|
pendingInboundFrames.clear();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.Channel;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A proxy handler writes all data read from one channel to another peer channel.
|
||||||
|
*/
|
||||||
|
class ProxyHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final Channel peerChannel;
|
||||||
|
|
||||||
|
public ProxyHandler(final Channel peerChannel) {
|
||||||
|
this.peerChannel = peerChannel;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||||
|
peerChannel.writeAndFlush(message)
|
||||||
|
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A "reject unsupported message" handler closes the channel if it receives messages it does not know how to process.
|
||||||
|
*/
|
||||||
|
public class RejectUnsupportedMessagesHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
|
if (message instanceof final WebSocketFrame webSocketFrame) {
|
||||||
|
if (webSocketFrame instanceof final TextWebSocketFrame textWebSocketFrame) {
|
||||||
|
try {
|
||||||
|
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE));
|
||||||
|
} finally {
|
||||||
|
textWebSocketFrame.release();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Allow all other types of WebSocket frames
|
||||||
|
context.fireChannelRead(webSocketFrame);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Discard anything that's not a WebSocket frame
|
||||||
|
ReferenceCountUtil.release(message);
|
||||||
|
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.handler.codec.http.DefaultFullHttpResponse;
|
||||||
|
import io.netty.handler.codec.http.FullHttpRequest;
|
||||||
|
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||||
|
import io.netty.handler.codec.http.HttpHeaderValues;
|
||||||
|
import io.netty.handler.codec.http.HttpMethod;
|
||||||
|
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A WebSocket opening handshake handler serves as the "front door" for the WebSocket/Noise tunnel and gracefully
|
||||||
|
* rejects requests for anything other than a WebSocket connection to a known endpoint.
|
||||||
|
*/
|
||||||
|
class WebSocketOpeningHandshakeHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final String authenticatedPath;
|
||||||
|
private final String anonymousPath;
|
||||||
|
|
||||||
|
WebSocketOpeningHandshakeHandler(final String authenticatedPath, final String anonymousPath) {
|
||||||
|
this.authenticatedPath = authenticatedPath;
|
||||||
|
this.anonymousPath = anonymousPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
|
if (message instanceof FullHttpRequest request) {
|
||||||
|
boolean shouldReleaseRequest = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (request.decoderResult().isSuccess()) {
|
||||||
|
if (HttpMethod.GET.equals(request.method())) {
|
||||||
|
if (authenticatedPath.equals(request.uri()) || anonymousPath.equals(request.uri())) {
|
||||||
|
if (request.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)) {
|
||||||
|
// Pass the request along to the websocket handshake handler and remove ourselves from the pipeline
|
||||||
|
shouldReleaseRequest = false;
|
||||||
|
|
||||||
|
context.fireChannelRead(request);
|
||||||
|
context.pipeline().remove(this);
|
||||||
|
} else {
|
||||||
|
closeConnectionWithStatus(context, request, HttpResponseStatus.UPGRADE_REQUIRED);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
closeConnectionWithStatus(context, request, HttpResponseStatus.NOT_FOUND);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
closeConnectionWithStatus(context, request, HttpResponseStatus.METHOD_NOT_ALLOWED);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
closeConnectionWithStatus(context, request, HttpResponseStatus.BAD_REQUEST);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
if (shouldReleaseRequest) {
|
||||||
|
request.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Anything except HTTP requests should have been filtered out of the pipeline by now; treat this as an error
|
||||||
|
ReferenceCountUtil.release(message);
|
||||||
|
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void closeConnectionWithStatus(final ChannelHandlerContext context,
|
||||||
|
final FullHttpRequest request,
|
||||||
|
final HttpResponseStatus status) {
|
||||||
|
|
||||||
|
context.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), status))
|
||||||
|
.addListener(ChannelFutureListener.CLOSE);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelHandler;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A WebSocket handshake listener waits for a WebSocket handshake to complete, then replaces itself with the appropriate
|
||||||
|
* Noise handshake handler for the requested path.
|
||||||
|
*/
|
||||||
|
class WebsocketHandshakeCompleteListener extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||||
|
|
||||||
|
private final ECKeyPair ecKeyPair;
|
||||||
|
private final byte[] publicKeySignature;
|
||||||
|
|
||||||
|
WebsocketHandshakeCompleteListener(final ClientPublicKeysManager clientPublicKeysManager,
|
||||||
|
final ECKeyPair ecKeyPair,
|
||||||
|
final byte[] publicKeySignature) {
|
||||||
|
|
||||||
|
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||||
|
this.ecKeyPair = ecKeyPair;
|
||||||
|
this.publicKeySignature = publicKeySignature;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||||
|
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||||
|
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
|
||||||
|
case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH ->
|
||||||
|
new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature);
|
||||||
|
|
||||||
|
case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH ->
|
||||||
|
new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature);
|
||||||
|
|
||||||
|
default -> {
|
||||||
|
// The HttpHandler should have caught all of these cases already; we'll consider it an internal error if
|
||||||
|
// something slipped through.
|
||||||
|
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
context.pipeline().replace(WebsocketHandshakeCompleteListener.this, null, noiseHandshakeHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
context.fireUserEventTriggered(event);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
|
import io.dropwizard.lifecycle.Managed;
|
||||||
|
import io.netty.bootstrap.ServerBootstrap;
|
||||||
|
import io.netty.channel.ChannelInitializer;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
|
import io.netty.channel.socket.ServerSocketChannel;
|
||||||
|
import io.netty.channel.socket.SocketChannel;
|
||||||
|
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
||||||
|
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||||
|
import io.netty.handler.codec.http.HttpServerCodec;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||||
|
import io.netty.handler.ssl.ClientAuth;
|
||||||
|
import io.netty.handler.ssl.OpenSsl;
|
||||||
|
import io.netty.handler.ssl.SslContext;
|
||||||
|
import io.netty.handler.ssl.SslContextBuilder;
|
||||||
|
import io.netty.handler.ssl.SslProtocols;
|
||||||
|
import io.netty.handler.ssl.SslProvider;
|
||||||
|
import java.net.InetSocketAddress;
|
||||||
|
import java.security.PrivateKey;
|
||||||
|
import java.security.cert.X509Certificate;
|
||||||
|
import java.util.concurrent.Executor;
|
||||||
|
import javax.net.ssl.SSLException;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A WebSocket/Noise tunnel server accepts traffic from the public internet (in the form of Noise packets framed by
|
||||||
|
* binary WebSocket frames) and passes it through to a local gRPC server.
|
||||||
|
*/
|
||||||
|
public class WebsocketNoiseTunnelServer implements Managed {
|
||||||
|
|
||||||
|
private final ServerBootstrap bootstrap;
|
||||||
|
private ServerSocketChannel channel;
|
||||||
|
|
||||||
|
static final String AUTHENTICATED_SERVICE_PATH = "/authenticated";
|
||||||
|
static final String ANONYMOUS_SERVICE_PATH = "/anonymous";
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(WebsocketNoiseTunnelServer.class);
|
||||||
|
|
||||||
|
public WebsocketNoiseTunnelServer(final int websocketPort,
|
||||||
|
final X509Certificate[] tlsCertificateChain,
|
||||||
|
final PrivateKey tlsPrivateKey,
|
||||||
|
final NioEventLoopGroup eventLoopGroup,
|
||||||
|
final Executor delegatedTaskExecutor,
|
||||||
|
final ClientPublicKeysManager clientPublicKeysManager,
|
||||||
|
final ECKeyPair ecKeyPair,
|
||||||
|
final byte[] publicKeySignature,
|
||||||
|
final LocalAddress authenticatedGrpcServerAddress,
|
||||||
|
final LocalAddress anonymousGrpcServerAddress) throws SSLException {
|
||||||
|
|
||||||
|
final SslProvider sslProvider;
|
||||||
|
|
||||||
|
if (OpenSsl.isAvailable()) {
|
||||||
|
log.info("Native OpenSSL provider is available; will use native provider");
|
||||||
|
sslProvider = SslProvider.OPENSSL;
|
||||||
|
} else {
|
||||||
|
log.info("No native SSL provider available; will use JDK provider");
|
||||||
|
sslProvider = SslProvider.JDK;
|
||||||
|
}
|
||||||
|
|
||||||
|
final SslContext sslContext = SslContextBuilder.forServer(tlsPrivateKey, tlsCertificateChain)
|
||||||
|
.clientAuth(ClientAuth.NONE)
|
||||||
|
.protocols(SslProtocols.TLS_v1_3)
|
||||||
|
.sslProvider(sslProvider)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
this.bootstrap = new ServerBootstrap()
|
||||||
|
.group(eventLoopGroup)
|
||||||
|
.channel(NioServerSocketChannel.class)
|
||||||
|
.localAddress(websocketPort)
|
||||||
|
.childHandler(new ChannelInitializer<SocketChannel>() {
|
||||||
|
@Override
|
||||||
|
protected void initChannel(SocketChannel socketChannel) {
|
||||||
|
socketChannel.pipeline()
|
||||||
|
.addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor))
|
||||||
|
.addLast(new HttpServerCodec())
|
||||||
|
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
|
||||||
|
// The WebSocket opening handshake handler will remove itself from the pipeline once it has received a valid WebSocket upgrade
|
||||||
|
// request and passed it down the pipeline
|
||||||
|
.addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH))
|
||||||
|
.addLast(new WebSocketServerProtocolHandler("/", true))
|
||||||
|
.addLast(new RejectUnsupportedMessagesHandler())
|
||||||
|
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
|
||||||
|
// a WebSocket handshake has been completed
|
||||||
|
.addLast(new WebsocketHandshakeCompleteListener(clientPublicKeysManager, ecKeyPair, publicKeySignature))
|
||||||
|
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||||
|
// once the Noise handshake has completed
|
||||||
|
.addLast(new EstablishLocalGrpcConnectionHandler(authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
|
||||||
|
.addLast(new ErrorHandler());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
InetSocketAddress getLocalAddress() {
|
||||||
|
return channel.localAddress();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void start() throws InterruptedException {
|
||||||
|
channel = (ServerSocketChannel) bootstrap.bind().await().channel();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void stop() throws InterruptedException {
|
||||||
|
if (channel != null) {
|
||||||
|
channel.close().await();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.util;
|
||||||
import com.google.protobuf.ByteString;
|
import com.google.protobuf.ByteString;
|
||||||
import java.nio.BufferUnderflowException;
|
import java.nio.BufferUnderflowException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
|
||||||
public final class UUIDUtil {
|
public final class UUIDUtil {
|
||||||
|
@ -40,6 +39,10 @@ public final class UUIDUtil {
|
||||||
return fromByteBuffer(ByteBuffer.wrap(bytes));
|
return fromByteBuffer(ByteBuffer.wrap(bytes));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static UUID fromBytes(final byte[] bytes, final int offset) {
|
||||||
|
return fromByteBuffer(ByteBuffer.wrap(bytes, offset, 16));
|
||||||
|
}
|
||||||
|
|
||||||
public static UUID fromByteBuffer(final ByteBuffer byteBuffer) {
|
public static UUID fromByteBuffer(final ByteBuffer byteBuffer) {
|
||||||
try {
|
try {
|
||||||
final long mostSigBits = byteBuffer.getLong();
|
final long mostSigBits = byteBuffer.getLong();
|
||||||
|
@ -52,12 +55,4 @@ public final class UUIDUtil {
|
||||||
throw new IllegalArgumentException("unexpected byte array length; was less than 16");
|
throw new IllegalArgumentException("unexpected byte array length; was less than 16");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Optional<UUID> fromStringSafe(final String uuidString) {
|
|
||||||
try {
|
|
||||||
return Optional.of(UUID.fromString(uuidString));
|
|
||||||
} catch (final IllegalArgumentException e) {
|
|
||||||
return Optional.empty();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.util.ResourceLeakDetector;
|
||||||
|
import org.junit.jupiter.api.AfterAll;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
|
||||||
|
abstract class AbstractLeakDetectionTest {
|
||||||
|
|
||||||
|
private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel;
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
static void setLeakDetectionLevel() {
|
||||||
|
originalResourceLeakDetectorLevel = ResourceLeakDetector.getLevel();
|
||||||
|
ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID);
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterAll
|
||||||
|
static void restoreLeakDetectionLevel() {
|
||||||
|
ResourceLeakDetector.setLevel(originalResourceLeakDetectorLevel);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,94 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
|
import javax.crypto.ShortBufferException;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
|
||||||
|
abstract class AbstractNoiseClientHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final ECPublicKey rootPublicKey;
|
||||||
|
|
||||||
|
private final HandshakeState handshakeState;
|
||||||
|
|
||||||
|
AbstractNoiseClientHandler(final ECPublicKey rootPublicKey) {
|
||||||
|
this.rootPublicKey = rootPublicKey;
|
||||||
|
|
||||||
|
try {
|
||||||
|
handshakeState = new HandshakeState(getNoiseProtocolName(), HandshakeState.INITIATOR);
|
||||||
|
} catch (final NoSuchAlgorithmException e) {
|
||||||
|
throw new AssertionError("Unsupported Noise algorithm: " + getNoiseProtocolName(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract String getNoiseProtocolName();
|
||||||
|
|
||||||
|
protected abstract void startHandshake();
|
||||||
|
|
||||||
|
protected HandshakeState getHandshakeState() {
|
||||||
|
return handshakeState;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||||
|
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
|
||||||
|
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
|
||||||
|
startHandshake();
|
||||||
|
|
||||||
|
final byte[] ephemeralKeyMessage = new byte[32];
|
||||||
|
handshakeState.writeMessage(ephemeralKeyMessage, 0, null, 0, 0);
|
||||||
|
|
||||||
|
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessage)))
|
||||||
|
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
super.userEventTriggered(context, event);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void handleServerStaticKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
|
||||||
|
throws NoiseHandshakeException {
|
||||||
|
|
||||||
|
// The frame is coming right off the wire and so will be a direct buffer not backed by an array; copy it to a heap
|
||||||
|
// buffer so we can Noise at it.
|
||||||
|
final ByteBuf keyMaterialBuffer = context.alloc().heapBuffer(frame.content().readableBytes());
|
||||||
|
final byte[] serverPublicKeySignature = new byte[64];
|
||||||
|
|
||||||
|
try {
|
||||||
|
frame.content().readBytes(keyMaterialBuffer);
|
||||||
|
|
||||||
|
final int payloadBytesRead =
|
||||||
|
handshakeState.readMessage(keyMaterialBuffer.array(), keyMaterialBuffer.arrayOffset(), keyMaterialBuffer.readableBytes(), serverPublicKeySignature, 0);
|
||||||
|
|
||||||
|
if (payloadBytesRead != 64) {
|
||||||
|
throw new NoiseHandshakeException("Unexpected signature length");
|
||||||
|
}
|
||||||
|
} catch (final ShortBufferException e) {
|
||||||
|
throw new NoiseHandshakeException("Unexpected signature length");
|
||||||
|
} catch (final BadPaddingException e) {
|
||||||
|
throw new NoiseHandshakeException("Invalid keys");
|
||||||
|
} finally {
|
||||||
|
keyMaterialBuffer.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
final byte[] serverPublicKey = new byte[32];
|
||||||
|
handshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||||
|
|
||||||
|
if (!rootPublicKey.verifySignature(serverPublicKey, serverPublicKeySignature)) {
|
||||||
|
throw new NoiseHandshakeException("Invalid server public key signature");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved(final ChannelHandlerContext context) throws Exception {
|
||||||
|
handshakeState.destroy();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,141 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelFuture;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
import javax.annotation.Nullable;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
|
||||||
|
abstract class AbstractNoiseHandshakeHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
|
private ECPublicKey rootPublicKey;
|
||||||
|
|
||||||
|
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
|
||||||
|
|
||||||
|
private EmbeddedChannel embeddedChannel;
|
||||||
|
|
||||||
|
private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
private NoiseHandshakeCompleteEvent handshakeCompleteEvent = null;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||||
|
if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
|
||||||
|
handshakeCompleteEvent = noiseHandshakeCompleteEvent;
|
||||||
|
} else {
|
||||||
|
context.fireUserEventTriggered(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
public NoiseHandshakeCompleteEvent getHandshakeCompleteEvent() {
|
||||||
|
return handshakeCompleteEvent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
final ECKeyPair rootKeyPair = Curve.generateKeyPair();
|
||||||
|
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
|
rootPublicKey = rootKeyPair.getPublicKey();
|
||||||
|
|
||||||
|
final byte[] serverPublicKeySignature =
|
||||||
|
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes());
|
||||||
|
|
||||||
|
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
|
||||||
|
|
||||||
|
embeddedChannel =
|
||||||
|
new EmbeddedChannel(getHandler(serverKeyPair, serverPublicKeySignature), noiseHandshakeCompleteHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() {
|
||||||
|
embeddedChannel.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected EmbeddedChannel getEmbeddedChannel() {
|
||||||
|
return embeddedChannel;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected ECPublicKey getRootPublicKey() {
|
||||||
|
return rootPublicKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
protected NoiseHandshakeCompleteEvent getNoiseHandshakeCompleteEvent() {
|
||||||
|
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract AbstractNoiseHandshakeHandler getHandler(final ECKeyPair serverKeyPair, final byte[] serverPublicKeySignature);
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleInvalidInitialMessage() throws InterruptedException {
|
||||||
|
final byte[] contentBytes = new byte[17];
|
||||||
|
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||||
|
|
||||||
|
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
|
||||||
|
|
||||||
|
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await();
|
||||||
|
|
||||||
|
assertFalse(writeFuture.isSuccess());
|
||||||
|
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
|
||||||
|
assertEquals(0, content.refCnt());
|
||||||
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
|
||||||
|
final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7];
|
||||||
|
|
||||||
|
for (int i = 0; i < frames.length; i++) {
|
||||||
|
final byte[] contentBytes = new byte[17];
|
||||||
|
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||||
|
|
||||||
|
frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
|
||||||
|
|
||||||
|
embeddedChannel.writeOneInbound(frames[i]).await();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (final BinaryWebSocketFrame frame : frames) {
|
||||||
|
assertEquals(0, frame.refCnt());
|
||||||
|
}
|
||||||
|
|
||||||
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleNonWebSocketBinaryFrame() throws InterruptedException {
|
||||||
|
final byte[] contentBytes = new byte[17];
|
||||||
|
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||||
|
|
||||||
|
final ByteBuf message = Unpooled.wrappedBuffer(contentBytes);
|
||||||
|
|
||||||
|
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
|
||||||
|
|
||||||
|
assertFalse(writeFuture.isSuccess());
|
||||||
|
assertInstanceOf(IllegalArgumentException.class, writeFuture.cause());
|
||||||
|
assertEquals(0, message.refCnt());
|
||||||
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
|
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.grpc.stub.StreamObserver;
|
||||||
|
import org.signal.chat.rpc.AuthenticationTypeGrpc;
|
||||||
|
import org.signal.chat.rpc.GetAuthenticatedRequest;
|
||||||
|
import org.signal.chat.rpc.GetAuthenticatedResponse;
|
||||||
|
|
||||||
|
public class AuthenticationTypeService extends AuthenticationTypeGrpc.AuthenticationTypeImplBase {
|
||||||
|
|
||||||
|
private final boolean authenticated;
|
||||||
|
|
||||||
|
public AuthenticationTypeService(final boolean authenticated) {
|
||||||
|
this.authenticated = authenticated;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void getAuthenticated(final GetAuthenticatedRequest request, final StreamObserver<GetAuthenticatedResponse> responseObserver) {
|
||||||
|
responseObserver.onNext(GetAuthenticatedResponse.newBuilder().setAuthenticated(authenticated).build());
|
||||||
|
responseObserver.onCompleted();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,18 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||||
|
|
||||||
|
class ClientErrorHandler extends ErrorHandler {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||||
|
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
|
||||||
|
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
|
||||||
|
setWebsocketHandshakeComplete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
super.userEventTriggered(context, event);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,141 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
|
import io.netty.bootstrap.Bootstrap;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.channel.ChannelInitializer;
|
||||||
|
import io.netty.channel.socket.SocketChannel;
|
||||||
|
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||||
|
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||||
|
import io.netty.handler.codec.http.HttpClientCodec;
|
||||||
|
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
|
||||||
|
import io.netty.handler.ssl.SslContextBuilder;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import java.net.SocketAddress;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.security.cert.X509Certificate;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
import javax.annotation.Nullable;
|
||||||
|
import javax.net.ssl.SSLException;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
|
||||||
|
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final X509Certificate trustedServerCertificate;
|
||||||
|
private final URI websocketUri;
|
||||||
|
private final boolean authenticated;
|
||||||
|
@Nullable private final ECKeyPair ecKeyPair;
|
||||||
|
private final ECPublicKey rootPublicKey;
|
||||||
|
@Nullable private final UUID accountIdentifier;
|
||||||
|
private final byte deviceId;
|
||||||
|
private final SocketAddress remoteServerAddress;
|
||||||
|
private final WebSocketCloseListener webSocketCloseListener;
|
||||||
|
|
||||||
|
private final List<Object> pendingReads = new ArrayList<>();
|
||||||
|
|
||||||
|
private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
|
||||||
|
|
||||||
|
EstablishRemoteConnectionHandler(
|
||||||
|
final X509Certificate trustedServerCertificate,
|
||||||
|
final URI websocketUri,
|
||||||
|
final boolean authenticated,
|
||||||
|
@Nullable final ECKeyPair ecKeyPair,
|
||||||
|
final ECPublicKey rootPublicKey,
|
||||||
|
@Nullable final UUID accountIdentifier,
|
||||||
|
final byte deviceId,
|
||||||
|
final SocketAddress remoteServerAddress,
|
||||||
|
final WebSocketCloseListener webSocketCloseListener) {
|
||||||
|
|
||||||
|
this.trustedServerCertificate = trustedServerCertificate;
|
||||||
|
this.websocketUri = websocketUri;
|
||||||
|
this.authenticated = authenticated;
|
||||||
|
this.ecKeyPair = ecKeyPair;
|
||||||
|
this.rootPublicKey = rootPublicKey;
|
||||||
|
this.accountIdentifier = accountIdentifier;
|
||||||
|
this.deviceId = deviceId;
|
||||||
|
this.remoteServerAddress = remoteServerAddress;
|
||||||
|
this.webSocketCloseListener = webSocketCloseListener;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerAdded(final ChannelHandlerContext localContext) {
|
||||||
|
new Bootstrap()
|
||||||
|
.channel(NioSocketChannel.class)
|
||||||
|
.group(localContext.channel().eventLoop())
|
||||||
|
.handler(new ChannelInitializer<SocketChannel>() {
|
||||||
|
@Override
|
||||||
|
protected void initChannel(final SocketChannel channel) throws SSLException {
|
||||||
|
channel.pipeline()
|
||||||
|
.addLast(SslContextBuilder
|
||||||
|
.forClient()
|
||||||
|
.trustManager(trustedServerCertificate)
|
||||||
|
.build()
|
||||||
|
.newHandler(channel.alloc()))
|
||||||
|
.addLast(new HttpClientCodec())
|
||||||
|
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
|
||||||
|
// Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we
|
||||||
|
// want to react to them on our own, we need to catch them before they hit that handler.
|
||||||
|
.addLast(new InboundCloseWebSocketFrameHandler(webSocketCloseListener))
|
||||||
|
.addLast(new WebSocketClientProtocolHandler(websocketUri,
|
||||||
|
WebSocketVersion.V13,
|
||||||
|
null,
|
||||||
|
false,
|
||||||
|
new DefaultHttpHeaders(),
|
||||||
|
Noise.MAX_PACKET_LEN,
|
||||||
|
10_000))
|
||||||
|
.addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener))
|
||||||
|
.addLast(authenticated
|
||||||
|
? new NoiseXXClientHandshakeHandler(ecKeyPair, rootPublicKey, accountIdentifier, deviceId)
|
||||||
|
: new NoiseNXClientHandshakeHandler(rootPublicKey))
|
||||||
|
.addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
|
||||||
|
throws Exception {
|
||||||
|
if (event instanceof NoiseHandshakeCompleteEvent) {
|
||||||
|
remoteContext.pipeline()
|
||||||
|
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
|
||||||
|
|
||||||
|
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
|
||||||
|
|
||||||
|
pendingReads.forEach(localContext::fireChannelRead);
|
||||||
|
pendingReads.clear();
|
||||||
|
|
||||||
|
localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
|
||||||
|
}
|
||||||
|
|
||||||
|
super.userEventTriggered(remoteContext, event);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.addLast(new ClientErrorHandler());
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.connect(remoteServerAddress)
|
||||||
|
.addListener((ChannelFutureListener) future -> {
|
||||||
|
if (future.isSuccess()) {
|
||||||
|
// Close the local connection if the remote channel closes and vice versa
|
||||||
|
future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close());
|
||||||
|
localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
|
||||||
|
} else {
|
||||||
|
localContext.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||||
|
pendingReads.add(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||||
|
pendingReads.forEach(ReferenceCountUtil::release);
|
||||||
|
pendingReads.clear();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
|
||||||
|
class InboundCloseWebSocketFrameHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final WebSocketCloseListener webSocketCloseListener;
|
||||||
|
|
||||||
|
public InboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) {
|
||||||
|
this.webSocketCloseListener = webSocketCloseListener;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
|
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
|
||||||
|
webSocketCloseListener.handleWebSocketClosedByServer(closeWebSocketFrame.statusCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
super.channelRead(context, message);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import java.util.Optional;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
|
||||||
|
class NoiseNXClientHandshakeHandler extends AbstractNoiseClientHandler {
|
||||||
|
|
||||||
|
private boolean receivedServerStaticKeyMessage = false;
|
||||||
|
|
||||||
|
NoiseNXClientHandshakeHandler(final ECPublicKey rootPublicKey) {
|
||||||
|
super(rootPublicKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String getNoiseProtocolName() {
|
||||||
|
return NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void startHandshake() {
|
||||||
|
getHandshakeState().start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||||
|
if (message instanceof BinaryWebSocketFrame frame) {
|
||||||
|
try {
|
||||||
|
// Don't process additional messages if we're just waiting to close because the handshake failed
|
||||||
|
if (receivedServerStaticKeyMessage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedServerStaticKeyMessage = true;
|
||||||
|
handleServerStaticKeyMessage(context, frame);
|
||||||
|
|
||||||
|
context.pipeline().replace(this, null, new NoiseStreamHandler(getHandshakeState().split()));
|
||||||
|
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
|
||||||
|
} finally {
|
||||||
|
frame.release();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
context.fireChannelRead(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,84 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import java.util.Optional;
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
|
import javax.crypto.ShortBufferException;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
|
||||||
|
class NoiseNXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NoiseNXHandshakeHandler getHandler(final ECKeyPair serverKeyPair,
|
||||||
|
final byte[] serverPublicKeySignature) {
|
||||||
|
|
||||||
|
return new NoiseNXHandshakeHandler(serverKeyPair, serverPublicKeySignature);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleCompleteHandshake()
|
||||||
|
throws NoSuchAlgorithmException, ShortBufferException, InterruptedException, BadPaddingException {
|
||||||
|
|
||||||
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class));
|
||||||
|
|
||||||
|
final HandshakeState clientHandshakeState =
|
||||||
|
new HandshakeState(NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
|
||||||
|
|
||||||
|
clientHandshakeState.start();
|
||||||
|
|
||||||
|
{
|
||||||
|
final byte[] ephemeralKeyMessageBytes = new byte[32];
|
||||||
|
clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0);
|
||||||
|
|
||||||
|
final BinaryWebSocketFrame ephemeralKeyMessageFrame =
|
||||||
|
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes));
|
||||||
|
|
||||||
|
assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess());
|
||||||
|
assertEquals(0, ephemeralKeyMessageFrame.refCnt());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
assertEquals(1, embeddedChannel.outboundMessages().size());
|
||||||
|
|
||||||
|
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
|
||||||
|
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
||||||
|
|
||||||
|
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
|
||||||
|
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
|
||||||
|
|
||||||
|
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
|
||||||
|
|
||||||
|
final byte[] serverPublicKeySignature = new byte[64];
|
||||||
|
|
||||||
|
final int payloadLength =
|
||||||
|
clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0);
|
||||||
|
|
||||||
|
assertEquals(serverPublicKeySignature.length, payloadLength);
|
||||||
|
|
||||||
|
final byte[] serverPublicKey = new byte[32];
|
||||||
|
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||||
|
|
||||||
|
assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature));
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(new NoiseHandshakeCompleteEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
|
assertNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class),
|
||||||
|
"Handshake handler should remove self from pipeline after successful handshake");
|
||||||
|
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
|
||||||
|
"Handshake handler should insert a Noise stream handler after successful handshake");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,135 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelFuture;
|
||||||
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.util.internal.EmptyArrays;
|
||||||
|
import io.netty.util.internal.ThreadLocalRandom;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import javax.crypto.AEADBadTagException;
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
|
import javax.crypto.ShortBufferException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
class NoiseStreamHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
|
private CipherStatePair clientCipherStatePair;
|
||||||
|
private EmbeddedChannel embeddedChannel;
|
||||||
|
|
||||||
|
// We use an NN handshake for this test just because it's a little shorter and easier to set up
|
||||||
|
private static final String NOISE_PROTOCOL_NAME = "Noise_NN_25519_ChaChaPoly_BLAKE2b";
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException {
|
||||||
|
final HandshakeState clientHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
|
||||||
|
final HandshakeState serverHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.RESPONDER);
|
||||||
|
|
||||||
|
clientHandshakeState.start();
|
||||||
|
serverHandshakeState.start();
|
||||||
|
|
||||||
|
final byte[] clientEphemeralKeyMessage = new byte[32];
|
||||||
|
assertEquals(clientEphemeralKeyMessage.length,
|
||||||
|
clientHandshakeState.writeMessage(clientEphemeralKeyMessage, 0, null, 0, 0));
|
||||||
|
|
||||||
|
serverHandshakeState.readMessage(clientEphemeralKeyMessage, 0, clientEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
|
||||||
|
|
||||||
|
// 32 bytes of key material plus a 16-byte MAC
|
||||||
|
final byte[] serverEphemeralKeyMessage = new byte[48];
|
||||||
|
assertEquals(serverEphemeralKeyMessage.length,
|
||||||
|
serverHandshakeState.writeMessage(serverEphemeralKeyMessage, 0, null, 0, 0));
|
||||||
|
|
||||||
|
clientHandshakeState.readMessage(serverEphemeralKeyMessage, 0, serverEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
|
||||||
|
|
||||||
|
clientCipherStatePair = clientHandshakeState.split();
|
||||||
|
embeddedChannel = new EmbeddedChannel(new NoiseStreamHandler(serverHandshakeState.split()));
|
||||||
|
|
||||||
|
clientHandshakeState.destroy();
|
||||||
|
serverHandshakeState.destroy();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void channelRead() throws ShortBufferException, InterruptedException {
|
||||||
|
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
|
||||||
|
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
|
||||||
|
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
|
||||||
|
|
||||||
|
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext));
|
||||||
|
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
|
||||||
|
assertEquals(0, ciphertextFrame.refCnt());
|
||||||
|
|
||||||
|
final ByteBuf decryptedPlaintextBuffer = (ByteBuf) embeddedChannel.inboundMessages().poll();
|
||||||
|
assertNotNull(decryptedPlaintextBuffer);
|
||||||
|
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||||
|
|
||||||
|
final byte[] decryptedPlaintext = ByteBufUtil.getBytes(decryptedPlaintextBuffer);
|
||||||
|
decryptedPlaintextBuffer.release();
|
||||||
|
|
||||||
|
assertArrayEquals(plaintext, decryptedPlaintext);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void channelReadBadCiphertext() throws InterruptedException {
|
||||||
|
final byte[] bogusCiphertext = new byte[32];
|
||||||
|
ThreadLocalRandom.current().nextBytes(bogusCiphertext);
|
||||||
|
|
||||||
|
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext));
|
||||||
|
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
|
||||||
|
|
||||||
|
assertEquals(0, ciphertextFrame.refCnt());
|
||||||
|
assertFalse(readCiphertextFuture.isSuccess());
|
||||||
|
assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause());
|
||||||
|
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void channelReadUnexpectedMessageType() throws InterruptedException {
|
||||||
|
final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await();
|
||||||
|
|
||||||
|
assertFalse(readFuture.isSuccess());
|
||||||
|
assertInstanceOf(IllegalArgumentException.class, readFuture.cause());
|
||||||
|
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void write() throws InterruptedException, ShortBufferException, BadPaddingException {
|
||||||
|
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
|
||||||
|
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext);
|
||||||
|
|
||||||
|
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
|
||||||
|
assertTrue(writePlaintextFuture.await().isSuccess());
|
||||||
|
assertEquals(0, plaintextBuffer.refCnt());
|
||||||
|
|
||||||
|
final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
||||||
|
assertNotNull(ciphertextFrame);
|
||||||
|
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
||||||
|
|
||||||
|
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
|
||||||
|
ciphertextFrame.release();
|
||||||
|
|
||||||
|
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
|
||||||
|
clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length);
|
||||||
|
|
||||||
|
assertArrayEquals(plaintext, decryptedPlaintext);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void writeUnexpectedMessageType() throws InterruptedException {
|
||||||
|
final Object unexpectedMessaged = new Object();
|
||||||
|
|
||||||
|
final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged);
|
||||||
|
assertTrue(writeFuture.await().isSuccess());
|
||||||
|
|
||||||
|
assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll());
|
||||||
|
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,89 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
import javax.crypto.ShortBufferException;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
|
||||||
|
class NoiseXXClientHandshakeHandler extends AbstractNoiseClientHandler {
|
||||||
|
|
||||||
|
private final ECKeyPair ecKeyPair;
|
||||||
|
|
||||||
|
private final UUID accountIdentifier;
|
||||||
|
private final byte deviceId;
|
||||||
|
|
||||||
|
private boolean receivedServerStaticKeyMessage = false;
|
||||||
|
|
||||||
|
NoiseXXClientHandshakeHandler(final ECKeyPair ecKeyPair,
|
||||||
|
final ECPublicKey rootPublicKey,
|
||||||
|
final UUID accountIdentifier,
|
||||||
|
final byte deviceId) {
|
||||||
|
|
||||||
|
super(rootPublicKey);
|
||||||
|
|
||||||
|
this.ecKeyPair = ecKeyPair;
|
||||||
|
|
||||||
|
this.accountIdentifier = accountIdentifier;
|
||||||
|
this.deviceId = deviceId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String getNoiseProtocolName() {
|
||||||
|
return NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void startHandshake() {
|
||||||
|
final HandshakeState handshakeState = getHandshakeState();
|
||||||
|
|
||||||
|
// Noise-java derives the public key from the private key, so we just need to set the private key
|
||||||
|
handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0);
|
||||||
|
handshakeState.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelRead(final ChannelHandlerContext context, final Object message)
|
||||||
|
throws NoiseHandshakeException, ShortBufferException {
|
||||||
|
if (message instanceof BinaryWebSocketFrame frame) {
|
||||||
|
try {
|
||||||
|
// Don't process additional messages if the handshake failed and we're just waiting to close
|
||||||
|
if (receivedServerStaticKeyMessage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedServerStaticKeyMessage = true;
|
||||||
|
handleServerStaticKeyMessage(context, frame);
|
||||||
|
|
||||||
|
final ByteBuffer clientIdentityBuffer = ByteBuffer.allocate(17);
|
||||||
|
clientIdentityBuffer.putLong(accountIdentifier.getMostSignificantBits());
|
||||||
|
clientIdentityBuffer.putLong(accountIdentifier.getLeastSignificantBits());
|
||||||
|
clientIdentityBuffer.put(deviceId);
|
||||||
|
clientIdentityBuffer.flip();
|
||||||
|
|
||||||
|
final HandshakeState handshakeState = getHandshakeState();
|
||||||
|
|
||||||
|
// We're sending two 32-byte keys plus the client identity payload
|
||||||
|
final byte[] staticKeyAndIdentityMessage = new byte[64 + clientIdentityBuffer.remaining()];
|
||||||
|
handshakeState.writeMessage(
|
||||||
|
staticKeyAndIdentityMessage, 0, clientIdentityBuffer.array(), 0, clientIdentityBuffer.remaining());
|
||||||
|
|
||||||
|
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(staticKeyAndIdentityMessage)))
|
||||||
|
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||||
|
|
||||||
|
context.pipeline().replace(this, null, new NoiseStreamHandler(handshakeState.split()));
|
||||||
|
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
|
||||||
|
} finally {
|
||||||
|
frame.release();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
context.fireChannelRead(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,454 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import com.southernstorm.noise.protocol.CipherState;
|
||||||
|
import com.southernstorm.noise.protocol.HandshakeState;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.ChannelFuture;
|
||||||
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
|
import javax.crypto.ShortBufferException;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import io.netty.util.internal.EmptyArrays;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
|
|
||||||
|
class NoiseXXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest {
|
||||||
|
|
||||||
|
private ClientPublicKeysManager clientPublicKeysManager;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||||
|
|
||||||
|
super.setUp();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NoiseXXHandshakeHandler getHandler(final ECKeyPair serverKeyPair,
|
||||||
|
final byte[] serverPublicKeySignature) {
|
||||||
|
|
||||||
|
return new NoiseXXHandshakeHandler(clientPublicKeysManager, serverKeyPair, serverPublicKeySignature);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleCompleteHandshake()
|
||||||
|
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||||
|
|
||||||
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||||
|
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||||
|
|
||||||
|
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||||
|
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
|
||||||
|
|
||||||
|
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||||
|
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||||
|
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||||
|
// and issue a "handshake complete" event.
|
||||||
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
|
assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
|
||||||
|
getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
|
assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||||
|
"Handshake handler should remove self from pipeline after successful handshake");
|
||||||
|
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
|
||||||
|
"Handshake handler should insert a Noise stream handler after successful handshake");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleCompleteHandshakeMissingIdentityInformation()
|
||||||
|
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||||
|
|
||||||
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||||
|
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||||
|
|
||||||
|
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||||
|
|
||||||
|
{
|
||||||
|
final byte[] clientStaticKeyMessageBytes = new byte[64];
|
||||||
|
final int messageLength =
|
||||||
|
clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, EmptyArrays.EMPTY_BYTES, 0, 0);
|
||||||
|
|
||||||
|
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
|
||||||
|
|
||||||
|
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
|
||||||
|
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
|
||||||
|
|
||||||
|
final ChannelFuture writeClientStaticKeyMessageFuture =
|
||||||
|
getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await();
|
||||||
|
|
||||||
|
assertFalse(writeClientStaticKeyMessageFuture.isSuccess());
|
||||||
|
assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause());
|
||||||
|
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
|
||||||
|
}
|
||||||
|
|
||||||
|
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||||
|
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||||
|
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||||
|
// and issue a "handshake complete" event.
|
||||||
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||||
|
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||||
|
|
||||||
|
assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
|
||||||
|
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleCompleteHandshakeMalformedIdentityInformation()
|
||||||
|
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||||
|
|
||||||
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||||
|
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||||
|
|
||||||
|
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||||
|
|
||||||
|
{
|
||||||
|
final byte[] clientStaticKeyMessageBytes = new byte[96];
|
||||||
|
final int messageLength =
|
||||||
|
clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, new byte[32], 0, 32);
|
||||||
|
|
||||||
|
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
|
||||||
|
|
||||||
|
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
|
||||||
|
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
|
||||||
|
|
||||||
|
final ChannelFuture writeClientStaticKeyMessageFuture =
|
||||||
|
getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await();
|
||||||
|
|
||||||
|
assertFalse(writeClientStaticKeyMessageFuture.isSuccess());
|
||||||
|
assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause());
|
||||||
|
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
|
||||||
|
}
|
||||||
|
|
||||||
|
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||||
|
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||||
|
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||||
|
// and issue a "handshake complete" event.
|
||||||
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||||
|
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||||
|
|
||||||
|
assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
|
||||||
|
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleCompleteHandshakeUnrecognizedDevice()
|
||||||
|
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||||
|
|
||||||
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||||
|
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||||
|
|
||||||
|
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||||
|
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
|
||||||
|
|
||||||
|
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||||
|
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||||
|
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||||
|
// and issue a "handshake complete" event.
|
||||||
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
|
assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException);
|
||||||
|
|
||||||
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||||
|
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||||
|
|
||||||
|
assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
|
||||||
|
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleCompleteHandshakePublicKeyMismatch()
|
||||||
|
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||||
|
|
||||||
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||||
|
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
|
||||||
|
|
||||||
|
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||||
|
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
|
||||||
|
|
||||||
|
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||||
|
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||||
|
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||||
|
// and issue a "handshake complete" event.
|
||||||
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
|
assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException);
|
||||||
|
|
||||||
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||||
|
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||||
|
|
||||||
|
assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
|
||||||
|
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleCompleteHandshakeBufferedReads()
|
||||||
|
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||||
|
|
||||||
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||||
|
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
|
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
||||||
|
|
||||||
|
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||||
|
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
|
||||||
|
|
||||||
|
final ByteBuf[] additionalMessages = new ByteBuf[4];
|
||||||
|
final CipherState senderState = clientHandshakeState.split().getSender();
|
||||||
|
|
||||||
|
try {
|
||||||
|
for (int i = 0; i < additionalMessages.length; i++) {
|
||||||
|
final byte[] contentBytes = new byte[32];
|
||||||
|
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||||
|
|
||||||
|
// Copy the "plaintext" portion of the content bytes for future assertions
|
||||||
|
additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16);
|
||||||
|
|
||||||
|
// Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD
|
||||||
|
// tag
|
||||||
|
senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16);
|
||||||
|
|
||||||
|
assertTrue(
|
||||||
|
embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await()
|
||||||
|
.isSuccess());
|
||||||
|
}
|
||||||
|
|
||||||
|
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
|
||||||
|
|
||||||
|
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||||
|
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||||
|
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||||
|
// and issue a "handshake complete" event.
|
||||||
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
|
assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
|
||||||
|
getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
|
assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||||
|
"Handshake handler should remove self from pipeline after successful handshake");
|
||||||
|
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
|
||||||
|
"Handshake handler should insert a Noise stream handler after successful handshake");
|
||||||
|
|
||||||
|
for (final ByteBuf additionalMessage : additionalMessages) {
|
||||||
|
assertEquals(additionalMessage, embeddedChannel.inboundMessages().poll(),
|
||||||
|
"Buffered message should pass through pipeline after successful handshake");
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
for (final ByteBuf additionalMessage : additionalMessages) {
|
||||||
|
additionalMessage.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleCompleteHandshakeFailureBufferedReads()
|
||||||
|
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
|
||||||
|
|
||||||
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
|
||||||
|
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
|
||||||
|
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
|
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
||||||
|
|
||||||
|
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
|
||||||
|
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
|
||||||
|
|
||||||
|
final ByteBuf[] additionalMessages = new ByteBuf[4];
|
||||||
|
final CipherState senderState = clientHandshakeState.split().getSender();
|
||||||
|
|
||||||
|
try {
|
||||||
|
for (int i = 0; i < additionalMessages.length; i++) {
|
||||||
|
final byte[] contentBytes = new byte[32];
|
||||||
|
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||||
|
|
||||||
|
// Copy the "plaintext" portion of the content bytes for future assertions
|
||||||
|
additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16);
|
||||||
|
|
||||||
|
// Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD
|
||||||
|
// tag
|
||||||
|
senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16);
|
||||||
|
|
||||||
|
assertTrue(embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await().isSuccess());
|
||||||
|
}
|
||||||
|
|
||||||
|
findPublicKeyFuture.complete(Optional.empty());
|
||||||
|
|
||||||
|
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||||
|
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||||
|
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||||
|
// and issue a "handshake complete" event.
|
||||||
|
embeddedChannel.runPendingTasks();
|
||||||
|
|
||||||
|
assertNull(getNoiseHandshakeCompleteEvent());
|
||||||
|
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
|
||||||
|
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||||
|
|
||||||
|
assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
|
||||||
|
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||||
|
|
||||||
|
assertTrue(embeddedChannel.inboundMessages().isEmpty(),
|
||||||
|
"Buffered messages should not pass through pipeline after failed handshake");
|
||||||
|
} finally {
|
||||||
|
for (final ByteBuf additionalMessage : additionalMessages) {
|
||||||
|
additionalMessage.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private HandshakeState exchangeClientEphemeralAndServerStaticMessages(final ECKeyPair clientKeyPair)
|
||||||
|
throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException, InterruptedException {
|
||||||
|
|
||||||
|
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||||
|
|
||||||
|
final HandshakeState clientHandshakeState =
|
||||||
|
new HandshakeState(NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
|
||||||
|
|
||||||
|
clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0);
|
||||||
|
clientHandshakeState.start();
|
||||||
|
|
||||||
|
{
|
||||||
|
final byte[] ephemeralKeyMessageBytes = new byte[32];
|
||||||
|
clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0);
|
||||||
|
|
||||||
|
final BinaryWebSocketFrame ephemeralKeyMessageFrame =
|
||||||
|
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes));
|
||||||
|
|
||||||
|
assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess());
|
||||||
|
assertEquals(0, ephemeralKeyMessageFrame.refCnt());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
assertEquals(1, embeddedChannel.outboundMessages().size());
|
||||||
|
|
||||||
|
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
|
||||||
|
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
|
||||||
|
|
||||||
|
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
|
||||||
|
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
|
||||||
|
|
||||||
|
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
|
||||||
|
|
||||||
|
final byte[] serverPublicKeySignature = new byte[64];
|
||||||
|
|
||||||
|
final int payloadLength =
|
||||||
|
clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0);
|
||||||
|
|
||||||
|
assertEquals(serverPublicKeySignature.length, payloadLength);
|
||||||
|
|
||||||
|
final byte[] serverPublicKey = new byte[32];
|
||||||
|
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||||
|
|
||||||
|
assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature));
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientHandshakeState;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void sendClientStaticKey(final HandshakeState handshakeState, final UUID accountIdentifier, final byte deviceId)
|
||||||
|
throws ShortBufferException, InterruptedException {
|
||||||
|
|
||||||
|
final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17);
|
||||||
|
clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits());
|
||||||
|
clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits());
|
||||||
|
clientIdentityPayloadBuffer.put(deviceId);
|
||||||
|
clientIdentityPayloadBuffer.flip();
|
||||||
|
|
||||||
|
final byte[] clientStaticKeyMessageBytes = new byte[81];
|
||||||
|
final int messageLength =
|
||||||
|
handshakeState.writeMessage(clientStaticKeyMessageBytes, 0, clientIdentityPayloadBuffer.array(), 0, clientIdentityPayloadBuffer.remaining());
|
||||||
|
|
||||||
|
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
|
||||||
|
|
||||||
|
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
|
||||||
|
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
|
||||||
|
|
||||||
|
assertTrue(getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await().isSuccess());
|
||||||
|
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||||
|
import io.netty.channel.ChannelPromise;
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
|
||||||
|
class OutboundCloseWebSocketFrameHandler extends ChannelOutboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final WebSocketCloseListener webSocketCloseListener;
|
||||||
|
|
||||||
|
OutboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) {
|
||||||
|
this.webSocketCloseListener = webSocketCloseListener;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
|
||||||
|
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
|
||||||
|
webSocketCloseListener.handleWebSocketClosedByClient(closeWebSocketFrame.statusCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
super.write(context, message, promise);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,72 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import java.util.List;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
|
private EmbeddedChannel embeddedChannel;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
embeddedChannel = new EmbeddedChannel(new RejectUnsupportedMessagesHandler());
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource
|
||||||
|
void allowWebSocketFrame(final WebSocketFrame frame) {
|
||||||
|
embeddedChannel.writeOneInbound(frame);
|
||||||
|
|
||||||
|
try {
|
||||||
|
assertEquals(frame, embeddedChannel.inboundMessages().poll());
|
||||||
|
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||||
|
assertEquals(1, frame.refCnt());
|
||||||
|
} finally {
|
||||||
|
frame.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<WebSocketFrame> allowWebSocketFrame() {
|
||||||
|
return List.of(
|
||||||
|
new BinaryWebSocketFrame(),
|
||||||
|
new CloseWebSocketFrame(),
|
||||||
|
new ContinuationWebSocketFrame(),
|
||||||
|
new PingWebSocketFrame(),
|
||||||
|
new PongWebSocketFrame());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void rejectTextFrame() {
|
||||||
|
final TextWebSocketFrame textFrame = new TextWebSocketFrame();
|
||||||
|
embeddedChannel.writeOneInbound(textFrame);
|
||||||
|
|
||||||
|
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||||
|
assertEquals(0, textFrame.refCnt());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void rejectNonWebSocketFrame() {
|
||||||
|
final ByteBuf bytes = Unpooled.buffer(0);
|
||||||
|
embeddedChannel.writeOneInbound(bytes);
|
||||||
|
|
||||||
|
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||||
|
assertEquals(0, bytes.refCnt());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,18 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
interface WebSocketCloseListener {
|
||||||
|
|
||||||
|
WebSocketCloseListener NOOP_LISTENER = new WebSocketCloseListener() {
|
||||||
|
@Override
|
||||||
|
public void handleWebSocketClosedByClient(final int statusCode) {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handleWebSocketClosedByServer(final int statusCode) {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void handleWebSocketClosedByClient(int statusCode);
|
||||||
|
|
||||||
|
void handleWebSocketClosedByServer(int statusCode);
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.bootstrap.ServerBootstrap;
|
||||||
|
import io.netty.channel.Channel;
|
||||||
|
import io.netty.channel.ChannelInitializer;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.local.LocalChannel;
|
||||||
|
import io.netty.channel.local.LocalServerChannel;
|
||||||
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
|
import java.net.SocketAddress;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.security.cert.X509Certificate;
|
||||||
|
import java.util.UUID;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
import javax.annotation.Nullable;
|
||||||
|
|
||||||
|
class WebSocketNoiseTunnelClient implements AutoCloseable {
|
||||||
|
|
||||||
|
private final ServerBootstrap serverBootstrap;
|
||||||
|
private Channel serverChannel;
|
||||||
|
|
||||||
|
static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
|
||||||
|
static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
|
||||||
|
|
||||||
|
public WebSocketNoiseTunnelClient(final SocketAddress remoteServerAddress,
|
||||||
|
final URI websocketUri,
|
||||||
|
final boolean authenticated,
|
||||||
|
final ECKeyPair ecKeyPair,
|
||||||
|
final ECPublicKey rootPublicKey,
|
||||||
|
@Nullable final UUID accountIdentifier,
|
||||||
|
final byte deviceId,
|
||||||
|
final X509Certificate trustedServerCertificate,
|
||||||
|
final NioEventLoopGroup eventLoopGroup,
|
||||||
|
final WebSocketCloseListener webSocketCloseListener) {
|
||||||
|
|
||||||
|
this.serverBootstrap = new ServerBootstrap()
|
||||||
|
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
|
||||||
|
.channel(LocalServerChannel.class)
|
||||||
|
.group(eventLoopGroup)
|
||||||
|
.childHandler(new ChannelInitializer<LocalChannel>() {
|
||||||
|
@Override
|
||||||
|
protected void initChannel(final LocalChannel localChannel) {
|
||||||
|
localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(trustedServerCertificate,
|
||||||
|
websocketUri,
|
||||||
|
authenticated,
|
||||||
|
ecKeyPair,
|
||||||
|
rootPublicKey,
|
||||||
|
accountIdentifier,
|
||||||
|
deviceId,
|
||||||
|
remoteServerAddress,
|
||||||
|
webSocketCloseListener));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
LocalAddress getLocalAddress() {
|
||||||
|
return (LocalAddress) serverChannel.localAddress();
|
||||||
|
}
|
||||||
|
|
||||||
|
WebSocketNoiseTunnelClient start() throws InterruptedException {
|
||||||
|
serverChannel = serverBootstrap.bind().await().channel();
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws InterruptedException {
|
||||||
|
serverChannel.close().await();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,486 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyByte;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import io.grpc.ManagedChannel;
|
||||||
|
import io.grpc.ServerBuilder;
|
||||||
|
import io.grpc.Status;
|
||||||
|
import io.grpc.netty.NettyChannelBuilder;
|
||||||
|
import io.netty.channel.DefaultEventLoopGroup;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import io.netty.channel.local.LocalChannel;
|
||||||
|
import io.netty.channel.nio.NioEventLoopGroup;
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.http.HttpClient;
|
||||||
|
import java.net.http.HttpRequest;
|
||||||
|
import java.net.http.HttpResponse;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.security.KeyFactory;
|
||||||
|
import java.security.KeyStore;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import java.security.PrivateKey;
|
||||||
|
import java.security.SecureRandom;
|
||||||
|
import java.security.cert.CertificateException;
|
||||||
|
import java.security.cert.CertificateFactory;
|
||||||
|
import java.security.cert.X509Certificate;
|
||||||
|
import java.security.spec.InvalidKeySpecException;
|
||||||
|
import java.security.spec.PKCS8EncodedKeySpec;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import javax.net.ssl.SSLContext;
|
||||||
|
import javax.net.ssl.TrustManagerFactory;
|
||||||
|
import org.junit.jupiter.api.AfterAll;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.signal.chat.rpc.AuthenticationTypeGrpc;
|
||||||
|
import org.signal.chat.rpc.GetAuthenticatedRequest;
|
||||||
|
import org.signal.chat.rpc.GetAuthenticatedResponse;
|
||||||
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
|
|
||||||
|
class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
|
private static NioEventLoopGroup nioEventLoopGroup;
|
||||||
|
private static DefaultEventLoopGroup defaultEventLoopGroup;
|
||||||
|
private static ExecutorService delegatedTaskExecutor;
|
||||||
|
|
||||||
|
private static X509Certificate serverTlsCertificate;
|
||||||
|
|
||||||
|
private ClientPublicKeysManager clientPublicKeysManager;
|
||||||
|
|
||||||
|
private ECKeyPair rootKeyPair;
|
||||||
|
private ECKeyPair clientKeyPair;
|
||||||
|
|
||||||
|
private ManagedLocalGrpcServer authenticatedGrpcServer;
|
||||||
|
private ManagedLocalGrpcServer anonymousGrpcServer;
|
||||||
|
|
||||||
|
private WebsocketNoiseTunnelServer websocketNoiseTunnelServer;
|
||||||
|
|
||||||
|
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
|
||||||
|
private static final byte DEVICE_ID = Device.PRIMARY_ID;
|
||||||
|
|
||||||
|
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
|
||||||
|
// They were generated with:
|
||||||
|
//
|
||||||
|
// ```shell
|
||||||
|
// openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost"
|
||||||
|
// ```
|
||||||
|
private static final String SERVER_CERTIFICATE = """
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw
|
||||||
|
FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx
|
||||||
|
MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA
|
||||||
|
IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV
|
||||||
|
jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq
|
||||||
|
SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME
|
||||||
|
GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG
|
||||||
|
SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw
|
||||||
|
XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi
|
||||||
|
iOr9sHiO8Rn2u0xRKgU5Ig==
|
||||||
|
-----END CERTIFICATE-----
|
||||||
|
""";
|
||||||
|
|
||||||
|
// BEGIN/END PRIVATE KEY header/footer removed for easier parsing
|
||||||
|
private static final String SERVER_PRIVATE_KEY = """
|
||||||
|
MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj
|
||||||
|
kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd
|
||||||
|
PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA
|
||||||
|
O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo=
|
||||||
|
""";
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
static void setUpBeforeAll() throws CertificateException {
|
||||||
|
nioEventLoopGroup = new NioEventLoopGroup();
|
||||||
|
defaultEventLoopGroup = new DefaultEventLoopGroup();
|
||||||
|
delegatedTaskExecutor = Executors.newSingleThreadExecutor();
|
||||||
|
|
||||||
|
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
|
||||||
|
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
|
||||||
|
new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8)));
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() throws NoSuchAlgorithmException, InvalidKeySpecException, IOException, InterruptedException {
|
||||||
|
|
||||||
|
final PrivateKey serverTlsPrivateKey;
|
||||||
|
{
|
||||||
|
final KeyFactory keyFactory = KeyFactory.getInstance("EC");
|
||||||
|
serverTlsPrivateKey =
|
||||||
|
keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
|
||||||
|
}
|
||||||
|
|
||||||
|
rootKeyPair = Curve.generateKeyPair();
|
||||||
|
clientKeyPair = Curve.generateKeyPair();
|
||||||
|
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
|
||||||
|
|
||||||
|
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||||
|
when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||||
|
|
||||||
|
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated");
|
||||||
|
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous");
|
||||||
|
|
||||||
|
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
|
||||||
|
@Override
|
||||||
|
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||||
|
serverBuilder.addService(new AuthenticationTypeService(true));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
authenticatedGrpcServer.start();
|
||||||
|
|
||||||
|
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
|
||||||
|
@Override
|
||||||
|
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||||
|
serverBuilder.addService(new AuthenticationTypeService(false));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
anonymousGrpcServer.start();
|
||||||
|
|
||||||
|
websocketNoiseTunnelServer = new WebsocketNoiseTunnelServer(0,
|
||||||
|
new X509Certificate[] { serverTlsCertificate },
|
||||||
|
serverTlsPrivateKey,
|
||||||
|
nioEventLoopGroup,
|
||||||
|
delegatedTaskExecutor,
|
||||||
|
clientPublicKeysManager,
|
||||||
|
serverKeyPair,
|
||||||
|
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
|
||||||
|
authenticatedGrpcServerAddress,
|
||||||
|
anonymousGrpcServerAddress);
|
||||||
|
|
||||||
|
websocketNoiseTunnelServer.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() throws InterruptedException {
|
||||||
|
websocketNoiseTunnelServer.stop();
|
||||||
|
authenticatedGrpcServer.stop();
|
||||||
|
anonymousGrpcServer.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterAll
|
||||||
|
static void tearDownAfterAll() throws InterruptedException {
|
||||||
|
nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
|
||||||
|
defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
|
||||||
|
|
||||||
|
delegatedTaskExecutor.shutdown();
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAuthenticated() throws InterruptedException {
|
||||||
|
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAuthenticatedClient()) {
|
||||||
|
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build());
|
||||||
|
|
||||||
|
assertTrue(response.getAuthenticated());
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAuthenticatedBadServerKeySignature() throws InterruptedException {
|
||||||
|
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
||||||
|
|
||||||
|
|
||||||
|
// Try to verify the server's public key with something other than the key with which it was signed
|
||||||
|
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
|
||||||
|
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException {
|
||||||
|
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
|
||||||
|
|
||||||
|
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
|
||||||
|
buildAndStartAuthenticatedClient(webSocketCloseListener)) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAuthenticatedUnrecognizedDevice() throws InterruptedException {
|
||||||
|
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
||||||
|
|
||||||
|
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||||
|
|
||||||
|
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
|
||||||
|
buildAndStartAuthenticatedClient(webSocketCloseListener)) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAuthenticatedToAnonymousService() throws InterruptedException {
|
||||||
|
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
||||||
|
|
||||||
|
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(
|
||||||
|
websocketNoiseTunnelServer.getLocalAddress(),
|
||||||
|
URI.create("wss://localhost/anonymous"),
|
||||||
|
true,
|
||||||
|
clientKeyPair,
|
||||||
|
rootKeyPair.getPublicKey(),
|
||||||
|
ACCOUNT_IDENTIFIER,
|
||||||
|
DEVICE_ID,
|
||||||
|
serverTlsCertificate,
|
||||||
|
nioEventLoopGroup,
|
||||||
|
webSocketCloseListener)
|
||||||
|
.start()) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAnonymous() throws InterruptedException {
|
||||||
|
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAnonymousClient()) {
|
||||||
|
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build());
|
||||||
|
|
||||||
|
assertFalse(response.getAuthenticated());
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAnonymousBadServerKeySignature() throws InterruptedException {
|
||||||
|
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
||||||
|
|
||||||
|
// Try to verify the server's public key with something other than the key with which it was signed
|
||||||
|
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
|
||||||
|
buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void connectAnonymousToAuthenticatedService() throws InterruptedException {
|
||||||
|
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
|
||||||
|
|
||||||
|
try (final WebSocketNoiseTunnelClient websocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(
|
||||||
|
websocketNoiseTunnelServer.getLocalAddress(),
|
||||||
|
URI.create("wss://localhost/authenticated"),
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
rootKeyPair.getPublicKey(),
|
||||||
|
null,
|
||||||
|
(byte) 0,
|
||||||
|
serverTlsCertificate,
|
||||||
|
nioEventLoopGroup,
|
||||||
|
webSocketCloseListener)
|
||||||
|
.start()) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(websocketNoiseTunnelClient.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||||
|
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
private ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
|
||||||
|
return NettyChannelBuilder.forAddress(localAddress)
|
||||||
|
.channelType(LocalChannel.class)
|
||||||
|
.eventLoopGroup(defaultEventLoopGroup)
|
||||||
|
.usePlaintext()
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void rejectIllegalRequests() throws Exception {
|
||||||
|
|
||||||
|
final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
|
||||||
|
keyStore.load(null, null);
|
||||||
|
keyStore.setCertificateEntry("tunnel", serverTlsCertificate);
|
||||||
|
|
||||||
|
final TrustManagerFactory trustManagerFactory =
|
||||||
|
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
|
||||||
|
|
||||||
|
trustManagerFactory.init(keyStore);
|
||||||
|
|
||||||
|
final SSLContext sslContext = SSLContext.getInstance("TLS");
|
||||||
|
sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
|
||||||
|
|
||||||
|
final URI authenticatedUri =
|
||||||
|
new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null);
|
||||||
|
|
||||||
|
final URI incorrectUri =
|
||||||
|
new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null);
|
||||||
|
|
||||||
|
try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
|
||||||
|
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
|
||||||
|
.uri(authenticatedUri)
|
||||||
|
.PUT(HttpRequest.BodyPublishers.ofString("test"))
|
||||||
|
.build(),
|
||||||
|
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||||
|
"Non-GET requests should not be allowed");
|
||||||
|
|
||||||
|
assertEquals(426, httpClient.send(HttpRequest.newBuilder()
|
||||||
|
.GET()
|
||||||
|
.uri(authenticatedUri)
|
||||||
|
.build(),
|
||||||
|
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||||
|
"GET requests without upgrade headers should not be allowed");
|
||||||
|
|
||||||
|
assertEquals(404, httpClient.send(HttpRequest.newBuilder()
|
||||||
|
.GET()
|
||||||
|
.uri(incorrectUri)
|
||||||
|
.build(),
|
||||||
|
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||||
|
"GET requests to unrecognized URIs should not be allowed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException {
|
||||||
|
return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER);
|
||||||
|
}
|
||||||
|
|
||||||
|
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener)
|
||||||
|
throws InterruptedException {
|
||||||
|
|
||||||
|
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey());
|
||||||
|
}
|
||||||
|
|
||||||
|
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener,
|
||||||
|
final ECPublicKey rootPublicKey) throws InterruptedException {
|
||||||
|
|
||||||
|
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
|
||||||
|
WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
|
||||||
|
true,
|
||||||
|
clientKeyPair,
|
||||||
|
rootPublicKey,
|
||||||
|
ACCOUNT_IDENTIFIER,
|
||||||
|
DEVICE_ID,
|
||||||
|
serverTlsCertificate,
|
||||||
|
nioEventLoopGroup,
|
||||||
|
webSocketCloseListener)
|
||||||
|
.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
private WebSocketNoiseTunnelClient buildAndStartAnonymousClient() throws InterruptedException {
|
||||||
|
return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey());
|
||||||
|
}
|
||||||
|
|
||||||
|
private WebSocketNoiseTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener,
|
||||||
|
final ECPublicKey rootPublicKey) throws InterruptedException {
|
||||||
|
|
||||||
|
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
|
||||||
|
WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI,
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
rootPublicKey,
|
||||||
|
null,
|
||||||
|
(byte) 0,
|
||||||
|
serverTlsCertificate,
|
||||||
|
nioEventLoopGroup,
|
||||||
|
webSocketCloseListener)
|
||||||
|
.start();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,104 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||||
|
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
|
import io.netty.handler.codec.http.DefaultFullHttpRequest;
|
||||||
|
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||||
|
import io.netty.handler.codec.http.FullHttpRequest;
|
||||||
|
import io.netty.handler.codec.http.FullHttpResponse;
|
||||||
|
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||||
|
import io.netty.handler.codec.http.HttpHeaderValues;
|
||||||
|
import io.netty.handler.codec.http.HttpHeaders;
|
||||||
|
import io.netty.handler.codec.http.HttpMethod;
|
||||||
|
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||||
|
import io.netty.handler.codec.http.HttpVersion;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
|
|
||||||
|
class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
|
private EmbeddedChannel embeddedChannel;
|
||||||
|
|
||||||
|
private static final String AUTHENTICATED_PATH = "/authenticated";
|
||||||
|
private static final String ANONYMOUS_PATH = "/anonymous";
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
embeddedChannel = new EmbeddedChannel(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_PATH, ANONYMOUS_PATH));
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
|
||||||
|
void handleValidRequest(final String path) {
|
||||||
|
final FullHttpRequest request = buildRequest(HttpMethod.GET, path,
|
||||||
|
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
|
||||||
|
|
||||||
|
try {
|
||||||
|
embeddedChannel.writeOneInbound(request);
|
||||||
|
|
||||||
|
assertEquals(1, request.refCnt());
|
||||||
|
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||||
|
assertEquals(request, embeddedChannel.inboundMessages().poll());
|
||||||
|
} finally {
|
||||||
|
request.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
|
||||||
|
void handleUpgradeRequired(final String path) {
|
||||||
|
final FullHttpRequest request = buildRequest(HttpMethod.GET, path, new DefaultHttpHeaders());
|
||||||
|
|
||||||
|
embeddedChannel.writeOneInbound(request);
|
||||||
|
|
||||||
|
assertEquals(0, request.refCnt());
|
||||||
|
assertHttpResponse(HttpResponseStatus.UPGRADE_REQUIRED);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleBadPath() {
|
||||||
|
final FullHttpRequest request = buildRequest(HttpMethod.GET, "/incorrect",
|
||||||
|
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
|
||||||
|
|
||||||
|
embeddedChannel.writeOneInbound(request);
|
||||||
|
|
||||||
|
assertEquals(0, request.refCnt());
|
||||||
|
assertHttpResponse(HttpResponseStatus.NOT_FOUND);
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
|
||||||
|
void handleMethodNotAllowed(final String path) {
|
||||||
|
final FullHttpRequest request = buildRequest(HttpMethod.DELETE, path,
|
||||||
|
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
|
||||||
|
|
||||||
|
embeddedChannel.writeOneInbound(request);
|
||||||
|
|
||||||
|
assertEquals(0, request.refCnt());
|
||||||
|
assertHttpResponse(HttpResponseStatus.METHOD_NOT_ALLOWED);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertHttpResponse(final HttpResponseStatus expectedStatus) {
|
||||||
|
assertEquals(1, embeddedChannel.outboundMessages().size());
|
||||||
|
|
||||||
|
final FullHttpResponse response = assertInstanceOf(FullHttpResponse.class, embeddedChannel.outboundMessages().poll());
|
||||||
|
|
||||||
|
//noinspection DataFlowIssue
|
||||||
|
assertEquals(expectedStatus, response.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
private FullHttpRequest buildRequest(final HttpMethod method, final String path, final HttpHeaders headers) {
|
||||||
|
return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1,
|
||||||
|
method,
|
||||||
|
path,
|
||||||
|
Unpooled.buffer(0),
|
||||||
|
headers,
|
||||||
|
new DefaultHttpHeaders(true));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
|
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||||
|
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
|
||||||
|
class WebsocketHandshakeCompleteListenerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
|
private UserEventRecordingHandler userEventRecordingHandler;
|
||||||
|
private EmbeddedChannel embeddedChannel;
|
||||||
|
|
||||||
|
private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
|
||||||
|
|
||||||
|
private final List<Object> receivedEvents = new ArrayList<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||||
|
receivedEvents.add(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Object> getReceivedEvents() {
|
||||||
|
return receivedEvents;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
userEventRecordingHandler = new UserEventRecordingHandler();
|
||||||
|
|
||||||
|
embeddedChannel = new EmbeddedChannel(
|
||||||
|
new WebsocketHandshakeCompleteListener(mock(ClientPublicKeysManager.class), Curve.generateKeyPair(), new byte[64]),
|
||||||
|
userEventRecordingHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource
|
||||||
|
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
|
||||||
|
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||||
|
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
|
||||||
|
|
||||||
|
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||||
|
|
||||||
|
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
|
||||||
|
|
||||||
|
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Arguments> handleWebSocketHandshakeComplete() {
|
||||||
|
return List.of(
|
||||||
|
Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
|
||||||
|
Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleWebSocketHandshakeCompleteUnexpectedPath() {
|
||||||
|
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||||
|
new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
|
||||||
|
|
||||||
|
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||||
|
|
||||||
|
assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void handleUnrecognizedEvent() {
|
||||||
|
final Object unrecognizedEvent = new Object();
|
||||||
|
|
||||||
|
embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
|
||||||
|
assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2023 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
option java_multiple_files = true;
|
||||||
|
|
||||||
|
package org.signal.chat.rpc;
|
||||||
|
|
||||||
|
// A simple test service that identifies its authentication type to callers
|
||||||
|
service AuthenticationType {
|
||||||
|
rpc GetAuthenticated (GetAuthenticatedRequest) returns (GetAuthenticatedResponse) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetAuthenticatedRequest {
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetAuthenticatedResponse {
|
||||||
|
bool authenticated = 1;
|
||||||
|
}
|
Loading…
Reference in New Issue