Close remote connections only after all active server calls have completed
This commit is contained in:
parent
bb8ce6d981
commit
f191c68efc
|
@ -0,0 +1,55 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.grpc;
|
||||||
|
|
||||||
|
import io.grpc.Context;
|
||||||
|
import io.grpc.ForwardingServerCallListener;
|
||||||
|
import io.grpc.Grpc;
|
||||||
|
import io.grpc.Metadata;
|
||||||
|
import io.grpc.ServerCall;
|
||||||
|
import io.grpc.ServerCallHandler;
|
||||||
|
import io.grpc.ServerInterceptor;
|
||||||
|
import io.grpc.Status;
|
||||||
|
import io.netty.channel.local.LocalAddress;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with
|
||||||
|
* {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise.
|
||||||
|
*/
|
||||||
|
public class ChannelShutdownInterceptor implements ServerInterceptor {
|
||||||
|
|
||||||
|
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||||
|
|
||||||
|
public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||||
|
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||||
|
final Metadata headers,
|
||||||
|
final ServerCallHandler<ReqT, RespT> next) {
|
||||||
|
|
||||||
|
if (!grpcClientConnectionManager.handleServerCallStart(call)) {
|
||||||
|
// Don't allow new calls if the connection is getting ready to close
|
||||||
|
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) {
|
||||||
|
@Override
|
||||||
|
public void onComplete() {
|
||||||
|
grpcClientConnectionManager.handleServerCallComplete(call);
|
||||||
|
super.onComplete();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onCancel() {
|
||||||
|
grpcClientConnectionManager.handleServerCallComplete(call);
|
||||||
|
super.onCancel();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
|
@ -27,6 +27,7 @@ import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
|
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
|
||||||
|
import org.whispersystems.textsecuregcm.util.ClosableEpoch;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
|
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
|
||||||
|
@ -58,6 +59,10 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
||||||
public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
|
public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
|
||||||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
|
AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
|
||||||
|
AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
|
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -107,6 +112,39 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
||||||
return requestAttributes;
|
return requestAttributes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles the start of a server call, incrementing the active call count for the remote channel associated with the
|
||||||
|
* given server call.
|
||||||
|
*
|
||||||
|
* @param serverCall the server call to start
|
||||||
|
*
|
||||||
|
* @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the
|
||||||
|
* underlying channel is closing
|
||||||
|
*/
|
||||||
|
public boolean handleServerCallStart(final ServerCall<?, ?> serverCall) {
|
||||||
|
try {
|
||||||
|
return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive();
|
||||||
|
} catch (final ChannelNotFoundException e) {
|
||||||
|
// This would only happen if the channel had already closed, which is certainly possible. In this case, the call
|
||||||
|
// should certainly not proceed.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel
|
||||||
|
* associated with the given server call.
|
||||||
|
*
|
||||||
|
* @param serverCall the server call to complete
|
||||||
|
*/
|
||||||
|
public void handleServerCallComplete(final ServerCall<?, ?> serverCall) {
|
||||||
|
try {
|
||||||
|
getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart();
|
||||||
|
} catch (final ChannelNotFoundException ignored) {
|
||||||
|
// In practice, we'd only get here if the channel has already closed, so we can just ignore the exception
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Closes any client connections to this host associated with the given authenticated device.
|
* Closes any client connections to this host associated with the given authenticated device.
|
||||||
*
|
*
|
||||||
|
@ -119,10 +157,13 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
||||||
final List<Channel> channelsToClose =
|
final List<Channel> channelsToClose =
|
||||||
new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()));
|
new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()));
|
||||||
|
|
||||||
channelsToClose.forEach(channel ->
|
channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close());
|
||||||
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
|
}
|
||||||
.toWebSocketCloseStatus("Reauthentication required")))
|
|
||||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
|
private static void closeRemoteChannel(final Channel channel) {
|
||||||
|
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
|
||||||
|
.toWebSocketCloseStatus("Reauthentication required")))
|
||||||
|
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
|
@ -200,6 +241,9 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
|
||||||
maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
|
maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
|
||||||
remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
|
remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
|
||||||
|
|
||||||
|
remoteChannel.attr(EPOCH_ATTRIBUTE_KEY)
|
||||||
|
.set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel)));
|
||||||
|
|
||||||
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
|
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
|
||||||
|
|
||||||
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
|
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.grpc;
|
||||||
|
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import io.grpc.Metadata;
|
||||||
|
import io.grpc.ServerCall;
|
||||||
|
import io.grpc.ServerCallHandler;
|
||||||
|
import io.grpc.Status;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||||
|
|
||||||
|
class ChannelShutdownInterceptorTest {
|
||||||
|
|
||||||
|
private GrpcClientConnectionManager grpcClientConnectionManager;
|
||||||
|
private ChannelShutdownInterceptor channelShutdownInterceptor;
|
||||||
|
|
||||||
|
private ServerCallHandler<String, String> nextCallHandler;
|
||||||
|
|
||||||
|
private static final Metadata HEADERS = new Metadata();
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
|
||||||
|
channelShutdownInterceptor = new ChannelShutdownInterceptor(grpcClientConnectionManager);
|
||||||
|
|
||||||
|
//noinspection unchecked
|
||||||
|
nextCallHandler = mock(ServerCallHandler.class);
|
||||||
|
|
||||||
|
//noinspection unchecked
|
||||||
|
when(nextCallHandler.startCall(any(), any())).thenReturn(mock(ServerCall.Listener.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void interceptCallComplete() {
|
||||||
|
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
|
||||||
|
|
||||||
|
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
|
||||||
|
|
||||||
|
final ServerCall.Listener<String> serverCallListener =
|
||||||
|
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
|
||||||
|
|
||||||
|
serverCallListener.onComplete();
|
||||||
|
|
||||||
|
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
|
||||||
|
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
|
||||||
|
verify(serverCall, never()).close(any(), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void interceptCallCancelled() {
|
||||||
|
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
|
||||||
|
|
||||||
|
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
|
||||||
|
|
||||||
|
final ServerCall.Listener<String> serverCallListener =
|
||||||
|
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
|
||||||
|
|
||||||
|
serverCallListener.onCancel();
|
||||||
|
|
||||||
|
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
|
||||||
|
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
|
||||||
|
verify(serverCall, never()).close(any(), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void interceptCallChannelClosing() {
|
||||||
|
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
|
||||||
|
|
||||||
|
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(false);
|
||||||
|
|
||||||
|
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
|
||||||
|
|
||||||
|
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
|
||||||
|
verify(grpcClientConnectionManager, never()).handleServerCallComplete(serverCall);
|
||||||
|
verify(serverCall).close(eq(Status.UNAVAILABLE), any());
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,14 +12,38 @@ import org.signal.chat.rpc.EchoServiceGrpc;
|
||||||
|
|
||||||
public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase {
|
public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase {
|
||||||
@Override
|
@Override
|
||||||
public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
|
public void echo(final EchoRequest echoRequest, final StreamObserver<EchoResponse> responseObserver) {
|
||||||
responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build());
|
responseObserver.onNext(buildResponse(echoRequest));
|
||||||
responseObserver.onCompleted();
|
responseObserver.onCompleted();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void echo2(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
|
public void echo2(final EchoRequest echoRequest, final StreamObserver<EchoResponse> responseObserver) {
|
||||||
responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build());
|
responseObserver.onNext(buildResponse(echoRequest));
|
||||||
responseObserver.onCompleted();
|
responseObserver.onCompleted();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public StreamObserver<EchoRequest> echoStream(final StreamObserver<EchoResponse> responseObserver) {
|
||||||
|
return new StreamObserver<>() {
|
||||||
|
@Override
|
||||||
|
public void onNext(final EchoRequest echoRequest) {
|
||||||
|
responseObserver.onNext(buildResponse(echoRequest));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(final Throwable throwable) {
|
||||||
|
responseObserver.onError(throwable);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onCompleted() {
|
||||||
|
responseObserver.onCompleted();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private static EchoResponse buildResponse(final EchoRequest echoRequest) {
|
||||||
|
return EchoResponse.newBuilder().setPayload(echoRequest.getPayload()).build();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package org.whispersystems.textsecuregcm.grpc.net;
|
package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.ArgumentMatchers.anyByte;
|
import static org.mockito.ArgumentMatchers.anyByte;
|
||||||
|
@ -8,10 +9,12 @@ import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import com.google.protobuf.ByteString;
|
||||||
import io.grpc.ManagedChannel;
|
import io.grpc.ManagedChannel;
|
||||||
import io.grpc.ServerBuilder;
|
import io.grpc.ServerBuilder;
|
||||||
import io.grpc.Status;
|
import io.grpc.Status;
|
||||||
import io.grpc.netty.NettyChannelBuilder;
|
import io.grpc.netty.NettyChannelBuilder;
|
||||||
|
import io.grpc.stub.StreamObserver;
|
||||||
import io.netty.channel.DefaultEventLoopGroup;
|
import io.netty.channel.DefaultEventLoopGroup;
|
||||||
import io.netty.channel.local.LocalAddress;
|
import io.netty.channel.local.LocalAddress;
|
||||||
import io.netty.channel.local.LocalChannel;
|
import io.netty.channel.local.LocalChannel;
|
||||||
|
@ -61,6 +64,9 @@ import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.ValueSource;
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
|
import org.signal.chat.rpc.EchoRequest;
|
||||||
|
import org.signal.chat.rpc.EchoResponse;
|
||||||
|
import org.signal.chat.rpc.EchoServiceGrpc;
|
||||||
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
||||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||||
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||||
|
@ -71,6 +77,8 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
|
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
|
||||||
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
|
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
|
||||||
|
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
|
||||||
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
|
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
|
||||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
|
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
|
||||||
|
@ -83,6 +91,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
||||||
private static NioEventLoopGroup nioEventLoopGroup;
|
private static NioEventLoopGroup nioEventLoopGroup;
|
||||||
private static DefaultEventLoopGroup defaultEventLoopGroup;
|
private static DefaultEventLoopGroup defaultEventLoopGroup;
|
||||||
private static ExecutorService delegatedTaskExecutor;
|
private static ExecutorService delegatedTaskExecutor;
|
||||||
|
private static ExecutorService serverCallExecutor;
|
||||||
|
|
||||||
private static X509Certificate serverTlsCertificate;
|
private static X509Certificate serverTlsCertificate;
|
||||||
|
|
||||||
|
@ -136,7 +145,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
||||||
static void setUpBeforeAll() throws CertificateException {
|
static void setUpBeforeAll() throws CertificateException {
|
||||||
nioEventLoopGroup = new NioEventLoopGroup();
|
nioEventLoopGroup = new NioEventLoopGroup();
|
||||||
defaultEventLoopGroup = new DefaultEventLoopGroup();
|
defaultEventLoopGroup = new DefaultEventLoopGroup();
|
||||||
delegatedTaskExecutor = Executors.newSingleThreadExecutor();
|
delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
||||||
|
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
||||||
|
|
||||||
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
|
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
|
||||||
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
|
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
|
||||||
|
@ -171,7 +181,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
||||||
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
|
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
|
||||||
@Override
|
@Override
|
||||||
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||||
serverBuilder.addService(new RequestAttributesServiceImpl())
|
serverBuilder
|
||||||
|
.executor(serverCallExecutor)
|
||||||
|
.addService(new RequestAttributesServiceImpl())
|
||||||
|
.addService(new EchoServiceImpl())
|
||||||
|
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
|
||||||
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
|
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
|
||||||
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
|
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
|
||||||
}
|
}
|
||||||
|
@ -182,7 +196,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
||||||
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
|
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
|
||||||
@Override
|
@Override
|
||||||
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||||
serverBuilder.addService(new RequestAttributesServiceImpl())
|
serverBuilder
|
||||||
|
.executor(serverCallExecutor)
|
||||||
|
.addService(new RequestAttributesServiceImpl())
|
||||||
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
|
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
|
||||||
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
|
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
|
||||||
}
|
}
|
||||||
|
@ -195,7 +211,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
||||||
serverTlsPrivateKey,
|
serverTlsPrivateKey,
|
||||||
nioEventLoopGroup,
|
nioEventLoopGroup,
|
||||||
delegatedTaskExecutor,
|
delegatedTaskExecutor,
|
||||||
grpcClientConnectionManager,
|
grpcClientConnectionManager,
|
||||||
clientPublicKeysManager,
|
clientPublicKeysManager,
|
||||||
serverKeyPair,
|
serverKeyPair,
|
||||||
authenticatedGrpcServerAddress,
|
authenticatedGrpcServerAddress,
|
||||||
|
@ -209,7 +225,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
||||||
null,
|
null,
|
||||||
nioEventLoopGroup,
|
nioEventLoopGroup,
|
||||||
delegatedTaskExecutor,
|
delegatedTaskExecutor,
|
||||||
grpcClientConnectionManager,
|
grpcClientConnectionManager,
|
||||||
clientPublicKeysManager,
|
clientPublicKeysManager,
|
||||||
serverKeyPair,
|
serverKeyPair,
|
||||||
authenticatedGrpcServerAddress,
|
authenticatedGrpcServerAddress,
|
||||||
|
@ -235,6 +251,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
||||||
delegatedTaskExecutor.shutdown();
|
delegatedTaskExecutor.shutdown();
|
||||||
//noinspection ResultOfMethodCallIgnored
|
//noinspection ResultOfMethodCallIgnored
|
||||||
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
serverCallExecutor.shutdown();
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -579,6 +599,89 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void waitForCallCompletion() throws InterruptedException {
|
||||||
|
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
|
||||||
|
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
|
||||||
|
final AtomicBoolean closedByServer = new AtomicBoolean(false);
|
||||||
|
|
||||||
|
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handleWebSocketClosedByClient(final int statusCode) {
|
||||||
|
serverCloseStatusCode.set(statusCode);
|
||||||
|
closedByServer.set(false);
|
||||||
|
connectionCloseLatch.countDown();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handleWebSocketClosedByServer(final int statusCode) {
|
||||||
|
serverCloseStatusCode.set(statusCode);
|
||||||
|
closedByServer.set(true);
|
||||||
|
connectionCloseLatch.countDown();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
try (final NoiseWebSocketTunnelClient client = authenticated()
|
||||||
|
.setWebSocketCloseListener(webSocketCloseListener)
|
||||||
|
.build()) {
|
||||||
|
|
||||||
|
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||||
|
|
||||||
|
try {
|
||||||
|
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||||
|
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||||
|
|
||||||
|
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
||||||
|
assertEquals(DEVICE_ID, response.getDeviceId());
|
||||||
|
|
||||||
|
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
|
||||||
|
|
||||||
|
// Start an open-ended server call and leave it in a non-complete state
|
||||||
|
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
|
||||||
|
new StreamObserver<>() {
|
||||||
|
@Override
|
||||||
|
public void onNext(final EchoResponse echoResponse) {
|
||||||
|
responseCountDownLatch.countDown();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(final Throwable throwable) {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onCompleted() {
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
|
||||||
|
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
|
||||||
|
// truly started before requesting connection closure.
|
||||||
|
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
|
||||||
|
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
|
||||||
|
|
||||||
|
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
|
||||||
|
assertFalse(connectionCloseLatch.await(1, TimeUnit.SECONDS),
|
||||||
|
"Channel should not close until active requests have finished");
|
||||||
|
|
||||||
|
//noinspection ResultOfMethodCallIgnored
|
||||||
|
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
|
||||||
|
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
|
||||||
|
|
||||||
|
// Complete the open-ended server call
|
||||||
|
echoRequestStreamObserver.onCompleted();
|
||||||
|
|
||||||
|
assertTrue(connectionCloseLatch.await(1, TimeUnit.SECONDS),
|
||||||
|
"Channel should close once active requests have finished");
|
||||||
|
|
||||||
|
assertTrue(closedByServer.get());
|
||||||
|
assertEquals(4004, serverCloseStatusCode.get());
|
||||||
|
} finally {
|
||||||
|
channel.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private NoiseWebSocketTunnelClient.Builder anonymous() {
|
private NoiseWebSocketTunnelClient.Builder anonymous() {
|
||||||
return new NoiseWebSocketTunnelClient
|
return new NoiseWebSocketTunnelClient
|
||||||
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
|
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
|
||||||
|
|
|
@ -13,6 +13,7 @@ package org.signal.chat.rpc;
|
||||||
service EchoService {
|
service EchoService {
|
||||||
rpc echo (EchoRequest) returns (EchoResponse) {}
|
rpc echo (EchoRequest) returns (EchoResponse) {}
|
||||||
rpc echo2 (EchoRequest) returns (EchoResponse) {}
|
rpc echo2 (EchoRequest) returns (EchoResponse) {}
|
||||||
|
rpc echoStream (stream EchoRequest) returns (stream EchoResponse) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
message EchoRequest {
|
message EchoRequest {
|
||||||
|
|
Loading…
Reference in New Issue