Add a plaintext mode to the Noise-over-WebSocket server for local testing

This commit is contained in:
Jon Chambers 2024-05-21 17:23:22 -04:00 committed by Jon Chambers
parent 9e36cabef0
commit 9a2bfe1180
4 changed files with 106 additions and 34 deletions

View File

@ -23,6 +23,7 @@ import java.net.InetSocketAddress;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import javax.annotation.Nullable;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -44,8 +45,8 @@ public class WebsocketNoiseTunnelServer implements Managed {
private static final Logger log = LoggerFactory.getLogger(WebsocketNoiseTunnelServer.class); private static final Logger log = LoggerFactory.getLogger(WebsocketNoiseTunnelServer.class);
public WebsocketNoiseTunnelServer(final int websocketPort, public WebsocketNoiseTunnelServer(final int websocketPort,
final X509Certificate[] tlsCertificateChain, @Nullable final X509Certificate[] tlsCertificateChain,
final PrivateKey tlsPrivateKey, @Nullable final PrivateKey tlsPrivateKey,
final NioEventLoopGroup eventLoopGroup, final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor, final Executor delegatedTaskExecutor,
final ClientConnectionManager clientConnectionManager, final ClientConnectionManager clientConnectionManager,
@ -56,22 +57,29 @@ public class WebsocketNoiseTunnelServer implements Managed {
final LocalAddress anonymousGrpcServerAddress, final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws SSLException { final String recognizedProxySecret) throws SSLException {
final SslProvider sslProvider; @Nullable final SslContext sslContext;
if (OpenSsl.isAvailable()) { if (tlsCertificateChain != null && tlsPrivateKey != null) {
log.info("Native OpenSSL provider is available; will use native provider"); final SslProvider sslProvider;
sslProvider = SslProvider.OPENSSL;
if (OpenSsl.isAvailable()) {
log.info("Native OpenSSL provider is available; will use native provider");
sslProvider = SslProvider.OPENSSL;
} else {
log.info("No native SSL provider available; will use JDK provider");
sslProvider = SslProvider.JDK;
}
sslContext = SslContextBuilder.forServer(tlsPrivateKey, tlsCertificateChain)
.clientAuth(ClientAuth.NONE)
.protocols(SslProtocols.TLS_v1_3)
.sslProvider(sslProvider)
.build();
} else { } else {
log.info("No native SSL provider available; will use JDK provider"); log.warn("No TLS credentials provided; Noise-over-WebSocket tunnel will not use TLS. This configuration is not suitable for production environments.");
sslProvider = SslProvider.JDK; sslContext = null;
} }
final SslContext sslContext = SslContextBuilder.forServer(tlsPrivateKey, tlsCertificateChain)
.clientAuth(ClientAuth.NONE)
.protocols(SslProtocols.TLS_v1_3)
.sslProvider(sslProvider)
.build();
this.bootstrap = new ServerBootstrap() this.bootstrap = new ServerBootstrap()
.group(eventLoopGroup) .group(eventLoopGroup)
.channel(NioServerSocketChannel.class) .channel(NioServerSocketChannel.class)
@ -79,8 +87,11 @@ public class WebsocketNoiseTunnelServer implements Managed {
.childHandler(new ChannelInitializer<SocketChannel>() { .childHandler(new ChannelInitializer<SocketChannel>() {
@Override @Override
protected void initChannel(SocketChannel socketChannel) { protected void initChannel(SocketChannel socketChannel) {
if (sslContext != null) {
socketChannel.pipeline().addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor));
}
socketChannel.pipeline() socketChannel.pipeline()
.addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor))
.addLast(new HttpServerCodec()) .addLast(new HttpServerCodec())
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN)) .addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
// The WebSocket opening handshake handler will remove itself from the pipeline once it has received a valid WebSocket upgrade // The WebSocket opening handshake handler will remove itself from the pipeline once it has received a valid WebSocket upgrade

View File

@ -28,7 +28,8 @@ import org.signal.libsignal.protocol.ecc.ECPublicKey;
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final X509Certificate trustedServerCertificate; private final boolean useTls;
@Nullable private final X509Certificate trustedServerCertificate;
private final URI websocketUri; private final URI websocketUri;
private final boolean authenticated; private final boolean authenticated;
@Nullable private final ECKeyPair ecKeyPair; @Nullable private final ECKeyPair ecKeyPair;
@ -44,7 +45,8 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake"; private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
EstablishRemoteConnectionHandler( EstablishRemoteConnectionHandler(
final X509Certificate trustedServerCertificate, final boolean useTls,
@Nullable final X509Certificate trustedServerCertificate,
final URI websocketUri, final URI websocketUri,
final boolean authenticated, final boolean authenticated,
@Nullable final ECKeyPair ecKeyPair, @Nullable final ECKeyPair ecKeyPair,
@ -55,6 +57,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
final SocketAddress remoteServerAddress, final SocketAddress remoteServerAddress,
final WebSocketCloseListener webSocketCloseListener) { final WebSocketCloseListener webSocketCloseListener) {
this.useTls = useTls;
this.trustedServerCertificate = trustedServerCertificate; this.trustedServerCertificate = trustedServerCertificate;
this.websocketUri = websocketUri; this.websocketUri = websocketUri;
this.authenticated = authenticated; this.authenticated = authenticated;
@ -75,12 +78,17 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
.handler(new ChannelInitializer<SocketChannel>() { .handler(new ChannelInitializer<SocketChannel>() {
@Override @Override
protected void initChannel(final SocketChannel channel) throws SSLException { protected void initChannel(final SocketChannel channel) throws SSLException {
if (useTls) {
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
if (trustedServerCertificate != null) {
sslContextBuilder.trustManager(trustedServerCertificate);
}
channel.pipeline().addLast(sslContextBuilder.build().newHandler(channel.alloc()));
}
channel.pipeline() channel.pipeline()
.addLast(SslContextBuilder
.forClient()
.trustManager(trustedServerCertificate)
.build()
.newHandler(channel.alloc()))
.addLast(new HttpClientCodec()) .addLast(new HttpClientCodec())
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN)) .addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
// Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we // Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we

View File

@ -32,7 +32,8 @@ class WebSocketNoiseTunnelClient implements AutoCloseable {
@Nullable final UUID accountIdentifier, @Nullable final UUID accountIdentifier,
final byte deviceId, final byte deviceId,
final HttpHeaders headers, final HttpHeaders headers,
final X509Certificate trustedServerCertificate, final boolean useTls,
@Nullable final X509Certificate trustedServerCertificate,
final NioEventLoopGroup eventLoopGroup, final NioEventLoopGroup eventLoopGroup,
final WebSocketCloseListener webSocketCloseListener) { final WebSocketCloseListener webSocketCloseListener) {
@ -43,7 +44,8 @@ class WebSocketNoiseTunnelClient implements AutoCloseable {
.childHandler(new ChannelInitializer<LocalChannel>() { .childHandler(new ChannelInitializer<LocalChannel>() {
@Override @Override
protected void initChannel(final LocalChannel localChannel) { protected void initChannel(final LocalChannel localChannel) {
localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(trustedServerCertificate, localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(useTls,
trustedServerCertificate,
websocketUri, websocketUri,
authenticated, authenticated,
ecKeyPair, ecKeyPair,

View File

@ -89,7 +89,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private ManagedLocalGrpcServer authenticatedGrpcServer; private ManagedLocalGrpcServer authenticatedGrpcServer;
private ManagedLocalGrpcServer anonymousGrpcServer; private ManagedLocalGrpcServer anonymousGrpcServer;
private WebsocketNoiseTunnelServer websocketNoiseTunnelServer; private WebsocketNoiseTunnelServer tlsWebsocketNoiseTunnelServer;
private WebsocketNoiseTunnelServer plaintextWebsocketNoiseTunnelServer;
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID(); private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
private static final byte DEVICE_ID = Device.PRIMARY_ID; private static final byte DEVICE_ID = Device.PRIMARY_ID;
@ -184,7 +185,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServer.start(); anonymousGrpcServer.start();
websocketNoiseTunnelServer = new WebsocketNoiseTunnelServer(0, tlsWebsocketNoiseTunnelServer = new WebsocketNoiseTunnelServer(0,
new X509Certificate[] { serverTlsCertificate }, new X509Certificate[] { serverTlsCertificate },
serverTlsPrivateKey, serverTlsPrivateKey,
nioEventLoopGroup, nioEventLoopGroup,
@ -197,12 +198,28 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServerAddress, anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET); RECOGNIZED_PROXY_SECRET);
websocketNoiseTunnelServer.start(); tlsWebsocketNoiseTunnelServer.start();
plaintextWebsocketNoiseTunnelServer = new WebsocketNoiseTunnelServer(0,
null,
null,
nioEventLoopGroup,
delegatedTaskExecutor,
clientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
plaintextWebsocketNoiseTunnelServer.start();
} }
@AfterEach @AfterEach
void tearDown() throws InterruptedException { void tearDown() throws InterruptedException {
websocketNoiseTunnelServer.stop(); tlsWebsocketNoiseTunnelServer.stop();
plaintextWebsocketNoiseTunnelServer.stop();
authenticatedGrpcServer.stop(); authenticatedGrpcServer.stop();
anonymousGrpcServer.stop(); anonymousGrpcServer.stop();
} }
@ -234,6 +251,36 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
} }
} }
@Test
void connectAuthenticatedPlaintext() throws InterruptedException {
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(tlsWebsocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
true,
clientKeyPair,
rootKeyPair.getPublicKey(),
ACCOUNT_IDENTIFIER,
DEVICE_ID,
new DefaultHttpHeaders(),
true,
serverTlsCertificate,
nioEventLoopGroup,
WebSocketCloseListener.NOOP_LISTENER)
.start()) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
} finally {
channel.shutdown();
}
}
}
@Test @Test
void connectAuthenticatedBadServerKeySignature() throws InterruptedException { void connectAuthenticatedBadServerKeySignature() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
@ -313,7 +360,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = new WebSocketNoiseTunnelClient( try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(
websocketNoiseTunnelServer.getLocalAddress(), tlsWebsocketNoiseTunnelServer.getLocalAddress(),
URI.create("wss://localhost/anonymous"), URI.create("wss://localhost/anonymous"),
true, true,
clientKeyPair, clientKeyPair,
@ -321,6 +368,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
ACCOUNT_IDENTIFIER, ACCOUNT_IDENTIFIER,
DEVICE_ID, DEVICE_ID,
new DefaultHttpHeaders(), new DefaultHttpHeaders(),
true,
serverTlsCertificate, serverTlsCertificate,
nioEventLoopGroup, nioEventLoopGroup,
webSocketCloseListener) webSocketCloseListener)
@ -386,7 +434,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class); final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
try (final WebSocketNoiseTunnelClient websocketNoiseTunnelClient = new WebSocketNoiseTunnelClient( try (final WebSocketNoiseTunnelClient websocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(
websocketNoiseTunnelServer.getLocalAddress(), tlsWebsocketNoiseTunnelServer.getLocalAddress(),
URI.create("wss://localhost/authenticated"), URI.create("wss://localhost/authenticated"),
false, false,
null, null,
@ -394,6 +442,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
null, null,
(byte) 0, (byte) 0,
new DefaultHttpHeaders(), new DefaultHttpHeaders(),
true,
serverTlsCertificate, serverTlsCertificate,
nioEventLoopGroup, nioEventLoopGroup,
webSocketCloseListener) webSocketCloseListener)
@ -438,10 +487,10 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom()); sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
final URI authenticatedUri = final URI authenticatedUri =
new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null); new URI("https", null, "localhost", tlsWebsocketNoiseTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null);
final URI incorrectUri = final URI incorrectUri =
new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null); new URI("https", null, "localhost", tlsWebsocketNoiseTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null);
try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) { try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
assertEquals(405, httpClient.send(HttpRequest.newBuilder() assertEquals(405, httpClient.send(HttpRequest.newBuilder()
@ -561,7 +610,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final ECPublicKey rootPublicKey, final ECPublicKey rootPublicKey,
final HttpHeaders headers) throws InterruptedException { final HttpHeaders headers) throws InterruptedException {
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(), return new WebSocketNoiseTunnelClient(tlsWebsocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI, WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
true, true,
clientKeyPair, clientKeyPair,
@ -569,6 +618,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
ACCOUNT_IDENTIFIER, ACCOUNT_IDENTIFIER,
DEVICE_ID, DEVICE_ID,
headers, headers,
true,
serverTlsCertificate, serverTlsCertificate,
nioEventLoopGroup, nioEventLoopGroup,
webSocketCloseListener) webSocketCloseListener)
@ -583,7 +633,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final ECPublicKey rootPublicKey, final ECPublicKey rootPublicKey,
final HttpHeaders headers) throws InterruptedException { final HttpHeaders headers) throws InterruptedException {
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(), return new WebSocketNoiseTunnelClient(tlsWebsocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI, WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI,
false, false,
null, null,
@ -591,6 +641,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
null, null,
(byte) 0, (byte) 0,
headers, headers,
true,
serverTlsCertificate, serverTlsCertificate,
nioEventLoopGroup, nioEventLoopGroup,
webSocketCloseListener) webSocketCloseListener)