diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptor.java new file mode 100644 index 000000000..24cec57ed --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptor.java @@ -0,0 +1,55 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Context; +import io.grpc.ForwardingServerCallListener; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.netty.channel.local.LocalAddress; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; + +/** + * Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with + * {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise. + */ +public class ChannelShutdownInterceptor implements ServerInterceptor { + + private final GrpcClientConnectionManager grpcClientConnectionManager; + + public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { + this.grpcClientConnectionManager = grpcClientConnectionManager; + } + + @Override + public ServerCall.Listener interceptCall(final ServerCall call, + final Metadata headers, + final ServerCallHandler next) { + + if (!grpcClientConnectionManager.handleServerCallStart(call)) { + // Don't allow new calls if the connection is getting ready to close + return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE); + } + + return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) { + @Override + public void onComplete() { + grpcClientConnectionManager.handleServerCallComplete(call); + super.onComplete(); + } + + @Override + public void onCancel() { + grpcClientConnectionManager.handleServerCallComplete(call); + super.onCancel(); + } + }; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java index f3eec095c..be7ddc477 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java @@ -27,6 +27,7 @@ import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; import org.whispersystems.textsecuregcm.grpc.RequestAttributes; +import org.whispersystems.textsecuregcm.util.ClosableEpoch; /** * A client connection manager associates a local connection to a local gRPC server with a remote connection through a @@ -58,6 +59,10 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener public static final AttributeKey REQUEST_ATTRIBUTES_KEY = AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes"); + @VisibleForTesting + static final AttributeKey EPOCH_ATTRIBUTE_KEY = + AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch"); + private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class); /** @@ -107,6 +112,39 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener return requestAttributes; } + /** + * Handles the start of a server call, incrementing the active call count for the remote channel associated with the + * given server call. + * + * @param serverCall the server call to start + * + * @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the + * underlying channel is closing + */ + public boolean handleServerCallStart(final ServerCall serverCall) { + try { + return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive(); + } catch (final ChannelNotFoundException e) { + // This would only happen if the channel had already closed, which is certainly possible. In this case, the call + // should certainly not proceed. + return false; + } + } + + /** + * Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel + * associated with the given server call. + * + * @param serverCall the server call to complete + */ + public void handleServerCallComplete(final ServerCall serverCall) { + try { + getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart(); + } catch (final ChannelNotFoundException ignored) { + // In practice, we'd only get here if the channel has already closed, so we can just ignore the exception + } + } + /** * Closes any client connections to this host associated with the given authenticated device. * @@ -119,10 +157,13 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener final List channelsToClose = new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList())); - channelsToClose.forEach(channel -> - channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED - .toWebSocketCloseStatus("Reauthentication required"))) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE)); + channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close()); + } + + private static void closeRemoteChannel(final Channel channel) { + channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED + .toWebSocketCloseStatus("Reauthentication required"))) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); } @VisibleForTesting @@ -200,6 +241,9 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener maybeAuthenticatedDevice.ifPresent(authenticatedDevice -> remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice)); + remoteChannel.attr(EPOCH_ATTRIBUTE_KEY) + .set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel))); + remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel); getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice -> diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptorTest.java new file mode 100644 index 000000000..35db29a77 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptorTest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.Status; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; + +class ChannelShutdownInterceptorTest { + + private GrpcClientConnectionManager grpcClientConnectionManager; + private ChannelShutdownInterceptor channelShutdownInterceptor; + + private ServerCallHandler nextCallHandler; + + private static final Metadata HEADERS = new Metadata(); + + @BeforeEach + void setUp() { + grpcClientConnectionManager = mock(GrpcClientConnectionManager.class); + channelShutdownInterceptor = new ChannelShutdownInterceptor(grpcClientConnectionManager); + + //noinspection unchecked + nextCallHandler = mock(ServerCallHandler.class); + + //noinspection unchecked + when(nextCallHandler.startCall(any(), any())).thenReturn(mock(ServerCall.Listener.class)); + } + + @Test + void interceptCallComplete() { + @SuppressWarnings("unchecked") final ServerCall serverCall = mock(ServerCall.class); + + when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true); + + final ServerCall.Listener serverCallListener = + channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler); + + serverCallListener.onComplete(); + + verify(grpcClientConnectionManager).handleServerCallStart(serverCall); + verify(grpcClientConnectionManager).handleServerCallComplete(serverCall); + verify(serverCall, never()).close(any(), any()); + } + + @Test + void interceptCallCancelled() { + @SuppressWarnings("unchecked") final ServerCall serverCall = mock(ServerCall.class); + + when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true); + + final ServerCall.Listener serverCallListener = + channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler); + + serverCallListener.onCancel(); + + verify(grpcClientConnectionManager).handleServerCallStart(serverCall); + verify(grpcClientConnectionManager).handleServerCallComplete(serverCall); + verify(serverCall, never()).close(any(), any()); + } + + @Test + void interceptCallChannelClosing() { + @SuppressWarnings("unchecked") final ServerCall serverCall = mock(ServerCall.class); + + when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(false); + + channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler); + + verify(grpcClientConnectionManager).handleServerCallStart(serverCall); + verify(grpcClientConnectionManager, never()).handleServerCallComplete(serverCall); + verify(serverCall).close(eq(Status.UNAVAILABLE), any()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/EchoServiceImpl.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/EchoServiceImpl.java index 81714ecb1..bf7eb840f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/EchoServiceImpl.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/EchoServiceImpl.java @@ -12,14 +12,38 @@ import org.signal.chat.rpc.EchoServiceGrpc; public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase { @Override - public void echo(EchoRequest req, StreamObserver responseObserver) { - responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build()); + public void echo(final EchoRequest echoRequest, final StreamObserver responseObserver) { + responseObserver.onNext(buildResponse(echoRequest)); responseObserver.onCompleted(); } @Override - public void echo2(EchoRequest req, StreamObserver responseObserver) { - responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build()); + public void echo2(final EchoRequest echoRequest, final StreamObserver responseObserver) { + responseObserver.onNext(buildResponse(echoRequest)); responseObserver.onCompleted(); } + + @Override + public StreamObserver echoStream(final StreamObserver responseObserver) { + return new StreamObserver<>() { + @Override + public void onNext(final EchoRequest echoRequest) { + responseObserver.onNext(buildResponse(echoRequest)); + } + + @Override + public void onError(final Throwable throwable) { + responseObserver.onError(throwable); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + + private static EchoResponse buildResponse(final EchoRequest echoRequest) { + return EchoResponse.newBuilder().setPayload(echoRequest.getPayload()).build(); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java index 603ab94fa..301443a4a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseWebSocketTunnelServerIntegrationTest.java @@ -1,6 +1,7 @@ package org.whispersystems.textsecuregcm.grpc.net; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; @@ -8,10 +9,12 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.protobuf.ByteString; import io.grpc.ManagedChannel; import io.grpc.ServerBuilder; import io.grpc.Status; import io.grpc.netty.NettyChannelBuilder; +import io.grpc.stub.StreamObserver; import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; @@ -61,6 +64,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.signal.chat.rpc.EchoRequest; +import org.signal.chat.rpc.EchoResponse; +import org.signal.chat.rpc.EchoServiceGrpc; import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; import org.signal.chat.rpc.GetRequestAttributesRequest; @@ -71,6 +77,8 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; +import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor; +import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; @@ -83,6 +91,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes private static NioEventLoopGroup nioEventLoopGroup; private static DefaultEventLoopGroup defaultEventLoopGroup; private static ExecutorService delegatedTaskExecutor; + private static ExecutorService serverCallExecutor; private static X509Certificate serverTlsCertificate; @@ -136,7 +145,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes static void setUpBeforeAll() throws CertificateException { nioEventLoopGroup = new NioEventLoopGroup(); defaultEventLoopGroup = new DefaultEventLoopGroup(); - delegatedTaskExecutor = Executors.newSingleThreadExecutor(); + delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor(); + serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor(); final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate( @@ -171,7 +181,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) { @Override protected void configureServer(final ServerBuilder serverBuilder) { - serverBuilder.addService(new RequestAttributesServiceImpl()) + serverBuilder + .executor(serverCallExecutor) + .addService(new RequestAttributesServiceImpl()) + .addService(new EchoServiceImpl()) + .intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager)) .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) .intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager)); } @@ -182,7 +196,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) { @Override protected void configureServer(final ServerBuilder serverBuilder) { - serverBuilder.addService(new RequestAttributesServiceImpl()) + serverBuilder + .executor(serverCallExecutor) + .addService(new RequestAttributesServiceImpl()) .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) .intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager)); } @@ -195,7 +211,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes serverTlsPrivateKey, nioEventLoopGroup, delegatedTaskExecutor, - grpcClientConnectionManager, + grpcClientConnectionManager, clientPublicKeysManager, serverKeyPair, authenticatedGrpcServerAddress, @@ -209,7 +225,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes null, nioEventLoopGroup, delegatedTaskExecutor, - grpcClientConnectionManager, + grpcClientConnectionManager, clientPublicKeysManager, serverKeyPair, authenticatedGrpcServerAddress, @@ -235,6 +251,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes delegatedTaskExecutor.shutdown(); //noinspection ResultOfMethodCallIgnored delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS); + + serverCallExecutor.shutdown(); + //noinspection ResultOfMethodCallIgnored + serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS); } @ParameterizedTest @@ -579,6 +599,89 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes } } + @Test + void waitForCallCompletion() throws InterruptedException { + final CountDownLatch connectionCloseLatch = new CountDownLatch(1); + final AtomicInteger serverCloseStatusCode = new AtomicInteger(0); + final AtomicBoolean closedByServer = new AtomicBoolean(false); + + final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() { + + @Override + public void handleWebSocketClosedByClient(final int statusCode) { + serverCloseStatusCode.set(statusCode); + closedByServer.set(false); + connectionCloseLatch.countDown(); + } + + @Override + public void handleWebSocketClosedByServer(final int statusCode) { + serverCloseStatusCode.set(statusCode); + closedByServer.set(true); + connectionCloseLatch.countDown(); + } + }; + + try (final NoiseWebSocketTunnelClient client = authenticated() + .setWebSocketCloseListener(webSocketCloseListener) + .build()) { + + final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); + + try { + final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) + .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); + + assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); + assertEquals(DEVICE_ID, response.getDeviceId()); + + final CountDownLatch responseCountDownLatch = new CountDownLatch(1); + + // Start an open-ended server call and leave it in a non-complete state + final StreamObserver echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream( + new StreamObserver<>() { + @Override + public void onNext(final EchoResponse echoResponse) { + responseCountDownLatch.countDown(); + } + + @Override + public void onError(final Throwable throwable) { + } + + @Override + public void onCompleted() { + } + }); + + // Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before + // the request even starts. Make sure we've done at least one request/response pair to ensure that the call has + // truly started before requesting connection closure. + echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()); + assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS)); + + grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID)); + assertFalse(connectionCloseLatch.await(1, TimeUnit.SECONDS), + "Channel should not close until active requests have finished"); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel) + .echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build())); + + // Complete the open-ended server call + echoRequestStreamObserver.onCompleted(); + + assertTrue(connectionCloseLatch.await(1, TimeUnit.SECONDS), + "Channel should close once active requests have finished"); + + assertTrue(closedByServer.get()); + assertEquals(4004, serverCloseStatusCode.get()); + } finally { + channel.shutdown(); + } + } + } + private NoiseWebSocketTunnelClient.Builder anonymous() { return new NoiseWebSocketTunnelClient .Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey()) diff --git a/service/src/test/proto/echo_service.proto b/service/src/test/proto/echo_service.proto index d411b6a72..971efe761 100644 --- a/service/src/test/proto/echo_service.proto +++ b/service/src/test/proto/echo_service.proto @@ -13,6 +13,7 @@ package org.signal.chat.rpc; service EchoService { rpc echo (EchoRequest) returns (EchoResponse) {} rpc echo2 (EchoRequest) returns (EchoResponse) {} + rpc echoStream (stream EchoRequest) returns (stream EchoResponse) {} } message EchoRequest {